From 704b890addc6be09dd878f4f7cff6a8ba6b4fd4d Mon Sep 17 00:00:00 2001
From: Nikos Skalkotos <skalkoto@grnet.gr>
Date: Fri, 4 Jul 2014 17:33:08 +0300
Subject: [PATCH] windows: Make Windows INF parsing more accurate

 * Detect the VirtIO driver type by the Device ID it refers to and not
   the catalog file name
 * Check if a driver is suitable for the input media by examining the
   TargetOSVersion field
---
 image_creator/os_type/windows/__init__.py | 206 +++++++++++++++-------
 1 file changed, 140 insertions(+), 66 deletions(-)

diff --git a/image_creator/os_type/windows/__init__.py b/image_creator/os_type/windows/__init__.py
index d3ba2a1..3fc2559 100644
--- a/image_creator/os_type/windows/__init__.py
+++ b/image_creator/os_type/windows/__init__.py
@@ -86,55 +86,42 @@ KMS_CLIENT_SETUP_KEYS = {
     "Windows Server 2008 for Itanium-Based Systems":
     "4DWFP-JF3DJ-B7DTH-78FJB-PDRHK"}
 
-VIRTIO = (
-    "viostor",  # "VirtIO SCSI controller"
-    "vioscsi",  # "VirtIO SCSI pass-through controller"
-    "vioser",   # "VirtIO Serial Driver"
-    "netkvm",   # "VirtIO Ethernet Adapter"
-    "balloon",  # "VirtIO Balloon Driver
-    "viorng")   # "VirtIO RNG Driver"
-
-
-def virtio_dir_check(dirname):
-    """Check if the needed virtio driver files are present in the dirname
-    directory
-    """
-    if not dirname:
-        return ""  # value not set
-
-    ext = ('cat', 'inf', 'sys')
-
-    # Check files in a case insensitive manner
-    files = set([f.lower() for f in os.listdir(dirname)])
-
-    found = False
-    for cat, inf, sys in [["%s.%s" % (b, e) for e in ext] for b in VIRTIO]:
-        if cat in files and inf in files and sys in files:
-            found = True
-
-    if not found:
-        raise ValueError("Invalid VirtIO directory. No VirtIO driver found")
-
-    return dirname
+# The PCI Device ID for VirtIO devices. 1af4 is the Vendor ID for Red Hat, Inc
+VIRTIO_DEVICE_ID = re.compile(r'pci\\ven_1af4&dev_100[0-5]')
+VIRTIO = (      # id    Name
+    "netkvm",   # 1000	Virtio network device
+    "viostor",  # 1001	Virtio block device
+    "balloon",  # 1002	Virtio memory balloon
+    "vioser",   # 1003	Virtio console
+    "vioscsi",  # 1004	Virtio SCSI
+    "viorng")   # 1005	Virtio RNG
 
 
 def parse_inf(inf):
     """Parse the content of a Windows INF file and fetch all information found
-    in the Version section.
+    in the Version section, the target OS as well as the VirtIO drivers it
+    defines.
+
+    For more info check here:
+        http://msdn.microsoft.com/en-us/library/windows/hardware/ff549520
     """
 
-    version = {}  # The 'Version' section
-    strings = {}  # The 'Strings' section
-    section = ""
-    current = None
+    driver = None
+    target_os = set()
+
+    sections = {}
+    current = {}
+
     prev_line = ""
     for line in iter(inf):
-        line = prev_line + line.strip().split(';')[0].strip()
+        # Strip comments
+        line = prev_line + line.split(';')[0].strip()
         prev_line = ""
 
         if not len(line):
             continue
 
+        # Does the directive span more lines?
         if line[-1] == "\\":
             prev_line = line
             continue
@@ -142,30 +129,80 @@ def parse_inf(inf):
         # Does the line denote a section?
         if line.startswith('[') and line.endswith(']'):
             section = line[1:-1].strip().lower()
-            if section == 'version':
-                current = version
-            if section == 'strings':
-                current = strings
-
-        # We only care about 'version' and 'string' sections
-        if section not in ('version', 'strings'):
+            if section not in sections:
+                current = {}
+                sections[section] = current
+            else:
+                current = sections[section]
             continue
 
         # We only care about param = value lines
-        if line.find('=') < 0:
-            continue
+        if line.find('=') > 0:
+            param, value = line.split('=', 1)
+            current[param.strip()] = value.strip()
+
+    models = []
+    if 'manufacturer' in sections:
+        for value in sections['manufacturer'].values():
+            value = value.split(',')
+            if len(value) == 0:
+                continue
+
+            # %strkey%=models-section-name [,TargetOSVersion] ...
+            models.append(value[0].strip().lower())
+            for i in range(len(value) - 1):
+                target_os.add(value[i+1].strip().lower())
+
+    if len(models):
+        # [models-section-name] | [models-section-name.TargetOSVersion]
+        models_section_name = re.compile('^(' + "|".join(models) + ')(\..+)?$')
+        for model in [s for s in sections if models_section_name.match(s)]:
+            for value in sections[model].values():
+                value = value.split(',')
+                if len(value) == 1:
+                    continue
+                # The second value in a device-description entry is always the
+                # hardware ID:
+                #   install-section-name[,hw-id][,compatible-id...]
+                hw_id = value[1].strip().lower()
+                # If this matches a VirtIO device, then this is a VirtIO driver
+                id_match = VIRTIO_DEVICE_ID.match(hw_id)
+                if id_match:
+                    driver = VIRTIO[int(id_match.group(0)[-1])]
 
