diff options
| -rw-r--r-- | git_command.py | 32 | ||||
| -rw-r--r-- | tests/test_git_command.py | 9 |
2 files changed, 20 insertions, 21 deletions
diff --git a/git_command.py b/git_command.py index d06fc77c..f8cb280c 100644 --- a/git_command.py +++ b/git_command.py | |||
| @@ -12,6 +12,7 @@ | |||
| 12 | # See the License for the specific language governing permissions and | 12 | # See the License for the specific language governing permissions and |
| 13 | # limitations under the License. | 13 | # limitations under the License. |
| 14 | 14 | ||
| 15 | import functools | ||
| 15 | import os | 16 | import os |
| 16 | import re | 17 | import re |
| 17 | import sys | 18 | import sys |
| @@ -45,7 +46,6 @@ LAST_CWD = None | |||
| 45 | _ssh_proxy_path = None | 46 | _ssh_proxy_path = None |
| 46 | _ssh_sock_path = None | 47 | _ssh_sock_path = None |
| 47 | _ssh_clients = [] | 48 | _ssh_clients = [] |
| 48 | _ssh_version = None | ||
| 49 | 49 | ||
| 50 | 50 | ||
| 51 | def _run_ssh_version(): | 51 | def _run_ssh_version(): |
| @@ -64,16 +64,14 @@ def _parse_ssh_version(ver_str=None): | |||
| 64 | return () | 64 | return () |
| 65 | 65 | ||
| 66 | 66 | ||
| 67 | @functools.lru_cache(maxsize=None) | ||
| 67 | def ssh_version(): | 68 | def ssh_version(): |
| 68 | """return ssh version as a tuple""" | 69 | """return ssh version as a tuple""" |
| 69 | global _ssh_version | 70 | try: |
| 70 | if _ssh_version is None: | 71 | return _parse_ssh_version() |
| 71 | try: | 72 | except subprocess.CalledProcessError: |
| 72 | _ssh_version = _parse_ssh_version() | 73 | print('fatal: unable to detect ssh version', file=sys.stderr) |
| 73 | except subprocess.CalledProcessError: | 74 | sys.exit(1) |
| 74 | print('fatal: unable to detect ssh version', file=sys.stderr) | ||
| 75 | sys.exit(1) | ||
| 76 | return _ssh_version | ||
| 77 | 75 | ||
| 78 | 76 | ||
| 79 | def ssh_sock(create=True): | 77 | def ssh_sock(create=True): |
| @@ -125,18 +123,14 @@ def terminate_ssh_clients(): | |||
| 125 | _ssh_clients = [] | 123 | _ssh_clients = [] |
| 126 | 124 | ||
| 127 | 125 | ||
| 128 | _git_version = None | ||
| 129 | |||
| 130 | |||
| 131 | class _GitCall(object): | 126 | class _GitCall(object): |
| 127 | @functools.lru_cache(maxsize=None) | ||
| 132 | def version_tuple(self): | 128 | def version_tuple(self): |
| 133 | global _git_version | 129 | ret = Wrapper().ParseGitVersion() |
| 134 | if _git_version is None: | 130 | if ret is None: |
| 135 | _git_version = Wrapper().ParseGitVersion() | 131 | print('fatal: unable to detect git version', file=sys.stderr) |
| 136 | if _git_version is None: | 132 | sys.exit(1) |
| 137 | print('fatal: unable to detect git version', file=sys.stderr) | 133 | return ret |
| 138 | sys.exit(1) | ||
| 139 | return _git_version | ||
| 140 | 134 | ||
| 141 | def __getattr__(self, name): | 135 | def __getattr__(self, name): |
| 142 | name = name.replace('_', '-') | 136 | name = name.replace('_', '-') |
diff --git a/tests/test_git_command.py b/tests/test_git_command.py index 912a9dbe..76c092f4 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py | |||
| @@ -29,8 +29,8 @@ import wrapper | |||
| 29 | class SSHUnitTest(unittest.TestCase): | 29 | class SSHUnitTest(unittest.TestCase): |
| 30 | """Tests the ssh functions.""" | 30 | """Tests the ssh functions.""" |
| 31 | 31 | ||
| 32 | def test_ssh_version(self): | 32 | def test_parse_ssh_version(self): |
| 33 | """Check ssh_version() handling.""" | 33 | """Check parse_ssh_version() handling.""" |
| 34 | ver = git_command._parse_ssh_version('Unknown\n') | 34 | ver = git_command._parse_ssh_version('Unknown\n') |
| 35 | self.assertEqual(ver, ()) | 35 | self.assertEqual(ver, ()) |
| 36 | ver = git_command._parse_ssh_version('OpenSSH_1.0\n') | 36 | ver = git_command._parse_ssh_version('OpenSSH_1.0\n') |
| @@ -40,6 +40,11 @@ class SSHUnitTest(unittest.TestCase): | |||
| 40 | ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') | 40 | ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') |
| 41 | self.assertEqual(ver, (7, 6)) | 41 | self.assertEqual(ver, (7, 6)) |
| 42 | 42 | ||
| 43 | def test_ssh_version(self): | ||
| 44 | """Check ssh_version() handling.""" | ||
| 45 | with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'): | ||
| 46 | self.assertEqual(git_command.ssh_version(), (1, 2)) | ||
| 47 | |||
| 43 | def test_ssh_sock(self): | 48 | def test_ssh_sock(self): |
| 44 | """Check ssh_sock() function.""" | 49 | """Check ssh_sock() function.""" |
| 45 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): | 50 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): |
