summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMike Frysinger <vapier@google.com>2021-05-05 15:53:03 -0400
committerMike Frysinger <vapier@google.com>2021-05-06 19:09:16 +0000
commit5291eafa412117b80ebbf645fc51559dd0b2caaf (patch)
treef92dd1030f36cbf8aaa3c208bee7b94cd9c72927
parent8e768eaaa722a99405f6542ac718880c8c22f060 (diff)
downloadgit-repo-5291eafa412117b80ebbf645fc51559dd0b2caaf.tar.gz
ssh: move all ssh logic to a common place
We had ssh logic sprinkled between two git modules, and neither was quite the right home for it. This largely moves the logic as-is to its new home. We'll leave major refactoring to followup commits. Bug: https://crbug.com/gerrit/12389 Change-Id: I300a8f7dba74f2bd132232a5eb1e856a8490e0e9 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305483 Reviewed-by: Chris Mcdonald <cjmcdonald@google.com> Tested-by: Mike Frysinger <vapier@google.com>
-rw-r--r--git_command.py91
-rw-r--r--git_config.py156
-rwxr-xr-xmain.py7
-rw-r--r--ssh.py257
-rw-r--r--tests/test_git_command.py32
-rw-r--r--tests/test_ssh.py52
6 files changed, 320 insertions, 275 deletions
diff --git a/git_command.py b/git_command.py
index f8cb280c..fabad0e0 100644
--- a/git_command.py
+++ b/git_command.py
@@ -14,16 +14,14 @@
14 14
15import functools 15import functools
16import os 16import os
17import re
18import sys 17import sys
19import subprocess 18import subprocess
20import tempfile
21from signal import SIGTERM
22 19
23from error import GitError 20from error import GitError
24from git_refs import HEAD 21from git_refs import HEAD
25import platform_utils 22import platform_utils
26from repo_trace import REPO_TRACE, IsTrace, Trace 23from repo_trace import REPO_TRACE, IsTrace, Trace
24import ssh
27from wrapper import Wrapper 25from wrapper import Wrapper
28 26
29GIT = 'git' 27GIT = 'git'
@@ -43,85 +41,6 @@ GIT_DIR = 'GIT_DIR'
43LAST_GITDIR = None 41LAST_GITDIR = None
44LAST_CWD = None 42LAST_CWD = None
45 43
46_ssh_proxy_path = None
47_ssh_sock_path = None
48_ssh_clients = []
49
50
51def _run_ssh_version():
52 """run ssh -V to display the version number"""
53 return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
54
55
56def _parse_ssh_version(ver_str=None):
57 """parse a ssh version string into a tuple"""
58 if ver_str is None:
59 ver_str = _run_ssh_version()
60 m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
61 if m:
62 return tuple(int(x) for x in m.group(1).split('.'))
63 else:
64 return ()
65
66
67@functools.lru_cache(maxsize=None)
68def ssh_version():
69 """return ssh version as a tuple"""
70 try:
71 return _parse_ssh_version()
72 except subprocess.CalledProcessError:
73 print('fatal: unable to detect ssh version', file=sys.stderr)
74 sys.exit(1)
75
76
77def ssh_sock(create=True):
78 global _ssh_sock_path
79 if _ssh_sock_path is None:
80 if not create:
81 return None
82 tmp_dir = '/tmp'
83 if not os.path.exists(tmp_dir):
84 tmp_dir = tempfile.gettempdir()
85 if ssh_version() < (6, 7):
86 tokens = '%r@%h:%p'
87 else:
88 tokens = '%C' # hash of %l%h%p%r
89 _ssh_sock_path = os.path.join(
90 tempfile.mkdtemp('', 'ssh-', tmp_dir),
91 'master-' + tokens)
92 return _ssh_sock_path
93
94
95def _ssh_proxy():
96 global _ssh_proxy_path
97 if _ssh_proxy_path is None:
98 _ssh_proxy_path = os.path.join(
99 os.path.dirname(__file__),
100 'git_ssh')
101 return _ssh_proxy_path
102
103
104def _add_ssh_client(p):
105 _ssh_clients.append(p)
106
107
108def _remove_ssh_client(p):
109 try:
110 _ssh_clients.remove(p)
111 except ValueError:
112 pass
113
114
115def terminate_ssh_clients():
116 global _ssh_clients
117 for p in _ssh_clients:
118 try:
119 os.kill(p.pid, SIGTERM)
120 p.wait()
121 except OSError:
122 pass
123 _ssh_clients = []
124
125 44
126class _GitCall(object): 45class _GitCall(object):
127 @functools.lru_cache(maxsize=None) 46 @functools.lru_cache(maxsize=None)
@@ -256,8 +175,8 @@ class GitCommand(object):
256 if disable_editor: 175 if disable_editor:
257 env['GIT_EDITOR'] = ':' 176 env['GIT_EDITOR'] = ':'
258 if ssh_proxy: 177 if ssh_proxy:
259 env['REPO_SSH_SOCK'] = ssh_sock() 178 env['REPO_SSH_SOCK'] = ssh.sock()
260 env['GIT_SSH'] = _ssh_proxy() 179 env['GIT_SSH'] = ssh.proxy()
261 env['GIT_SSH_VARIANT'] = 'ssh' 180 env['GIT_SSH_VARIANT'] = 'ssh'
262 if 'http_proxy' in env and 'darwin' == sys.platform: 181 if 'http_proxy' in env and 'darwin' == sys.platform:
263 s = "'http.proxy=%s'" % (env['http_proxy'],) 182 s = "'http.proxy=%s'" % (env['http_proxy'],)
@@ -340,7 +259,7 @@ class GitCommand(object):
340 raise GitError('%s: %s' % (command[1], e)) 259 raise GitError('%s: %s' % (command[1], e))
341 260
342 if ssh_proxy: 261 if ssh_proxy:
343 _add_ssh_client(p) 262 ssh.add_client(p)
344 263
345 self.process = p 264 self.process = p
346 if input: 265 if input:
@@ -352,7 +271,7 @@ class GitCommand(object):
352 try: 271 try:
353 self.stdout, self.stderr = p.communicate() 272 self.stdout, self.stderr = p.communicate()
354 finally: 273 finally:
355 _remove_ssh_client(p) 274 ssh.remove_client(p)
356 self.rc = p.wait() 275 self.rc = p.wait()
357 276
358 @staticmethod 277 @staticmethod
diff --git a/git_config.py b/git_config.py
index fcd0446c..1d8d1363 100644
--- a/git_config.py
+++ b/git_config.py
@@ -18,25 +18,17 @@ from http.client import HTTPException
18import json 18import json
19import os 19import os
20import re 20import re
21import signal
22import ssl 21import ssl
23import subprocess 22import subprocess
24import sys 23import sys
25try:
26 import threading as _threading
27except ImportError:
28 import dummy_threading as _threading
29import time
30import urllib.error 24import urllib.error
31import urllib.request 25import urllib.request
32 26
33from error import GitError, UploadError 27from error import GitError, UploadError
34import platform_utils 28import platform_utils
35from repo_trace import Trace 29from repo_trace import Trace
36 30import ssh
37from git_command import GitCommand 31from git_command import GitCommand
38from git_command import ssh_sock
39from git_command import terminate_ssh_clients
40from git_refs import R_CHANGES, R_HEADS, R_TAGS 32from git_refs import R_CHANGES, R_HEADS, R_TAGS
41 33
42ID_RE = re.compile(r'^[0-9a-f]{40}$') 34ID_RE = re.compile(r'^[0-9a-f]{40}$')
@@ -440,129 +432,6 @@ class RefSpec(object):
440 return s 432 return s
441 433
442 434
443_master_processes = []
444_master_keys = set()
445_ssh_master = True
446_master_keys_lock = None
447
448
449def init_ssh():
450 """Should be called once at the start of repo to init ssh master handling.
451
452 At the moment, all we do is to create our lock.
453 """
454 global _master_keys_lock
455 assert _master_keys_lock is None, "Should only call init_ssh once"
456 _master_keys_lock = _threading.Lock()
457
458
459def _open_ssh(host, port=None):
460 global _ssh_master
461
462 # Bail before grabbing the lock if we already know that we aren't going to
463 # try creating new masters below.
464 if sys.platform in ('win32', 'cygwin'):
465 return False
466
467 # Acquire the lock. This is needed to prevent opening multiple masters for
468 # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
469 # manifest <remote fetch="ssh://xyz"> specifies a different host from the
470 # one that was passed to repo init.
471 _master_keys_lock.acquire()
472 try:
473
474 # Check to see whether we already think that the master is running; if we
475 # think it's already running, return right away.
476 if port is not None:
477 key = '%s:%s' % (host, port)
478 else:
479 key = host
480
481 if key in _master_keys:
482 return True
483
484 if not _ssh_master or 'GIT_SSH' in os.environ:
485 # Failed earlier, so don't retry.
486 return False
487
488 # We will make two calls to ssh; this is the common part of both calls.
489 command_base = ['ssh',
490 '-o', 'ControlPath %s' % ssh_sock(),
491 host]
492 if port is not None:
493 command_base[1:1] = ['-p', str(port)]
494
495 # Since the key wasn't in _master_keys, we think that master isn't running.
496 # ...but before actually starting a master, we'll double-check. This can
497 # be important because we can't tell that that 'git@myhost.com' is the same
498 # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
499 check_command = command_base + ['-O', 'check']
500 try:
501 Trace(': %s', ' '.join(check_command))
502 check_process = subprocess.Popen(check_command,
503 stdout=subprocess.PIPE,
504 stderr=subprocess.PIPE)
505 check_process.communicate() # read output, but ignore it...
506 isnt_running = check_process.wait()
507
508 if not isnt_running:
509 # Our double-check found that the master _was_ infact running. Add to
510 # the list of keys.
511 _master_keys.add(key)
512 return True
513 except Exception:
514 # Ignore excpetions. We we will fall back to the normal command and print
515 # to the log there.
516 pass
517
518 command = command_base[:1] + ['-M', '-N'] + command_base[1:]
519 try:
520 Trace(': %s', ' '.join(command))
521 p = subprocess.Popen(command)
522 except Exception as e:
523 _ssh_master = False
524 print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
525 % (host, port, str(e)), file=sys.stderr)
526 return False
527
528 time.sleep(1)
529 ssh_died = (p.poll() is not None)
530 if ssh_died:
531 return False
532
533 _master_processes.append(p)
534 _master_keys.add(key)
535 return True
536 finally:
537 _master_keys_lock.release()
538
539
540def close_ssh():
541 global _master_keys_lock
542
543 terminate_ssh_clients()
544
545 for p in _master_processes:
546 try:
547 os.kill(p.pid, signal.SIGTERM)
548 p.wait()
549 except OSError:
550 pass
551 del _master_processes[:]
552 _master_keys.clear()
553
554 d = ssh_sock(create=False)
555 if d:
556 try:
557 platform_utils.rmdir(os.path.dirname(d))
558 except OSError:
559 pass
560
561 # We're done with the lock, so we can delete it.
562 _master_keys_lock = None
563
564
565URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
566URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') 435URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
567 436
568 437
@@ -614,27 +483,6 @@ def GetUrlCookieFile(url, quiet):
614 yield cookiefile, None 483 yield cookiefile, None
615 484
616 485
617def _preconnect(url):
618 m = URI_ALL.match(url)
619 if m:
620 scheme = m.group(1)
621 host = m.group(2)
622 if ':' in host:
623 host, port = host.split(':')
624 else:
625 port = None
626 if scheme in ('ssh', 'git+ssh', 'ssh+git'):
627 return _open_ssh(host, port)
628 return False
629
630 m = URI_SCP.match(url)
631 if m:
632 host = m.group(1)
633 return _open_ssh(host)
634
635 return False
636
637
638class Remote(object): 486class Remote(object):
639 """Configuration options related to a remote. 487 """Configuration options related to a remote.
640 """ 488 """
@@ -673,7 +521,7 @@ class Remote(object):
673 521
674 def PreConnectFetch(self): 522 def PreConnectFetch(self):
675 connectionUrl = self._InsteadOf() 523 connectionUrl = self._InsteadOf()
676 return _preconnect(connectionUrl) 524 return ssh.preconnect(connectionUrl)
677 525
678 def ReviewUrl(self, userEmail, validate_certs): 526 def ReviewUrl(self, userEmail, validate_certs):
679 if self._review_url is None: 527 if self._review_url is None:
diff --git a/main.py b/main.py
index 8aba2ec2..96744335 100755
--- a/main.py
+++ b/main.py
@@ -39,7 +39,7 @@ from color import SetDefaultColoring
39import event_log 39import event_log
40from repo_trace import SetTrace 40from repo_trace import SetTrace
41from git_command import user_agent 41from git_command import user_agent
42from git_config import init_ssh, close_ssh, RepoConfig 42from git_config import RepoConfig
43from git_trace2_event_log import EventLog 43from git_trace2_event_log import EventLog
44from command import InteractiveCommand 44from command import InteractiveCommand
45from command import MirrorSafeCommand 45from command import MirrorSafeCommand
@@ -56,6 +56,7 @@ from error import RepoChangedException
56import gitc_utils 56import gitc_utils
57from manifest_xml import GitcClient, RepoClient 57from manifest_xml import GitcClient, RepoClient
58from pager import RunPager, TerminatePager 58from pager import RunPager, TerminatePager
59import ssh
59from wrapper import WrapperPath, Wrapper 60from wrapper import WrapperPath, Wrapper
60 61
61from subcmds import all_commands 62from subcmds import all_commands
@@ -592,7 +593,7 @@ def _Main(argv):
592 repo = _Repo(opt.repodir) 593 repo = _Repo(opt.repodir)
593 try: 594 try:
594 try: 595 try:
595 init_ssh() 596 ssh.init()
596 init_http() 597 init_http()
597 name, gopts, argv = repo._ParseArgs(argv) 598 name, gopts, argv = repo._ParseArgs(argv)
598 run = lambda: repo._Run(name, gopts, argv) or 0 599 run = lambda: repo._Run(name, gopts, argv) or 0
@@ -604,7 +605,7 @@ def _Main(argv):
604 else: 605 else:
605 result = run() 606 result = run()
606 finally: 607 finally:
607 close_ssh() 608 ssh.close()
608 except KeyboardInterrupt: 609 except KeyboardInterrupt:
609 print('aborted by user', file=sys.stderr) 610 print('aborted by user', file=sys.stderr)
610 result = 1 611 result = 1
diff --git a/ssh.py b/ssh.py
new file mode 100644
index 00000000..d06c4eb2
--- /dev/null
+++ b/ssh.py
@@ -0,0 +1,257 @@
1# Copyright (C) 2008 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common SSH management logic."""
16
17import functools
18import os
19import re
20import signal
21import subprocess
22import sys
23import tempfile
24try:
25 import threading as _threading
26except ImportError:
27 import dummy_threading as _threading
28import time
29
30import platform_utils
31from repo_trace import Trace
32
33
34_ssh_proxy_path = None
35_ssh_sock_path = None
36_ssh_clients = []
37
38
39def _run_ssh_version():
40 """run ssh -V to display the version number"""
41 return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
42
43
44def _parse_ssh_version(ver_str=None):
45 """parse a ssh version string into a tuple"""
46 if ver_str is None:
47 ver_str = _run_ssh_version()
48 m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
49 if m:
50 return tuple(int(x) for x in m.group(1).split('.'))
51 else:
52 return ()
53
54
55@functools.lru_cache(maxsize=None)
56def version():
57 """return ssh version as a tuple"""
58 try:
59 return _parse_ssh_version()
60 except subprocess.CalledProcessError:
61 print('fatal: unable to detect ssh version', file=sys.stderr)
62 sys.exit(1)
63
64
65def proxy():
66 global _ssh_proxy_path
67 if _ssh_proxy_path is None:
68 _ssh_proxy_path = os.path.join(
69 os.path.dirname(__file__),
70 'git_ssh')
71 return _ssh_proxy_path
72
73
74def add_client(p):
75 _ssh_clients.append(p)
76
77
78def remove_client(p):
79 try:
80 _ssh_clients.remove(p)
81 except ValueError:
82 pass
83
84
85def _terminate_clients():
86 global _ssh_clients
87 for p in _ssh_clients:
88 try:
89 os.kill(p.pid, signal.SIGTERM)
90 p.wait()
91 except OSError:
92 pass
93 _ssh_clients = []
94
95
96_master_processes = []
97_master_keys = set()
98_ssh_master = True
99_master_keys_lock = None
100
101
102def init():
103 """Should be called once at the start of repo to init ssh master handling.
104
105 At the moment, all we do is to create our lock.
106 """
107 global _master_keys_lock
108 assert _master_keys_lock is None, "Should only call init once"
109 _master_keys_lock = _threading.Lock()
110
111
112def _open_ssh(host, port=None):
113 global _ssh_master
114
115 # Bail before grabbing the lock if we already know that we aren't going to
116 # try creating new masters below.
117 if sys.platform in ('win32', 'cygwin'):
118 return False
119
120 # Acquire the lock. This is needed to prevent opening multiple masters for
121 # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
122 # manifest <remote fetch="ssh://xyz"> specifies a different host from the
123 # one that was passed to repo init.
124 _master_keys_lock.acquire()
125 try:
126
127 # Check to see whether we already think that the master is running; if we
128 # think it's already running, return right away.
129 if port is not None:
130 key = '%s:%s' % (host, port)
131 else:
132 key = host
133
134 if key in _master_keys:
135 return True
136
137 if not _ssh_master or 'GIT_SSH' in os.environ:
138 # Failed earlier, so don't retry.
139 return False
140
141 # We will make two calls to ssh; this is the common part of both calls.
142 command_base = ['ssh',
143 '-o', 'ControlPath %s' % sock(),
144 host]
145 if port is not None:
146 command_base[1:1] = ['-p', str(port)]
147
148 # Since the key wasn't in _master_keys, we think that master isn't running.
149 # ...but before actually starting a master, we'll double-check. This can
150 # be important because we can't tell that that 'git@myhost.com' is the same
151 # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
152 check_command = command_base + ['-O', 'check']
153 try:
154 Trace(': %s', ' '.join(check_command))
155 check_process = subprocess.Popen(check_command,
156 stdout=subprocess.PIPE,
157 stderr=subprocess.PIPE)
158 check_process.communicate() # read output, but ignore it...
159 isnt_running = check_process.wait()
160
161 if not isnt_running:
162 # Our double-check found that the master _was_ infact running. Add to
163 # the list of keys.
164 _master_keys.add(key)
165 return True
166 except Exception:
167 # Ignore excpetions. We we will fall back to the normal command and print
168 # to the log there.
169 pass
170
171 command = command_base[:1] + ['-M', '-N'] + command_base[1:]
172 try:
173 Trace(': %s', ' '.join(command))
174 p = subprocess.Popen(command)
175 except Exception as e:
176 _ssh_master = False
177 print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
178 % (host, port, str(e)), file=sys.stderr)
179 return False
180
181 time.sleep(1)
182 ssh_died = (p.poll() is not None)
183 if ssh_died:
184 return False
185
186 _master_processes.append(p)
187 _master_keys.add(key)
188 return True
189 finally:
190 _master_keys_lock.release()
191
192
193def close():
194 global _master_keys_lock
195
196 _terminate_clients()
197
198 for p in _master_processes:
199 try:
200 os.kill(p.pid, signal.SIGTERM)
201 p.wait()
202 except OSError:
203 pass
204 del _master_processes[:]
205 _master_keys.clear()
206
207 d = sock(create=False)
208 if d:
209 try:
210 platform_utils.rmdir(os.path.dirname(d))
211 except OSError:
212 pass
213
214 # We're done with the lock, so we can delete it.
215 _master_keys_lock = None
216
217
218URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
219URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
220
221
222def preconnect(url):
223 m = URI_ALL.match(url)
224 if m:
225 scheme = m.group(1)
226 host = m.group(2)
227 if ':' in host:
228 host, port = host.split(':')
229 else:
230 port = None
231 if scheme in ('ssh', 'git+ssh', 'ssh+git'):
232 return _open_ssh(host, port)
233 return False
234
235 m = URI_SCP.match(url)
236 if m:
237 host = m.group(1)
238 return _open_ssh(host)
239
240 return False
241
242def sock(create=True):
243 global _ssh_sock_path
244 if _ssh_sock_path is None:
245 if not create:
246 return None
247 tmp_dir = '/tmp'
248 if not os.path.exists(tmp_dir):
249 tmp_dir = tempfile.gettempdir()
250 if version() < (6, 7):
251 tokens = '%r@%h:%p'
252 else:
253 tokens = '%C' # hash of %l%h%p%r
254 _ssh_sock_path = os.path.join(
255 tempfile.mkdtemp('', 'ssh-', tmp_dir),
256 'master-' + tokens)
257 return _ssh_sock_path
diff --git a/tests/test_git_command.py b/tests/test_git_command.py
index 76c092f4..93300a6f 100644
--- a/tests/test_git_command.py
+++ b/tests/test_git_command.py
@@ -26,38 +26,6 @@ import git_command
26import wrapper 26import wrapper
27 27
28 28
29class SSHUnitTest(unittest.TestCase):
30 """Tests the ssh functions."""
31
32 def test_parse_ssh_version(self):
33 """Check parse_ssh_version() handling."""
34 ver = git_command._parse_ssh_version('Unknown\n')
35 self.assertEqual(ver, ())
36 ver = git_command._parse_ssh_version('OpenSSH_1.0\n')
37 self.assertEqual(ver, (1, 0))
38 ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
39 self.assertEqual(ver, (6, 6, 1))
40 ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
41 self.assertEqual(ver, (7, 6))
42
43 def test_ssh_version(self):
44 """Check ssh_version() handling."""
45 with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'):
46 self.assertEqual(git_command.ssh_version(), (1, 2))
47
48 def test_ssh_sock(self):
49 """Check ssh_sock() function."""
50 with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
51 # old ssh version uses port
52 with mock.patch('git_command.ssh_version', return_value=(6, 6)):
53 self.assertTrue(git_command.ssh_sock().endswith('%p'))
54 git_command._ssh_sock_path = None
55 # new ssh version uses hash
56 with mock.patch('git_command.ssh_version', return_value=(6, 7)):
57 self.assertTrue(git_command.ssh_sock().endswith('%C'))
58 git_command._ssh_sock_path = None
59
60
61class GitCallUnitTest(unittest.TestCase): 29class GitCallUnitTest(unittest.TestCase):
62 """Tests the _GitCall class (via git_command.git).""" 30 """Tests the _GitCall class (via git_command.git)."""
63 31
diff --git a/tests/test_ssh.py b/tests/test_ssh.py
new file mode 100644
index 00000000..5a4f27e4
--- /dev/null
+++ b/tests/test_ssh.py
@@ -0,0 +1,52 @@
1# Copyright 2019 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the ssh.py module."""
16
17import unittest
18from unittest import mock
19
20import ssh
21
22
23class SshTests(unittest.TestCase):
24 """Tests the ssh functions."""
25
26 def test_parse_ssh_version(self):
27 """Check _parse_ssh_version() handling."""
28 ver = ssh._parse_ssh_version('Unknown\n')
29 self.assertEqual(ver, ())
30 ver = ssh._parse_ssh_version('OpenSSH_1.0\n')
31 self.assertEqual(ver, (1, 0))
32 ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
33 self.assertEqual(ver, (6, 6, 1))
34 ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
35 self.assertEqual(ver, (7, 6))
36
37 def test_version(self):
38 """Check version() handling."""
39 with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
40 self.assertEqual(ssh.version(), (1, 2))
41
42 def test_ssh_sock(self):
43 """Check sock() function."""
44 with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
45 # old ssh version uses port
46 with mock.patch('ssh.version', return_value=(6, 6)):
47 self.assertTrue(ssh.sock().endswith('%p'))
48 ssh._ssh_sock_path = None
49 # new ssh version uses hash
50 with mock.patch('ssh.version', return_value=(6, 7)):
51 self.assertTrue(ssh.sock().endswith('%C'))
52 ssh._ssh_sock_path = None