From 1e80d3414f2ee17871d948affba76f5b390262e8 Mon Sep 17 00:00:00 2001
From: Nikos Skalkotos <skalkoto@grnet.gr>
Date: Thu, 3 Jul 2014 19:37:14 +0300
Subject: [PATCH] windows: Uninstall the old VirtIO drivers

When installing new drivers, make sure the old ones get uninstalled.
---
 image_creator/os_type/windows/__init__.py   | 168 ++++++++++++--------
 image_creator/os_type/windows/powershell.py |   3 +
 2 files changed, 105 insertions(+), 66 deletions(-)

diff --git a/image_creator/os_type/windows/__init__.py b/image_creator/os_type/windows/__init__.py
index 665eae3..d3ba2a1 100644
--- a/image_creator/os_type/windows/__init__.py
+++ b/image_creator/os_type/windows/__init__.py
@@ -24,7 +24,7 @@ from image_creator.os_type.windows.vm import VM, RANDOM_TOKEN as TOKEN
 from image_creator.os_type.windows.registry import Registry
 from image_creator.os_type.windows.winexe import WinEXE
 from image_creator.os_type.windows.powershell import DRVINST_HEAD, SAFEBOOT, \
-    DRVINST_TAIL
+    DRVINST_TAIL, DRVUNINST
 
 import tempfile
 import re
@@ -118,6 +118,56 @@ def virtio_dir_check(dirname):
     return dirname
 
 