-        param, value = line.split('=', 1)
-        current[param.strip()] = value.strip()
+    if 'version' in sections and 'strings' in sections:
+        # Replace all strkey tokens with their actual value
+        for k, v in sections['version'].items():
+            if v.startswith('%') and v.endswith('%'):
+                strkey = v[1:-1]
+                if strkey in sections['strings']:
+                    sections['version'][k] = sections['strings'][strkey]
 
-    # Replace all strkey tokens with their actual value
-    for k, v in version.items():
-        if v.startswith('%') and v.endswith('%'):
-            strkey = v[1:-1]
-            if strkey in strings:
-                version[k] = strings[strkey]
+    if len(target_os) == 0:
+        target_os.add('ntx86')
 
-    return version
+    version = sections['version'] if 'version' in sections else {}
+
+    return driver, target_os, version
+
+
+def virtio_dir_check(dirname):
+    """Check if the needed virtio driver files are present in the dirname
+    directory
+    """
+    if not dirname:
+        return ""  # value not set
+
+    # Check files in a case insensitive manner
+    files = set(os.listdir(dirname))
+
+    for inf in [f for f in files if f.lower().endswith('.inf')]:
+        with open(os.path.join(dirname, inf)) as f:
+            driver, _, _ = parse_inf(f)
+            if driver:
+                return dirname
+
+    raise ValueError("Invalid VirtIO directory. No VirtIO driver found")
 
 
 DESCR = {
@@ -233,6 +270,16 @@ class Windows(OSBase):
             self.out.output("Checking media state ...", False)
             self.sysprepped = self.registry.get_setup_state() > 0
             self.virtio_state = self._virtio_state()
+            arch = self.image.g.inspect_get_arch(self.root)
+            if arch == 'x86_64':
+                arch = 'amd64'
+            elif arch == 'i386':
+                arch = 'x86'
+            major = self.image.g.inspect_get_major_version(self.root)
+            minor = self.image.g.inspect_get_minor_version(self.root)
+            # This is the OS version as defined in INF files to check if a
+            # driver is valid for this OS.
+            self.windows_version = "nt%s.%s.%s" % (arch, major, minor)
             self.out.success("done")
 
         # If the image is sysprepped no driver mappings will be present.
@@ -617,7 +664,6 @@ class Windows(OSBase):
         """Returns information about the VirtIO drivers found either in a
         directory or the media itself if the directory is None.
         """
-
         state = {}
         for driver in VIRTIO:
             state[driver] = {}
@@ -644,15 +690,45 @@ class Windows(OSBase):
                         yield name, f
 
         for name, txt in oem_files() if directory is None else local_files():
-            content = parse_inf(txt)
-            cat = content['CatalogFile'] if 'CatalogFile' in content else ""
+            driver, target, content = parse_inf(txt)
 
-            for driver in VIRTIO:
-                if cat.lower() == "%s.cat" % driver:
-                    state[driver][name] = content
+            if driver:
+                content['TargetOSVersions'] = target
+                state[driver][name] = content
 
         return state
 
+    def _fetch_virtio_drivers(self, dirname):
+        """Examines a directory for VirtIO drivers and returns only the drivers
+        that are suitable for this media.
+        """
+        collection = self._virtio_state(dirname)
+
+        self.out.output('Checking new drivers:')
+        for drv_type, drvs in collection.items():
+            for inf, content in drvs.items():
+                found_match = False
+                # Check if the driver is suitable for the input media
+                for target in content['TargetOSVersions']:
+                    if len(target) > len(self.windows_version):
+                        match = target.startswith(self.windows_version)
+                    else:
+                        match = self.windows_version.startswith(target)
+                    if match:
+                        found_match = True
+
+                if not found_match:  # Wrong Target
+                    self.out.warn(
+                        'Ignoring %s. Not suitable for this OS version.' % inf)
+                    del collection[drv_type][inf]
+                else:
+                    self.out.output('Found %s driver: %s' % (drv_type, inf))
+
+            if len(drvs) == 0:
+                del collection[drv_type]
+
+        return collection
+
     def install_virtio_drivers(self, upgrade=True):
         """Install new VirtIO drivers in the input media. If upgrade is True,
         then the old drivers found in the media will be removed.
@@ -662,15 +738,13 @@ class Windows(OSBase):
         if not dirname:
             raise FatalError('No directory hosting the VirtIO drivers defined')
 
-        new_drvs = self._virtio_state(dirname)
-        for k, v in new_drvs.items():
-            if len(v) == 0:
-                del new_drvs[k]
+        new_drvs = self._fetch_virtio_drivers(dirname)
 
-        assert len(new_drvs)
+        if not len(new_drvs):
+            self.out.warn('No suitable driver found to install!')
+            return
 
         self.out.output('Installing VirtIO drivers:')
-
         with self.mount(readonly=False, silent=True):
 
             admin = self.sysprep_params['admin'].value
-- 
GitLab