diff options
-rw-r--r-- | git_command.py | 36 | ||||
-rw-r--r-- | tests/test_git_command.py | 27 |
2 files changed, 62 insertions, 1 deletions
diff --git a/git_command.py b/git_command.py index a2782151..1cb8f1aa 100644 --- a/git_command.py +++ b/git_command.py | |||
@@ -16,6 +16,7 @@ | |||
16 | 16 | ||
17 | from __future__ import print_function | 17 | from __future__ import print_function |
18 | import os | 18 | import os |
19 | import re | ||
19 | import sys | 20 | import sys |
20 | import subprocess | 21 | import subprocess |
21 | import tempfile | 22 | import tempfile |
@@ -47,6 +48,35 @@ LAST_CWD = None | |||
47 | _ssh_proxy_path = None | 48 | _ssh_proxy_path = None |
48 | _ssh_sock_path = None | 49 | _ssh_sock_path = None |
49 | _ssh_clients = [] | 50 | _ssh_clients = [] |
51 | _ssh_version = None | ||
52 | |||
53 | |||
54 | def _run_ssh_version(): | ||
55 | """run ssh -V to display the version number""" | ||
56 | return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() | ||
57 | |||
58 | |||
59 | def _parse_ssh_version(ver_str=None): | ||
60 | """parse a ssh version string into a tuple""" | ||
61 | if ver_str is None: | ||
62 | ver_str = _run_ssh_version() | ||
63 | m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) | ||
64 | if m: | ||
65 | return tuple(int(x) for x in m.group(1).split('.')) | ||
66 | else: | ||
67 | return () | ||
68 | |||
69 | |||
70 | def ssh_version(): | ||
71 | """return ssh version as a tuple""" | ||
72 | global _ssh_version | ||
73 | if _ssh_version is None: | ||
74 | try: | ||
75 | _ssh_version = _parse_ssh_version() | ||
76 | except subprocess.CalledProcessError: | ||
77 | print('fatal: unable to detect ssh version', file=sys.stderr) | ||
78 | sys.exit(1) | ||
79 | return _ssh_version | ||
50 | 80 | ||
51 | 81 | ||
52 | def ssh_sock(create=True): | 82 | def ssh_sock(create=True): |
@@ -57,9 +87,13 @@ def ssh_sock(create=True): | |||
57 | tmp_dir = '/tmp' | 87 | tmp_dir = '/tmp' |
58 | if not os.path.exists(tmp_dir): | 88 | if not os.path.exists(tmp_dir): |
59 | tmp_dir = tempfile.gettempdir() | 89 | tmp_dir = tempfile.gettempdir() |
90 | if ssh_version() < (6, 7): | ||
91 | tokens = '%r@%h:%p' | ||
92 | else: | ||
93 | tokens = '%C' # hash of %l%h%p%r | ||
60 | _ssh_sock_path = os.path.join( | 94 | _ssh_sock_path = os.path.join( |
61 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | 95 | tempfile.mkdtemp('', 'ssh-', tmp_dir), |
62 | 'master-%r@%h:%p') | 96 | 'master-' + tokens) |
63 | return _ssh_sock_path | 97 | return _ssh_sock_path |
64 | 98 | ||
65 | 99 | ||
diff --git a/tests/test_git_command.py b/tests/test_git_command.py index c2d3f1df..2c22b250 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py | |||
@@ -30,6 +30,33 @@ import git_command | |||
30 | import wrapper | 30 | import wrapper |
31 | 31 | ||
32 | 32 | ||
33 | class SSHUnitTest(unittest.TestCase): | ||
34 | """Tests the ssh functions.""" | ||
35 | |||
36 | def test_ssh_version(self): | ||
37 | """Check ssh_version() handling.""" | ||
38 | ver = git_command._parse_ssh_version('Unknown\n') | ||
39 | self.assertEqual(ver, ()) | ||
40 | ver = git_command._parse_ssh_version('OpenSSH_1.0\n') | ||
41 | self.assertEqual(ver, (1, 0)) | ||
42 | ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') | ||
43 | self.assertEqual(ver, (6, 6, 1)) | ||
44 | ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') | ||
45 | self.assertEqual(ver, (7, 6)) | ||
46 | |||
47 | def test_ssh_sock(self): | ||
48 | """Check ssh_sock() function.""" | ||
49 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): | ||
50 | # old ssh version uses port | ||
51 | with mock.patch('git_command.ssh_version', return_value=(6, 6)): | ||
52 | self.assertTrue(git_command.ssh_sock().endswith('%p')) | ||
53 | git_command._ssh_sock_path = None | ||
54 | # new ssh version uses hash | ||
55 | with mock.patch('git_command.ssh_version', return_value=(6, 7)): | ||
56 | self.assertTrue(git_command.ssh_sock().endswith('%C')) | ||
57 | git_command._ssh_sock_path = None | ||
58 | |||
59 | |||
33 | class GitCallUnitTest(unittest.TestCase): | 60 | class GitCallUnitTest(unittest.TestCase): |
34 | """Tests the _GitCall class (via git_command.git).""" | 61 | """Tests the _GitCall class (via git_command.git).""" |
35 | 62 | ||