summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMike Frysinger <vapier@google.com>2021-05-06 00:44:42 -0400
committerMike Frysinger <vapier@google.com>2021-05-10 21:16:06 +0000
commit339f2df1ddd741070e340ec01d6882dd1eee617c (patch)
treed16fe7c87ba966a400d545bef5a49c460b75dc57
parent19e409c81863878d5d313fdc40b3975b98602454 (diff)
downloadgit-repo-339f2df1ddd741070e340ec01d6882dd1eee617c.tar.gz
ssh: rewrite proxy management for multiprocessing usagev2.15
We changed sync to use multiprocessing for parallel work. This broke the ssh proxy code as it's all based on threads. Rewrite the logic to be multiprocessing safe. Now instead of the module acting as a stateful object, callers have to instantiate a new ProxyManager class that holds all the state, an pass that down to any users. Bug: https://crbug.com/gerrit/12389 Change-Id: I4b1af116f7306b91e825d3c56fb4274c9b033562 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305486 Tested-by: Mike Frysinger <vapier@google.com> Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
-rw-r--r--git_command.py12
-rw-r--r--git_config.py11
-rw-r--r--project.py4
-rw-r--r--ssh.py282
-rw-r--r--subcmds/sync.py44
-rw-r--r--tests/test_ssh.py30
6 files changed, 225 insertions, 158 deletions
diff --git a/git_command.py b/git_command.py
index fabad0e0..04953f38 100644
--- a/git_command.py
+++ b/git_command.py
@@ -21,7 +21,6 @@ from error import GitError
21from git_refs import HEAD 21from git_refs import HEAD
22import platform_utils 22import platform_utils
23from repo_trace import REPO_TRACE, IsTrace, Trace 23from repo_trace import REPO_TRACE, IsTrace, Trace
24import ssh
25from wrapper import Wrapper 24from wrapper import Wrapper
26 25
27GIT = 'git' 26GIT = 'git'
@@ -167,7 +166,7 @@ class GitCommand(object):
167 capture_stderr=False, 166 capture_stderr=False,
168 merge_output=False, 167 merge_output=False,
169 disable_editor=False, 168 disable_editor=False,
170 ssh_proxy=False, 169 ssh_proxy=None,
171 cwd=None, 170 cwd=None,
172 gitdir=None): 171 gitdir=None):
173 env = self._GetBasicEnv() 172 env = self._GetBasicEnv()
@@ -175,8 +174,8 @@ class GitCommand(object):
175 if disable_editor: 174 if disable_editor:
176 env['GIT_EDITOR'] = ':' 175 env['GIT_EDITOR'] = ':'
177 if ssh_proxy: 176 if ssh_proxy:
178 env['REPO_SSH_SOCK'] = ssh.sock() 177 env['REPO_SSH_SOCK'] = ssh_proxy.sock()
179 env['GIT_SSH'] = ssh.proxy() 178 env['GIT_SSH'] = ssh_proxy.proxy
180 env['GIT_SSH_VARIANT'] = 'ssh' 179 env['GIT_SSH_VARIANT'] = 'ssh'
181 if 'http_proxy' in env and 'darwin' == sys.platform: 180 if 'http_proxy' in env and 'darwin' == sys.platform:
182 s = "'http.proxy=%s'" % (env['http_proxy'],) 181 s = "'http.proxy=%s'" % (env['http_proxy'],)
@@ -259,7 +258,7 @@ class GitCommand(object):
259 raise GitError('%s: %s' % (command[1], e)) 258 raise GitError('%s: %s' % (command[1], e))
260 259
261 if ssh_proxy: 260 if ssh_proxy:
262 ssh.add_client(p) 261 ssh_proxy.add_client(p)
263 262
264 self.process = p 263 self.process = p
265 if input: 264 if input:
@@ -271,7 +270,8 @@ class GitCommand(object):
271 try: 270 try:
272 self.stdout, self.stderr = p.communicate() 271 self.stdout, self.stderr = p.communicate()
273 finally: 272 finally:
274 ssh.remove_client(p) 273 if ssh_proxy:
274 ssh_proxy.remove_client(p)
275 self.rc = p.wait() 275 self.rc = p.wait()
276 276
277 @staticmethod 277 @staticmethod
diff --git a/git_config.py b/git_config.py
index d7fef8ca..978f6a59 100644
--- a/git_config.py
+++ b/git_config.py
@@ -27,7 +27,6 @@ import urllib.request
27from error import GitError, UploadError 27from error import GitError, UploadError
28import platform_utils 28import platform_utils
29from repo_trace import Trace 29from repo_trace import Trace
30import ssh
31from git_command import GitCommand 30from git_command import GitCommand
32from git_refs import R_CHANGES, R_HEADS, R_TAGS 31from git_refs import R_CHANGES, R_HEADS, R_TAGS
33 32
@@ -519,17 +518,23 @@ class Remote(object):
519 518
520 return self.url.replace(longest, longestUrl, 1) 519 return self.url.replace(longest, longestUrl, 1)
521 520
522 def PreConnectFetch(self): 521 def PreConnectFetch(self, ssh_proxy):
523 """Run any setup for this remote before we connect to it. 522 """Run any setup for this remote before we connect to it.
524 523
525 In practice, if the remote is using SSH, we'll attempt to create a new 524 In practice, if the remote is using SSH, we'll attempt to create a new
526 SSH master session to it for reuse across projects. 525 SSH master session to it for reuse across projects.
527 526
527 Args:
528 ssh_proxy: The SSH settings for managing master sessions.
529
528 Returns: 530 Returns:
529 Whether the preconnect phase for this remote was successful. 531 Whether the preconnect phase for this remote was successful.
530 """ 532 """
533 if not ssh_proxy:
534 return True
535
531 connectionUrl = self._InsteadOf() 536 connectionUrl = self._InsteadOf()
532 return ssh.preconnect(connectionUrl) 537 return ssh_proxy.preconnect(connectionUrl)
533 538
534 def ReviewUrl(self, userEmail, validate_certs): 539 def ReviewUrl(self, userEmail, validate_certs):
535 if self._review_url is None: 540 if self._review_url is None:
diff --git a/project.py b/project.py
index 37558061..2f83d796 100644
--- a/project.py
+++ b/project.py
@@ -2045,8 +2045,8 @@ class Project(object):
2045 name = self.remote.name 2045 name = self.remote.name
2046 2046
2047 remote = self.GetRemote(name) 2047 remote = self.GetRemote(name)
2048 if not remote.PreConnectFetch(): 2048 if not remote.PreConnectFetch(ssh_proxy):
2049 ssh_proxy = False 2049 ssh_proxy = None
2050 2050
2051 if initial: 2051 if initial:
2052 if alt_dir and 'objects' == os.path.basename(alt_dir): 2052 if alt_dir and 'objects' == os.path.basename(alt_dir):
diff --git a/ssh.py b/ssh.py
index d06c4eb2..0ae8d120 100644
--- a/ssh.py
+++ b/ssh.py
@@ -15,25 +15,20 @@
15"""Common SSH management logic.""" 15"""Common SSH management logic."""
16 16
17import functools 17import functools
18import multiprocessing
18import os 19import os
19import re 20import re
20import signal 21import signal
21import subprocess 22import subprocess
22import sys 23import sys
23import tempfile 24import tempfile
24try:
25 import threading as _threading
26except ImportError:
27 import dummy_threading as _threading
28import time 25import time
29 26
30import platform_utils 27import platform_utils
31from repo_trace import Trace 28from repo_trace import Trace
32 29
33 30
34_ssh_proxy_path = None 31PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh')
35_ssh_sock_path = None
36_ssh_clients = []
37 32
38 33
39def _run_ssh_version(): 34def _run_ssh_version():
@@ -62,68 +57,104 @@ def version():
62 sys.exit(1) 57 sys.exit(1)
63 58
64 59
65def proxy(): 60URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
66 global _ssh_proxy_path 61URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
67 if _ssh_proxy_path is None:
68 _ssh_proxy_path = os.path.join(
69 os.path.dirname(__file__),
70 'git_ssh')
71 return _ssh_proxy_path
72 62
73 63
74def add_client(p): 64class ProxyManager:
75 _ssh_clients.append(p) 65 """Manage various ssh clients & masters that we spawn.
76 66
67 This will take care of sharing state between multiprocessing children, and
68 make sure that if we crash, we don't leak any of the ssh sessions.
77 69
78def remove_client(p): 70 The code should work with a single-process scenario too, and not add too much
79 try: 71 overhead due to the manager.
80 _ssh_clients.remove(p) 72 """
81 except ValueError:
82 pass
83
84 73
85def _terminate_clients(): 74 # Path to the ssh program to run which will pass our master settings along.
86 global _ssh_clients 75 # Set here more as a convenience API.
87 for p in _ssh_clients: 76 proxy = PROXY_PATH
77
78 def __init__(self, manager):
79 # Protect access to the list of active masters.
80 self._lock = multiprocessing.Lock()
81 # List of active masters (pid). These will be spawned on demand, and we are
82 # responsible for shutting them all down at the end.
83 self._masters = manager.list()
84 # Set of active masters indexed by "host:port" information.
85 # The value isn't used, but multiprocessing doesn't provide a set class.
86 self._master_keys = manager.dict()
87 # Whether ssh masters are known to be broken, so we give up entirely.
88 self._master_broken = manager.Value('b', False)
89 # List of active ssh sesssions. Clients will be added & removed as
90 # connections finish, so this list is just for safety & cleanup if we crash.
91 self._clients = manager.list()
92 # Path to directory for holding master sockets.
93 self._sock_path = None
94
95 def __enter__(self):
96 """Enter a new context."""
97 return self
98
99 def __exit__(self, exc_type, exc_value, traceback):
100 """Exit a context & clean up all resources."""
101 self.close()
102
103 def add_client(self, proc):
104 """Track a new ssh session."""
105 self._clients.append(proc.pid)
106
107 def remove_client(self, proc):
108 """Remove a completed ssh session."""
88 try: 109 try:
89 os.kill(p.pid, signal.SIGTERM) 110 self._clients.remove(proc.pid)
90 p.wait() 111 except ValueError:
91 except OSError:
92 pass 112 pass
93 _ssh_clients = []
94
95
96_master_processes = []
97_master_keys = set()
98_ssh_master = True
99_master_keys_lock = None
100
101
102def init():
103 """Should be called once at the start of repo to init ssh master handling.
104
105 At the moment, all we do is to create our lock.
106 """
107 global _master_keys_lock
108 assert _master_keys_lock is None, "Should only call init once"
109 _master_keys_lock = _threading.Lock()
110
111
112def _open_ssh(host, port=None):
113 global _ssh_master
114
115 # Bail before grabbing the lock if we already know that we aren't going to
116 # try creating new masters below.
117 if sys.platform in ('win32', 'cygwin'):
118 return False
119
120 # Acquire the lock. This is needed to prevent opening multiple masters for
121 # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
122 # manifest <remote fetch="ssh://xyz"> specifies a different host from the
123 # one that was passed to repo init.
124 _master_keys_lock.acquire()
125 try:
126 113
114 def add_master(self, proc):
115 """Track a new master connection."""
116 self._masters.append(proc.pid)
117
118 def _terminate(self, procs):
119 """Kill all |procs|."""
120 for pid in procs:
121 try:
122 os.kill(pid, signal.SIGTERM)
123 os.waitpid(pid, 0)
124 except OSError:
125 pass
126
127 # The multiprocessing.list() API doesn't provide many standard list()
128 # methods, so we have to manually clear the list.
129 while True:
130 try:
131 procs.pop(0)
132 except:
133 break
134
135 def close(self):
136 """Close this active ssh session.
137
138 Kill all ssh clients & masters we created, and nuke the socket dir.
139 """
140 self._terminate(self._clients)
141 self._terminate(self._masters)
142
143 d = self.sock(create=False)
144 if d:
145 try:
146 platform_utils.rmdir(os.path.dirname(d))
147 except OSError:
148 pass
149
150 def _open_unlocked(self, host, port=None):
151 """Make sure a ssh master session exists for |host| & |port|.
152
153 If one doesn't exist already, we'll create it.
154
155 We won't grab any locks, so the caller has to do that. This helps keep the
156 business logic of actually creating the master separate from grabbing locks.
157 """
127 # Check to see whether we already think that the master is running; if we 158 # Check to see whether we already think that the master is running; if we
128 # think it's already running, return right away. 159 # think it's already running, return right away.
129 if port is not None: 160 if port is not None:
@@ -131,17 +162,15 @@ def _open_ssh(host, port=None):
131 else: 162 else:
132 key = host 163 key = host
133 164
134 if key in _master_keys: 165 if key in self._master_keys:
135 return True 166 return True
136 167
137 if not _ssh_master or 'GIT_SSH' in os.environ: 168 if self._master_broken.value or 'GIT_SSH' in os.environ:
138 # Failed earlier, so don't retry. 169 # Failed earlier, so don't retry.
139 return False 170 return False
140 171
141 # We will make two calls to ssh; this is the common part of both calls. 172 # We will make two calls to ssh; this is the common part of both calls.
142 command_base = ['ssh', 173 command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host]
143 '-o', 'ControlPath %s' % sock(),
144 host]
145 if port is not None: 174 if port is not None:
146 command_base[1:1] = ['-p', str(port)] 175 command_base[1:1] = ['-p', str(port)]
147 176
@@ -161,7 +190,7 @@ def _open_ssh(host, port=None):
161 if not isnt_running: 190 if not isnt_running:
162 # Our double-check found that the master _was_ infact running. Add to 191 # Our double-check found that the master _was_ infact running. Add to
163 # the list of keys. 192 # the list of keys.
164 _master_keys.add(key) 193 self._master_keys[key] = True
165 return True 194 return True
166 except Exception: 195 except Exception:
167 # Ignore excpetions. We we will fall back to the normal command and print 196 # Ignore excpetions. We we will fall back to the normal command and print
@@ -173,7 +202,7 @@ def _open_ssh(host, port=None):
173 Trace(': %s', ' '.join(command)) 202 Trace(': %s', ' '.join(command))
174 p = subprocess.Popen(command) 203 p = subprocess.Popen(command)
175 except Exception as e: 204 except Exception as e:
176 _ssh_master = False 205 self._master_broken.value = True
177 print('\nwarn: cannot enable ssh control master for %s:%s\n%s' 206 print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
178 % (host, port, str(e)), file=sys.stderr) 207 % (host, port, str(e)), file=sys.stderr)
179 return False 208 return False
@@ -183,75 +212,66 @@ def _open_ssh(host, port=None):
183 if ssh_died: 212 if ssh_died:
184 return False 213 return False
185 214
186 _master_processes.append(p) 215 self.add_master(p)
187 _master_keys.add(key) 216 self._master_keys[key] = True
188 return True 217 return True
189 finally:
190 _master_keys_lock.release()
191 218
219 def _open(self, host, port=None):
220 """Make sure a ssh master session exists for |host| & |port|.
192 221
193def close(): 222 If one doesn't exist already, we'll create it.
194 global _master_keys_lock
195
196 _terminate_clients()
197
198 for p in _master_processes:
199 try:
200 os.kill(p.pid, signal.SIGTERM)
201 p.wait()
202 except OSError:
203 pass
204 del _master_processes[:]
205 _master_keys.clear()
206
207 d = sock(create=False)
208 if d:
209 try:
210 platform_utils.rmdir(os.path.dirname(d))
211 except OSError:
212 pass
213
214 # We're done with the lock, so we can delete it.
215 _master_keys_lock = None
216 223
224 This will obtain any necessary locks to avoid inter-process races.
225 """
226 # Bail before grabbing the lock if we already know that we aren't going to
227 # try creating new masters below.
228 if sys.platform in ('win32', 'cygwin'):
229 return False
217 230
218URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') 231 # Acquire the lock. This is needed to prevent opening multiple masters for
219URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') 232 # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
233 # manifest <remote fetch="ssh://xyz"> specifies a different host from the
234 # one that was passed to repo init.
235 with self._lock:
236 return self._open_unlocked(host, port)
237
238 def preconnect(self, url):
239 """If |uri| will create a ssh connection, setup the ssh master for it."""
240 m = URI_ALL.match(url)
241 if m:
242 scheme = m.group(1)
243 host = m.group(2)
244 if ':' in host:
245 host, port = host.split(':')
246 else:
247 port = None
248 if scheme in ('ssh', 'git+ssh', 'ssh+git'):
249 return self._open(host, port)
250 return False
220 251
252 m = URI_SCP.match(url)
253 if m:
254 host = m.group(1)
255 return self._open(host)
221 256
222def preconnect(url):
223 m = URI_ALL.match(url)
224 if m:
225 scheme = m.group(1)
226 host = m.group(2)
227 if ':' in host:
228 host, port = host.split(':')
229 else:
230 port = None
231 if scheme in ('ssh', 'git+ssh', 'ssh+git'):
232 return _open_ssh(host, port)
233 return False 257 return False
234 258
235 m = URI_SCP.match(url) 259 def sock(self, create=True):
236 if m: 260 """Return the path to the ssh socket dir.
237 host = m.group(1) 261
238 return _open_ssh(host) 262 This has all the master sockets so clients can talk to them.
239 263 """
240 return False 264 if self._sock_path is None:
241 265 if not create:
242def sock(create=True): 266 return None
243 global _ssh_sock_path 267 tmp_dir = '/tmp'
244 if _ssh_sock_path is None: 268 if not os.path.exists(tmp_dir):
245 if not create: 269 tmp_dir = tempfile.gettempdir()
246 return None 270 if version() < (6, 7):
247 tmp_dir = '/tmp' 271 tokens = '%r@%h:%p'
248 if not os.path.exists(tmp_dir): 272 else:
249 tmp_dir = tempfile.gettempdir() 273 tokens = '%C' # hash of %l%h%p%r
250 if version() < (6, 7): 274 self._sock_path = os.path.join(
251 tokens = '%r@%h:%p' 275 tempfile.mkdtemp('', 'ssh-', tmp_dir),
252 else: 276 'master-' + tokens)
253 tokens = '%C' # hash of %l%h%p%r 277 return self._sock_path
254 _ssh_sock_path = os.path.join(
255 tempfile.mkdtemp('', 'ssh-', tmp_dir),
256 'master-' + tokens)
257 return _ssh_sock_path
diff --git a/subcmds/sync.py b/subcmds/sync.py
index 28568062..fb25c221 100644
--- a/subcmds/sync.py
+++ b/subcmds/sync.py
@@ -358,7 +358,7 @@ later is required to fix a server side protocol bug.
358 optimized_fetch=opt.optimized_fetch, 358 optimized_fetch=opt.optimized_fetch,
359 retry_fetches=opt.retry_fetches, 359 retry_fetches=opt.retry_fetches,
360 prune=opt.prune, 360 prune=opt.prune,
361 ssh_proxy=True, 361 ssh_proxy=self.ssh_proxy,
362 clone_filter=self.manifest.CloneFilter, 362 clone_filter=self.manifest.CloneFilter,
363 partial_clone_exclude=self.manifest.PartialCloneExclude) 363 partial_clone_exclude=self.manifest.PartialCloneExclude)
364 364
@@ -380,7 +380,11 @@ later is required to fix a server side protocol bug.
380 finish = time.time() 380 finish = time.time()
381 return (success, project, start, finish) 381 return (success, project, start, finish)
382 382
383 def _Fetch(self, projects, opt, err_event): 383 @classmethod
384 def _FetchInitChild(cls, ssh_proxy):
385 cls.ssh_proxy = ssh_proxy
386
387 def _Fetch(self, projects, opt, err_event, ssh_proxy):
384 ret = True 388 ret = True
385 389
386 jobs = opt.jobs_network if opt.jobs_network else self.jobs 390 jobs = opt.jobs_network if opt.jobs_network else self.jobs
@@ -410,8 +414,14 @@ later is required to fix a server side protocol bug.
410 break 414 break
411 return ret 415 return ret
412 416
417 # We pass the ssh proxy settings via the class. This allows multiprocessing
418 # to pickle it up when spawning children. We can't pass it as an argument
419 # to _FetchProjectList below as multiprocessing is unable to pickle those.
420 Sync.ssh_proxy = None
421
413 # NB: Multiprocessing is heavy, so don't spin it up for one job. 422 # NB: Multiprocessing is heavy, so don't spin it up for one job.
414 if len(projects_list) == 1 or jobs == 1: 423 if len(projects_list) == 1 or jobs == 1:
424 self._FetchInitChild(ssh_proxy)
415 if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): 425 if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list):
416 ret = False 426 ret = False
417 else: 427 else:
@@ -429,7 +439,8 @@ later is required to fix a server side protocol bug.
429 else: 439 else:
430 pm.update(inc=0, msg='warming up') 440 pm.update(inc=0, msg='warming up')
431 chunksize = 4 441 chunksize = 4
432 with multiprocessing.Pool(jobs) as pool: 442 with multiprocessing.Pool(
443 jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,)) as pool:
433 results = pool.imap_unordered( 444 results = pool.imap_unordered(
434 functools.partial(self._FetchProjectList, opt), 445 functools.partial(self._FetchProjectList, opt),
435 projects_list, 446 projects_list,
@@ -438,6 +449,11 @@ later is required to fix a server side protocol bug.
438 ret = False 449 ret = False
439 pool.close() 450 pool.close()
440 451
452 # Cleanup the reference now that we're done with it, and we're going to
453 # release any resources it points to. If we don't, later multiprocessing
454 # usage (e.g. checkouts) will try to pickle and then crash.
455 del Sync.ssh_proxy
456
441 pm.end() 457 pm.end()
442 self._fetch_times.Save() 458 self._fetch_times.Save()
443 459
@@ -447,7 +463,7 @@ later is required to fix a server side protocol bug.
447 return (ret, fetched) 463 return (ret, fetched)
448 464
449 def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, 465 def _FetchMain(self, opt, args, all_projects, err_event, manifest_name,
450 load_local_manifests): 466 load_local_manifests, ssh_proxy):
451 """The main network fetch loop. 467 """The main network fetch loop.
452 468
453 Args: 469 Args:
@@ -457,6 +473,7 @@ later is required to fix a server side protocol bug.
457 err_event: Whether an error was hit while processing. 473 err_event: Whether an error was hit while processing.
458 manifest_name: Manifest file to be reloaded. 474 manifest_name: Manifest file to be reloaded.
459 load_local_manifests: Whether to load local manifests. 475 load_local_manifests: Whether to load local manifests.
476 ssh_proxy: SSH manager for clients & masters.
460 """ 477 """
461 rp = self.manifest.repoProject 478 rp = self.manifest.repoProject
462 479
@@ -467,7 +484,7 @@ later is required to fix a server side protocol bug.
467 to_fetch.extend(all_projects) 484 to_fetch.extend(all_projects)
468 to_fetch.sort(key=self._fetch_times.Get, reverse=True) 485 to_fetch.sort(key=self._fetch_times.Get, reverse=True)
469 486
470 success, fetched = self._Fetch(to_fetch, opt, err_event) 487 success, fetched = self._Fetch(to_fetch, opt, err_event, ssh_proxy)
471 if not success: 488 if not success:
472 err_event.set() 489 err_event.set()
473 490
@@ -498,7 +515,7 @@ later is required to fix a server side protocol bug.
498 if previously_missing_set == missing_set: 515 if previously_missing_set == missing_set:
499 break 516 break
500 previously_missing_set = missing_set 517 previously_missing_set = missing_set
501 success, new_fetched = self._Fetch(missing, opt, err_event) 518 success, new_fetched = self._Fetch(missing, opt, err_event, ssh_proxy)
502 if not success: 519 if not success:
503 err_event.set() 520 err_event.set()
504 fetched.update(new_fetched) 521 fetched.update(new_fetched)
@@ -985,12 +1002,15 @@ later is required to fix a server side protocol bug.
985 1002
986 self._fetch_times = _FetchTimes(self.manifest) 1003 self._fetch_times = _FetchTimes(self.manifest)
987 if not opt.local_only: 1004 if not opt.local_only:
988 try: 1005 with multiprocessing.Manager() as manager:
989 ssh.init() 1006 with ssh.ProxyManager(manager) as ssh_proxy:
990 self._FetchMain(opt, args, all_projects, err_event, manifest_name, 1007 # Initialize the socket dir once in the parent.
991 load_local_manifests) 1008 ssh_proxy.sock()
992 finally: 1009 self._FetchMain(opt, args, all_projects, err_event, manifest_name,
993 ssh.close() 1010 load_local_manifests, ssh_proxy)
1011
1012 if opt.network_only:
1013 return
994 1014
995 # If we saw an error, exit with code 1 so that other scripts can check. 1015 # If we saw an error, exit with code 1 so that other scripts can check.
996 if err_event.is_set(): 1016 if err_event.is_set():
diff --git a/tests/test_ssh.py b/tests/test_ssh.py
index 5a4f27e4..ffb5cb94 100644
--- a/tests/test_ssh.py
+++ b/tests/test_ssh.py
@@ -14,6 +14,8 @@
14 14
15"""Unittests for the ssh.py module.""" 15"""Unittests for the ssh.py module."""
16 16
17import multiprocessing
18import subprocess
17import unittest 19import unittest
18from unittest import mock 20from unittest import mock
19 21
@@ -39,14 +41,34 @@ class SshTests(unittest.TestCase):
39 with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): 41 with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
40 self.assertEqual(ssh.version(), (1, 2)) 42 self.assertEqual(ssh.version(), (1, 2))
41 43
44 def test_context_manager_empty(self):
45 """Verify context manager with no clients works correctly."""
46 with multiprocessing.Manager() as manager:
47 with ssh.ProxyManager(manager):
48 pass
49
50 def test_context_manager_child_cleanup(self):
51 """Verify orphaned clients & masters get cleaned up."""
52 with multiprocessing.Manager() as manager:
53 with ssh.ProxyManager(manager) as ssh_proxy:
54 client = subprocess.Popen(['sleep', '964853320'])
55 ssh_proxy.add_client(client)
56 master = subprocess.Popen(['sleep', '964853321'])
57 ssh_proxy.add_master(master)
58 # If the process still exists, these will throw timeout errors.
59 client.wait(0)
60 master.wait(0)
61
42 def test_ssh_sock(self): 62 def test_ssh_sock(self):
43 """Check sock() function.""" 63 """Check sock() function."""
64 manager = multiprocessing.Manager()
65 proxy = ssh.ProxyManager(manager)
44 with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): 66 with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
45 # old ssh version uses port 67 # old ssh version uses port
46 with mock.patch('ssh.version', return_value=(6, 6)): 68 with mock.patch('ssh.version', return_value=(6, 6)):
47 self.assertTrue(ssh.sock().endswith('%p')) 69 self.assertTrue(proxy.sock().endswith('%p'))
48 ssh._ssh_sock_path = None 70
71 proxy._sock_path = None
49 # new ssh version uses hash 72 # new ssh version uses hash
50 with mock.patch('ssh.version', return_value=(6, 7)): 73 with mock.patch('ssh.version', return_value=(6, 7)):
51 self.assertTrue(ssh.sock().endswith('%C')) 74 self.assertTrue(proxy.sock().endswith('%C'))
52 ssh._ssh_sock_path = None