+def parse_inf(inf):
+    """Parse the content of a Windows INF file and fetch all information found
+    in the Version section.
+    """
+
+    version = {}  # The 'Version' section
+    strings = {}  # The 'Strings' section
+    section = ""
+    current = None
+    prev_line = ""
+    for line in iter(inf):
+        line = prev_line + line.strip().split(';')[0].strip()
+        prev_line = ""
+
+        if not len(line):
+            continue
+
+        if line[-1] == "\\":
+            prev_line = line
+            continue
+
+        # Does the line denote a section?
+        if line.startswith('[') and line.endswith(']'):
+            section = line[1:-1].strip().lower()
+            if section == 'version':
+                current = version
+            if section == 'strings':
+                current = strings
+
+        # We only care about 'version' and 'string' sections
+        if section not in ('version', 'strings'):
+            continue
+
+        # We only care about param = value lines
+        if line.find('=') < 0:
+            continue
+
+        param, value = line.split('=', 1)
+        current[param.strip()] = value.strip()
+
+    # Replace all strkey tokens with their actual value
+    for k, v in version.items():
+        if v.startswith('%') and v.endswith('%'):
+            strkey = v[1:-1]
+            if strkey in strings:
+                version[k] = strings[strkey]
+
+    return version
+
+
 DESCR = {
     "boot_timeout":
     "Time in seconds to wait for the Windows customization VM to boot.",
@@ -563,84 +613,63 @@ class Windows(OSBase):
         raise FatalError("Connection to the Windows VM failed after %d retries"
                          % retries)
 
-    def _virtio_state(self):
-        """Check if the virtio drivers are install and return the information
-        about the installed driver
+    def _virtio_state(self, directory=None):
+        """Returns information about the VirtIO drivers found either in a
+        directory or the media itself if the directory is None.
         """
 
-        inf_path = self.image.g.case_sensitive_path("%s/inf" % self.systemroot)
-
         state = {}
         for driver in VIRTIO:
             state[driver] = {}
 
-        def parse_inf(filename):
-            """Parse a Windows INF file and fetch all information found in the
-            Version section.
-            """
-            version = {}  # The 'Version' section
-            strings = {}  # The 'Strings' section
-            section = ""
-            current = None
-            prev_line = ""
-            fullpath = "%s/%s" % (inf_path, filename)
-            for line in self.image.g.cat(fullpath).splitlines():
-                line = prev_line + line.strip().split(';')[0].strip()
-                prev_line = ""
-
-                if not len(line):
+        def oem_files():
+            # Read oem*.inf files under \Windows\Inf\ directory
+            path = self.image.g.case_sensitive_path("%s/inf" % self.systemroot)
+            oem = re.compile(r'^oem\d+\.inf', flags=re.IGNORECASE)
+            for f in self.image.g.readdir(path):
+                name = f['name']
+                if not oem.match(name):
                     continue
+                yield name, \
+                    self.image.g.cat("%s/%s" % (path, name)).splitlines()
+
+        def local_files():
+            # Read *.inf files under a local directory
+            assert os.path.isdir(directory)
+            inf = re.compile(r'^.+\.inf', flags=re.IGNORECASE)
+            for name in os.listdir(directory):
+                fullpath = os.path.join(directory, name)
+                if inf.match(name) and os.path.isfile(fullpath):
+                    with open(fullpath, 'r') as f:
+                        yield name, f
+
+        for name, txt in oem_files() if directory is None else local_files():
+            content = parse_inf(txt)
+            cat = content['CatalogFile'] if 'CatalogFile' in content else ""
 
-                if line[-1] == "\\":
-                    prev_line = line
-                    continue
-
-                # Does the line denote a section?
-                if line.startswith('[') and line.endswith(']'):
-                    section = line[1:-1].lower()
-                    if section == 'version':
-                        current = version
-                    if section == 'strings':
-                        current = strings
-
-                # We only care about 'version' and 'string' sections
-                if section not in ('version', 'strings'):
-                    continue
-
-                # We only care about param = value lines
-                if line.find('=') < 0:
-                    continue
-
-                param, value = line.split('=', 1)
-                current[param.strip()] = value.strip()
-
-            # Replace all strkey tokens with their actual value
-            for k, v in version.items():
-                if v.startswith('%') and v.endswith('%'):
-                    strkey = v[1:-1]
-                    if strkey in strings:
-                        version[k] = strings[strkey]
-
-            cat = version['CatalogFile'] if 'CatalogFile' in version else ""
             for driver in VIRTIO:
                 if cat.lower() == "%s.cat" % driver:
-                    state[driver][filename] = version
-
-        oem = re.compile(r'^oem\d+\.inf', flags=re.IGNORECASE)
-        for f in self.image.g.readdir(inf_path):
-            if oem.match(f['name']):
-                parse_inf(f['name'])
+                    state[driver][name] = content
 
         return state
 
-    def install_virtio_drivers(self):
-        """Install the virtio drivers on the media"""
+    def install_virtio_drivers(self, upgrade=True):
+        """Install new VirtIO drivers in the input media. If upgrade is True,
+        then the old drivers found in the media will be removed.
+        """
 
         dirname = self.sysprep_params['virtio'].value
         if not dirname:
             raise FatalError('No directory hosting the VirtIO drivers defined')
 
-        self.out.output('Installing virtio drivers:')
+        new_drvs = self._virtio_state(dirname)
+        for k, v in new_drvs.items():
+            if len(v) == 0:
+                del new_drvs[k]
+
+        assert len(new_drvs)
+
+        self.out.output('Installing VirtIO drivers:')
 
         with self.mount(readonly=False, silent=True):
 
@@ -649,7 +678,13 @@ class Windows(OSBase):
             self.registry.enable_autologon(admin)
             self._upload_virtio_drivers(dirname)
 
-            drvs_install = DRVINST_HEAD.replace('\n', '\r\n')
+            drvs_install = DRVINST_HEAD
+
+            if upgrade:
+                # Add code to remove the old drivers
+                for drv in new_drvs:
+                    for oem in self.virtio_state[drv]:
+                        drvs_install += DRVUNINST % oem
 
             if self.check_version(6, 1) <= 0:
                 self._install_viostor_driver(dirname)
@@ -657,13 +692,14 @@ class Windows(OSBase):
                 # In newer windows, in order to reduce the boot process the
                 # boot drivers are cached. To be able to boot with viostor, we
                 # need to reboot in safe mode.
-                drvs_install += SAFEBOOT.replace('\n', '\r\n')
+                drvs_install += SAFEBOOT
 
-            drvs_install += DRVINST_TAIL.replace('\n', '\r\n')
+            drvs_install += DRVINST_TAIL
 
             remotedir = self.image.g.case_sensitive_path("%s/VirtIO" %
                                                          self.systemroot)
-            self.image.g.write(remotedir + "/InstallDrivers.ps1", drvs_install)
+            self.image.g.write(remotedir + "/InstallDrivers.ps1",
+                               drvs_install.replace('\n', '\r\n'))
 
             cmd = (
                 '%(drive)s:%(root)s\\System32\\WindowsPowerShell\\v1.0\\'
@@ -695,7 +731,7 @@ class Windows(OSBase):
             if not self.vm.wait_on_serial(timeout):
                 raise FatalError("Windows VM booting timed out!")
             self.out.success('done')
-            self.out.output("Performing the drivers installation ...", False)
+            self.out.output("Installing new drivers ...", False)
             if not self.vm.wait_on_serial(virtio_timeout):
                 raise FatalError("Windows VirtIO installation timed out!")
             self.out.success('done')
diff --git a/image_creator/os_type/windows/powershell.py b/image_creator/os_type/windows/powershell.py
index 909fbfc..3acd8f9 100644
--- a/image_creator/os_type/windows/powershell.py
+++ b/image_creator/os_type/windows/powershell.py
@@ -72,4 +72,7 @@ New-ItemProperty `
     -Value 'cmd /q /c "bcdedit /deletevalue safeboot & shutdown /s /t 0"'
 
 """
+
+DRVUNINST = 'pnputil.exe -f -d %s\n'
+
 # vim: set sta sts=4 shiftwidth=4 sw=4 et ai :
-- 
GitLab