diff --git a/lib/utils.py b/lib/utils.py index a71cd1f92b8db241437d797df6034bfb8c24b692..68ac6a6fd4fd61e721a1170f1f7073dc997303d4 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -815,6 +815,47 @@ def RemoveEtcHostsEntry(file_name, hostname): raise +def _SplitKnownHostsHosts(hosts): + """Parses the first field of a known_hosts file. + + TODO: Support other formats. + """ + return hosts.split(',') + + +def AddKnownHost(file_name, hostname, pubkey): + """Adds a new known host to a known_hosts file. + + """ + f = open(file_name, 'a+') + try: + nl = True + for line in f: + fields = line.split() + if (len(fields) < 3 or + fields[0].startswith('#') or + fields[1] != 'ssh-rsa'): + continue + hosts = _SplitKnownHostsHosts(fields[0]) + if hostname in hosts and fields[2] == pubkey: + break + nl = line.endswith('\n') + else: + if not nl: + f.write("\n") + f.write(hostname) + f.write(' ssh-rsa ') + f.write(pubkey) + f.write("\n") + f.flush() + finally: + f.close() + + +def RemoveKnownHost(file_name, hostname): + pass + + def CreateBackup(file_name): """Creates a backup of a file. diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index ab9956a330887a6432fdd2e2e3dd256375e80820..73250c5bb4e2b5247a3dfbf3e219053493a1007f 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -38,7 +38,8 @@ from ganeti.utils import IsProcessAlive, Lock, Unlock, RunCmd, \ RemoveFile, CheckDict, MatchNameComponent, FormatUnit, \ ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \ ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \ - AddEtcHostsEntry, RemoveEtcHostsEntry + AddEtcHostsEntry, RemoveEtcHostsEntry, \ + AddKnownHost, RemoveKnownHost from ganeti.errors import LockError, UnitParseError @@ -518,6 +519,51 @@ class TestEtcHosts(unittest.TestCase): os.unlink(tmpname) +class TestKnownHosts(unittest.TestCase): + """Test functions modifying known_hosts files""" + + def writeTestFile(self): + (fd, tmpname) = tempfile.mkstemp(prefix = 'ganeti-test') + f = os.fdopen(fd, 'w') + try: + f.write('node1.tld,node1\tssh-rsa AAAA1234567890=\n') + f.write('node2,node2.tld ssh-rsa AAAA1234567890=\n') + finally: + f.close() + + return tmpname + + def testAddingNewHost(self): + tmpname = self.writeTestFile() + try: + AddKnownHost(tmpname, 'node3.tld', 'AAAA0987654321=') + + f = open(tmpname, 'r') + try: + self.assertEqual(md5.new(f.read(8192)).hexdigest(), + '86cf3c7c7983a3bd5c475c4c1a3e5678') + finally: + f.close() + finally: + os.unlink(tmpname) + + def testAddingOldHost(self): + tmpname = self.writeTestFile() + try: + AddKnownHost(tmpname, 'node2.tld', 'AAAA0987654321=') + + f = open(tmpname, 'r') + try: + os.system("vim %s" % tmpname) + self.assertEqual(md5.new(f.read(8192)).hexdigest(), + '86cf3c7c7983a3bd5c475c4c1a3e5678') + finally: + f.close() + finally: + os.unlink(tmpname) + + + class TestShellQuoting(unittest.TestCase): """Test case for shell quoting functions"""