diff options
-rw-r--r-- | git_command.py | 12 | ||||
-rw-r--r-- | git_config.py | 11 | ||||
-rw-r--r-- | project.py | 4 | ||||
-rw-r--r-- | ssh.py | 282 | ||||
-rw-r--r-- | subcmds/sync.py | 44 | ||||
-rw-r--r-- | tests/test_ssh.py | 30 |
6 files changed, 225 insertions, 158 deletions
diff --git a/git_command.py b/git_command.py index fabad0e0..04953f38 100644 --- a/git_command.py +++ b/git_command.py | |||
@@ -21,7 +21,6 @@ from error import GitError | |||
21 | from git_refs import HEAD | 21 | from git_refs import HEAD |
22 | import platform_utils | 22 | import platform_utils |
23 | from repo_trace import REPO_TRACE, IsTrace, Trace | 23 | from repo_trace import REPO_TRACE, IsTrace, Trace |
24 | import ssh | ||
25 | from wrapper import Wrapper | 24 | from wrapper import Wrapper |
26 | 25 | ||
27 | GIT = 'git' | 26 | GIT = 'git' |
@@ -167,7 +166,7 @@ class GitCommand(object): | |||
167 | capture_stderr=False, | 166 | capture_stderr=False, |
168 | merge_output=False, | 167 | merge_output=False, |
169 | disable_editor=False, | 168 | disable_editor=False, |
170 | ssh_proxy=False, | 169 | ssh_proxy=None, |
171 | cwd=None, | 170 | cwd=None, |
172 | gitdir=None): | 171 | gitdir=None): |
173 | env = self._GetBasicEnv() | 172 | env = self._GetBasicEnv() |
@@ -175,8 +174,8 @@ class GitCommand(object): | |||
175 | if disable_editor: | 174 | if disable_editor: |
176 | env['GIT_EDITOR'] = ':' | 175 | env['GIT_EDITOR'] = ':' |
177 | if ssh_proxy: | 176 | if ssh_proxy: |
178 | env['REPO_SSH_SOCK'] = ssh.sock() | 177 | env['REPO_SSH_SOCK'] = ssh_proxy.sock() |
179 | env['GIT_SSH'] = ssh.proxy() | 178 | env['GIT_SSH'] = ssh_proxy.proxy |
180 | env['GIT_SSH_VARIANT'] = 'ssh' | 179 | env['GIT_SSH_VARIANT'] = 'ssh' |
181 | if 'http_proxy' in env and 'darwin' == sys.platform: | 180 | if 'http_proxy' in env and 'darwin' == sys.platform: |
182 | s = "'http.proxy=%s'" % (env['http_proxy'],) | 181 | s = "'http.proxy=%s'" % (env['http_proxy'],) |
@@ -259,7 +258,7 @@ class GitCommand(object): | |||
259 | raise GitError('%s: %s' % (command[1], e)) | 258 | raise GitError('%s: %s' % (command[1], e)) |
260 | 259 | ||
261 | if ssh_proxy: | 260 | if ssh_proxy: |
262 | ssh.add_client(p) | 261 | ssh_proxy.add_client(p) |
263 | 262 | ||
264 | self.process = p | 263 | self.process = p |
265 | if input: | 264 | if input: |
@@ -271,7 +270,8 @@ class GitCommand(object): | |||
271 | try: | 270 | try: |
272 | self.stdout, self.stderr = p.communicate() | 271 | self.stdout, self.stderr = p.communicate() |
273 | finally: | 272 | finally: |
274 | ssh.remove_client(p) | 273 | if ssh_proxy: |
274 | ssh_proxy.remove_client(p) | ||
275 | self.rc = p.wait() | 275 | self.rc = p.wait() |
276 | 276 | ||
277 | @staticmethod | 277 | @staticmethod |
diff --git a/git_config.py b/git_config.py index d7fef8ca..978f6a59 100644 --- a/git_config.py +++ b/git_config.py | |||
@@ -27,7 +27,6 @@ import urllib.request | |||
27 | from error import GitError, UploadError | 27 | from error import GitError, UploadError |
28 | import platform_utils | 28 | import platform_utils |
29 | from repo_trace import Trace | 29 | from repo_trace import Trace |
30 | import ssh | ||
31 | from git_command import GitCommand | 30 | from git_command import GitCommand |
32 | from git_refs import R_CHANGES, R_HEADS, R_TAGS | 31 | from git_refs import R_CHANGES, R_HEADS, R_TAGS |
33 | 32 | ||
@@ -519,17 +518,23 @@ class Remote(object): | |||
519 | 518 | ||
520 | return self.url.replace(longest, longestUrl, 1) | 519 | return self.url.replace(longest, longestUrl, 1) |
521 | 520 | ||
522 | def PreConnectFetch(self): | 521 | def PreConnectFetch(self, ssh_proxy): |
523 | """Run any setup for this remote before we connect to it. | 522 | """Run any setup for this remote before we connect to it. |
524 | 523 | ||
525 | In practice, if the remote is using SSH, we'll attempt to create a new | 524 | In practice, if the remote is using SSH, we'll attempt to create a new |
526 | SSH master session to it for reuse across projects. | 525 | SSH master session to it for reuse across projects. |
527 | 526 | ||
527 | Args: | ||
528 | ssh_proxy: The SSH settings for managing master sessions. | ||
529 | |||
528 | Returns: | 530 | Returns: |
529 | Whether the preconnect phase for this remote was successful. | 531 | Whether the preconnect phase for this remote was successful. |
530 | """ | 532 | """ |
533 | if not ssh_proxy: | ||
534 | return True | ||
535 | |||
531 | connectionUrl = self._InsteadOf() | 536 | connectionUrl = self._InsteadOf() |
532 | return ssh.preconnect(connectionUrl) | 537 | return ssh_proxy.preconnect(connectionUrl) |
533 | 538 | ||
534 | def ReviewUrl(self, userEmail, validate_certs): | 539 | def ReviewUrl(self, userEmail, validate_certs): |
535 | if self._review_url is None: | 540 | if self._review_url is None: |
@@ -2045,8 +2045,8 @@ class Project(object): | |||
2045 | name = self.remote.name | 2045 | name = self.remote.name |
2046 | 2046 | ||
2047 | remote = self.GetRemote(name) | 2047 | remote = self.GetRemote(name) |
2048 | if not remote.PreConnectFetch(): | 2048 | if not remote.PreConnectFetch(ssh_proxy): |
2049 | ssh_proxy = False | 2049 | ssh_proxy = None |
2050 | 2050 | ||
2051 | if initial: | 2051 | if initial: |
2052 | if alt_dir and 'objects' == os.path.basename(alt_dir): | 2052 | if alt_dir and 'objects' == os.path.basename(alt_dir): |
@@ -15,25 +15,20 @@ | |||
15 | """Common SSH management logic.""" | 15 | """Common SSH management logic.""" |
16 | 16 | ||
17 | import functools | 17 | import functools |
18 | import multiprocessing | ||
18 | import os | 19 | import os |
19 | import re | 20 | import re |
20 | import signal | 21 | import signal |
21 | import subprocess | 22 | import subprocess |
22 | import sys | 23 | import sys |
23 | import tempfile | 24 | import tempfile |
24 | try: | ||
25 | import threading as _threading | ||
26 | except ImportError: | ||
27 | import dummy_threading as _threading | ||
28 | import time | 25 | import time |
29 | 26 | ||
30 | import platform_utils | 27 | import platform_utils |
31 | from repo_trace import Trace | 28 | from repo_trace import Trace |
32 | 29 | ||
33 | 30 | ||
34 | _ssh_proxy_path = None | 31 | PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh') |
35 | _ssh_sock_path = None | ||
36 | _ssh_clients = [] | ||
37 | 32 | ||
38 | 33 | ||
39 | def _run_ssh_version(): | 34 | def _run_ssh_version(): |
@@ -62,68 +57,104 @@ def version(): | |||
62 | sys.exit(1) | 57 | sys.exit(1) |
63 | 58 | ||
64 | 59 | ||
65 | def proxy(): | 60 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') |
66 | global _ssh_proxy_path | 61 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') |
67 | if _ssh_proxy_path is None: | ||
68 | _ssh_proxy_path = os.path.join( | ||
69 | os.path.dirname(__file__), | ||
70 | 'git_ssh') | ||
71 | return _ssh_proxy_path | ||
72 | 62 | ||
73 | 63 | ||
74 | def add_client(p): | 64 | class ProxyManager: |
75 | _ssh_clients.append(p) | 65 | """Manage various ssh clients & masters that we spawn. |
76 | 66 | ||
67 | This will take care of sharing state between multiprocessing children, and | ||
68 | make sure that if we crash, we don't leak any of the ssh sessions. | ||
77 | 69 | ||
78 | def remove_client(p): | 70 | The code should work with a single-process scenario too, and not add too much |
79 | try: | 71 | overhead due to the manager. |
80 | _ssh_clients.remove(p) | 72 | """ |
81 | except ValueError: | ||
82 | pass | ||
83 | |||
84 | 73 | ||
85 | def _terminate_clients(): | 74 | # Path to the ssh program to run which will pass our master settings along. |
86 | global _ssh_clients | 75 | # Set here more as a convenience API. |
87 | for p in _ssh_clients: | 76 | proxy = PROXY_PATH |
77 | |||
78 | def __init__(self, manager): | ||
79 | # Protect access to the list of active masters. | ||
80 | self._lock = multiprocessing.Lock() | ||
81 | # List of active masters (pid). These will be spawned on demand, and we are | ||
82 | # responsible for shutting them all down at the end. | ||
83 | self._masters = manager.list() | ||
84 | # Set of active masters indexed by "host:port" information. | ||
85 | # The value isn't used, but multiprocessing doesn't provide a set class. | ||
86 | self._master_keys = manager.dict() | ||
87 | # Whether ssh masters are known to be broken, so we give up entirely. | ||
88 | self._master_broken = manager.Value('b', False) | ||
89 | # List of active ssh sesssions. Clients will be added & removed as | ||
90 | # connections finish, so this list is just for safety & cleanup if we crash. | ||
91 | self._clients = manager.list() | ||
92 | # Path to directory for holding master sockets. | ||
93 | self._sock_path = None | ||
94 | |||
95 | def __enter__(self): | ||
96 | """Enter a new context.""" | ||
97 | return self | ||
98 | |||
99 | def __exit__(self, exc_type, exc_value, traceback): | ||
100 | """Exit a context & clean up all resources.""" | ||
101 | self.close() | ||
102 | |||
103 | def add_client(self, proc): | ||
104 | """Track a new ssh session.""" | ||
105 | self._clients.append(proc.pid) | ||
106 | |||
107 | def remove_client(self, proc): | ||
108 | """Remove a completed ssh session.""" | ||
88 | try: | 109 | try: |
89 | os.kill(p.pid, signal.SIGTERM) | 110 | self._clients.remove(proc.pid) |
90 | p.wait() | 111 | except ValueError: |
91 | except OSError: | ||
92 | pass | 112 | pass |
93 | _ssh_clients = [] | ||
94 | |||
95 | |||
96 | _master_processes = [] | ||
97 | _master_keys = set() | ||
98 | _ssh_master = True | ||
99 | _master_keys_lock = None | ||
100 | |||
101 | |||
102 | def init(): | ||
103 | """Should be called once at the start of repo to init ssh master handling. | ||
104 | |||
105 | At the moment, all we do is to create our lock. | ||
106 | """ | ||
107 | global _master_keys_lock | ||
108 | assert _master_keys_lock is None, "Should only call init once" | ||
109 | _master_keys_lock = _threading.Lock() | ||
110 | |||
111 | |||
112 | def _open_ssh(host, port=None): | ||
113 | global _ssh_master | ||
114 | |||
115 | # Bail before grabbing the lock if we already know that we aren't going to | ||
116 | # try creating new masters below. | ||
117 | if sys.platform in ('win32', 'cygwin'): | ||
118 | return False | ||
119 | |||
120 | # Acquire the lock. This is needed to prevent opening multiple masters for | ||
121 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the | ||
122 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
123 | # one that was passed to repo init. | ||
124 | _master_keys_lock.acquire() | ||
125 | try: | ||
126 | 113 | ||
114 | def add_master(self, proc): | ||
115 | """Track a new master connection.""" | ||
116 | self._masters.append(proc.pid) | ||
117 | |||
118 | def _terminate(self, procs): | ||
119 | """Kill all |procs|.""" | ||
120 | for pid in procs: | ||
121 | try: | ||
122 | os.kill(pid, signal.SIGTERM) | ||
123 | os.waitpid(pid, 0) | ||
124 | except OSError: | ||
125 | pass | ||
126 | |||
127 | # The multiprocessing.list() API doesn't provide many standard list() | ||
128 | # methods, so we have to manually clear the list. | ||
129 | while True: | ||
130 | try: | ||
131 | procs.pop(0) | ||
132 | except: | ||
133 | break | ||
134 | |||
135 | def close(self): | ||
136 | """Close this active ssh session. | ||
137 | |||
138 | Kill all ssh clients & masters we created, and nuke the socket dir. | ||
139 | """ | ||
140 | self._terminate(self._clients) | ||
141 | self._terminate(self._masters) | ||
142 | |||
143 | d = self.sock(create=False) | ||
144 | if d: | ||
145 | try: | ||
146 | platform_utils.rmdir(os.path.dirname(d)) | ||
147 | except OSError: | ||
148 | pass | ||
149 | |||
150 | def _open_unlocked(self, host, port=None): | ||
151 | """Make sure a ssh master session exists for |host| & |port|. | ||
152 | |||
153 | If one doesn't exist already, we'll create it. | ||
154 | |||
155 | We won't grab any locks, so the caller has to do that. This helps keep the | ||
156 | business logic of actually creating the master separate from grabbing locks. | ||
157 | """ | ||
127 | # Check to see whether we already think that the master is running; if we | 158 | # Check to see whether we already think that the master is running; if we |
128 | # think it's already running, return right away. | 159 | # think it's already running, return right away. |
129 | if port is not None: | 160 | if port is not None: |
@@ -131,17 +162,15 @@ def _open_ssh(host, port=None): | |||
131 | else: | 162 | else: |
132 | key = host | 163 | key = host |
133 | 164 | ||
134 | if key in _master_keys: | 165 | if key in self._master_keys: |
135 | return True | 166 | return True |
136 | 167 | ||
137 | if not _ssh_master or 'GIT_SSH' in os.environ: | 168 | if self._master_broken.value or 'GIT_SSH' in os.environ: |
138 | # Failed earlier, so don't retry. | 169 | # Failed earlier, so don't retry. |
139 | return False | 170 | return False |
140 | 171 | ||
141 | # We will make two calls to ssh; this is the common part of both calls. | 172 | # We will make two calls to ssh; this is the common part of both calls. |
142 | command_base = ['ssh', | 173 | command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host] |
143 | '-o', 'ControlPath %s' % sock(), | ||
144 | host] | ||
145 | if port is not None: | 174 | if port is not None: |
146 | command_base[1:1] = ['-p', str(port)] | 175 | command_base[1:1] = ['-p', str(port)] |
147 | 176 | ||
@@ -161,7 +190,7 @@ def _open_ssh(host, port=None): | |||
161 | if not isnt_running: | 190 | if not isnt_running: |
162 | # Our double-check found that the master _was_ infact running. Add to | 191 | # Our double-check found that the master _was_ infact running. Add to |
163 | # the list of keys. | 192 | # the list of keys. |
164 | _master_keys.add(key) | 193 | self._master_keys[key] = True |
165 | return True | 194 | return True |
166 | except Exception: | 195 | except Exception: |
167 | # Ignore excpetions. We we will fall back to the normal command and print | 196 | # Ignore excpetions. We we will fall back to the normal command and print |
@@ -173,7 +202,7 @@ def _open_ssh(host, port=None): | |||
173 | Trace(': %s', ' '.join(command)) | 202 | Trace(': %s', ' '.join(command)) |
174 | p = subprocess.Popen(command) | 203 | p = subprocess.Popen(command) |
175 | except Exception as e: | 204 | except Exception as e: |
176 | _ssh_master = False | 205 | self._master_broken.value = True |
177 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' | 206 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' |
178 | % (host, port, str(e)), file=sys.stderr) | 207 | % (host, port, str(e)), file=sys.stderr) |
179 | return False | 208 | return False |
@@ -183,75 +212,66 @@ def _open_ssh(host, port=None): | |||
183 | if ssh_died: | 212 | if ssh_died: |
184 | return False | 213 | return False |
185 | 214 | ||
186 | _master_processes.append(p) | 215 | self.add_master(p) |
187 | _master_keys.add(key) | 216 | self._master_keys[key] = True |
188 | return True | 217 | return True |
189 | finally: | ||
190 | _master_keys_lock.release() | ||
191 | 218 | ||
219 | def _open(self, host, port=None): | ||
220 | """Make sure a ssh master session exists for |host| & |port|. | ||
192 | 221 | ||
193 | def close(): | 222 | If one doesn't exist already, we'll create it. |
194 | global _master_keys_lock | ||
195 | |||
196 | _terminate_clients() | ||
197 | |||
198 | for p in _master_processes: | ||
199 | try: | ||
200 | os.kill(p.pid, signal.SIGTERM) | ||
201 | p.wait() | ||
202 | except OSError: | ||
203 | pass | ||
204 | del _master_processes[:] | ||
205 | _master_keys.clear() | ||
206 | |||
207 | d = sock(create=False) | ||
208 | if d: | ||
209 | try: | ||
210 | platform_utils.rmdir(os.path.dirname(d)) | ||
211 | except OSError: | ||
212 | pass | ||
213 | |||
214 | # We're done with the lock, so we can delete it. | ||
215 | _master_keys_lock = None | ||
216 | 223 | ||
224 | This will obtain any necessary locks to avoid inter-process races. | ||
225 | """ | ||
226 | # Bail before grabbing the lock if we already know that we aren't going to | ||
227 | # try creating new masters below. | ||
228 | if sys.platform in ('win32', 'cygwin'): | ||
229 | return False | ||
217 | 230 | ||
218 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') | 231 | # Acquire the lock. This is needed to prevent opening multiple masters for |
219 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') | 232 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the |
233 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
234 | # one that was passed to repo init. | ||
235 | with self._lock: | ||
236 | return self._open_unlocked(host, port) | ||
237 | |||
238 | def preconnect(self, url): | ||
239 | """If |uri| will create a ssh connection, setup the ssh master for it.""" | ||
240 | m = URI_ALL.match(url) | ||
241 | if m: | ||
242 | scheme = m.group(1) | ||
243 | host = m.group(2) | ||
244 | if ':' in host: | ||
245 | host, port = host.split(':') | ||
246 | else: | ||
247 | port = None | ||
248 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
249 | return self._open(host, port) | ||
250 | return False | ||
220 | 251 | ||
252 | m = URI_SCP.match(url) | ||
253 | if m: | ||
254 | host = m.group(1) | ||
255 | return self._open(host) | ||
221 | 256 | ||
222 | def preconnect(url): | ||
223 | m = URI_ALL.match(url) | ||
224 | if m: | ||
225 | scheme = m.group(1) | ||
226 | host = m.group(2) | ||
227 | if ':' in host: | ||
228 | host, port = host.split(':') | ||
229 | else: | ||
230 | port = None | ||
231 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
232 | return _open_ssh(host, port) | ||
233 | return False | 257 | return False |
234 | 258 | ||
235 | m = URI_SCP.match(url) | 259 | def sock(self, create=True): |
236 | if m: | 260 | """Return the path to the ssh socket dir. |
237 | host = m.group(1) | 261 | |
238 | return _open_ssh(host) | 262 | This has all the master sockets so clients can talk to them. |
239 | 263 | """ | |
240 | return False | 264 | if self._sock_path is None: |
241 | 265 | if not create: | |
242 | def sock(create=True): | 266 | return None |
243 | global _ssh_sock_path | 267 | tmp_dir = '/tmp' |
244 | if _ssh_sock_path is None: | 268 | if not os.path.exists(tmp_dir): |
245 | if not create: | 269 | tmp_dir = tempfile.gettempdir() |
246 | return None | 270 | if version() < (6, 7): |
247 | tmp_dir = '/tmp' | 271 | tokens = '%r@%h:%p' |
248 | if not os.path.exists(tmp_dir): | 272 | else: |
249 | tmp_dir = tempfile.gettempdir() | 273 | tokens = '%C' # hash of %l%h%p%r |
250 | if version() < (6, 7): | 274 | self._sock_path = os.path.join( |
251 | tokens = '%r@%h:%p' | 275 | tempfile.mkdtemp('', 'ssh-', tmp_dir), |
252 | else: | 276 | 'master-' + tokens) |
253 | tokens = '%C' # hash of %l%h%p%r | 277 | return self._sock_path |
254 | _ssh_sock_path = os.path.join( | ||
255 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | ||
256 | 'master-' + tokens) | ||
257 | return _ssh_sock_path | ||
diff --git a/subcmds/sync.py b/subcmds/sync.py index 28568062..fb25c221 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py | |||
@@ -358,7 +358,7 @@ later is required to fix a server side protocol bug. | |||
358 | optimized_fetch=opt.optimized_fetch, | 358 | optimized_fetch=opt.optimized_fetch, |
359 | retry_fetches=opt.retry_fetches, | 359 | retry_fetches=opt.retry_fetches, |
360 | prune=opt.prune, | 360 | prune=opt.prune, |
361 | ssh_proxy=True, | 361 | ssh_proxy=self.ssh_proxy, |
362 | clone_filter=self.manifest.CloneFilter, | 362 | clone_filter=self.manifest.CloneFilter, |
363 | partial_clone_exclude=self.manifest.PartialCloneExclude) | 363 | partial_clone_exclude=self.manifest.PartialCloneExclude) |
364 | 364 | ||
@@ -380,7 +380,11 @@ later is required to fix a server side protocol bug. | |||
380 | finish = time.time() | 380 | finish = time.time() |
381 | return (success, project, start, finish) | 381 | return (success, project, start, finish) |
382 | 382 | ||
383 | def _Fetch(self, projects, opt, err_event): | 383 | @classmethod |
384 | def _FetchInitChild(cls, ssh_proxy): | ||
385 | cls.ssh_proxy = ssh_proxy | ||
386 | |||
387 | def _Fetch(self, projects, opt, err_event, ssh_proxy): | ||
384 | ret = True | 388 | ret = True |
385 | 389 | ||
386 | jobs = opt.jobs_network if opt.jobs_network else self.jobs | 390 | jobs = opt.jobs_network if opt.jobs_network else self.jobs |
@@ -410,8 +414,14 @@ later is required to fix a server side protocol bug. | |||
410 | break | 414 | break |
411 | return ret | 415 | return ret |
412 | 416 | ||
417 | # We pass the ssh proxy settings via the class. This allows multiprocessing | ||
418 | # to pickle it up when spawning children. We can't pass it as an argument | ||
419 | # to _FetchProjectList below as multiprocessing is unable to pickle those. | ||
420 | Sync.ssh_proxy = None | ||
421 | |||
413 | # NB: Multiprocessing is heavy, so don't spin it up for one job. | 422 | # NB: Multiprocessing is heavy, so don't spin it up for one job. |
414 | if len(projects_list) == 1 or jobs == 1: | 423 | if len(projects_list) == 1 or jobs == 1: |
424 | self._FetchInitChild(ssh_proxy) | ||
415 | if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): | 425 | if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): |
416 | ret = False | 426 | ret = False |
417 | else: | 427 | else: |
@@ -429,7 +439,8 @@ later is required to fix a server side protocol bug. | |||
429 | else: | 439 | else: |
430 | pm.update(inc=0, msg='warming up') | 440 | pm.update(inc=0, msg='warming up') |
431 | chunksize = 4 | 441 | chunksize = 4 |
432 | with multiprocessing.Pool(jobs) as pool: | 442 | with multiprocessing.Pool( |
443 | jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,)) as pool: | ||
433 | results = pool.imap_unordered( | 444 | results = pool.imap_unordered( |
434 | functools.partial(self._FetchProjectList, opt), | 445 | functools.partial(self._FetchProjectList, opt), |
435 | projects_list, | 446 | projects_list, |
@@ -438,6 +449,11 @@ later is required to fix a server side protocol bug. | |||
438 | ret = False | 449 | ret = False |
439 | pool.close() | 450 | pool.close() |
440 | 451 | ||
452 | # Cleanup the reference now that we're done with it, and we're going to | ||
453 | # release any resources it points to. If we don't, later multiprocessing | ||
454 | # usage (e.g. checkouts) will try to pickle and then crash. | ||
455 | del Sync.ssh_proxy | ||
456 | |||
441 | pm.end() | 457 | pm.end() |
442 | self._fetch_times.Save() | 458 | self._fetch_times.Save() |
443 | 459 | ||
@@ -447,7 +463,7 @@ later is required to fix a server side protocol bug. | |||
447 | return (ret, fetched) | 463 | return (ret, fetched) |
448 | 464 | ||
449 | def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, | 465 | def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, |
450 | load_local_manifests): | 466 | load_local_manifests, ssh_proxy): |
451 | """The main network fetch loop. | 467 | """The main network fetch loop. |
452 | 468 | ||
453 | Args: | 469 | Args: |
@@ -457,6 +473,7 @@ later is required to fix a server side protocol bug. | |||
457 | err_event: Whether an error was hit while processing. | 473 | err_event: Whether an error was hit while processing. |
458 | manifest_name: Manifest file to be reloaded. | 474 | manifest_name: Manifest file to be reloaded. |
459 | load_local_manifests: Whether to load local manifests. | 475 | load_local_manifests: Whether to load local manifests. |
476 | ssh_proxy: SSH manager for clients & masters. | ||
460 | """ | 477 | """ |
461 | rp = self.manifest.repoProject | 478 | rp = self.manifest.repoProject |
462 | 479 | ||
@@ -467,7 +484,7 @@ later is required to fix a server side protocol bug. | |||
467 | to_fetch.extend(all_projects) | 484 | to_fetch.extend(all_projects) |
468 | to_fetch.sort(key=self._fetch_times.Get, reverse=True) | 485 | to_fetch.sort(key=self._fetch_times.Get, reverse=True) |
469 | 486 | ||
470 | success, fetched = self._Fetch(to_fetch, opt, err_event) | 487 | success, fetched = self._Fetch(to_fetch, opt, err_event, ssh_proxy) |
471 | if not success: | 488 | if not success: |
472 | err_event.set() | 489 | err_event.set() |
473 | 490 | ||
@@ -498,7 +515,7 @@ later is required to fix a server side protocol bug. | |||
498 | if previously_missing_set == missing_set: | 515 | if previously_missing_set == missing_set: |
499 | break | 516 | break |
500 | previously_missing_set = missing_set | 517 | previously_missing_set = missing_set |
501 | success, new_fetched = self._Fetch(missing, opt, err_event) | 518 | success, new_fetched = self._Fetch(missing, opt, err_event, ssh_proxy) |
502 | if not success: | 519 | if not success: |
503 | err_event.set() | 520 | err_event.set() |
504 | fetched.update(new_fetched) | 521 | fetched.update(new_fetched) |
@@ -985,12 +1002,15 @@ later is required to fix a server side protocol bug. | |||
985 | 1002 | ||
986 | self._fetch_times = _FetchTimes(self.manifest) | 1003 | self._fetch_times = _FetchTimes(self.manifest) |
987 | if not opt.local_only: | 1004 | if not opt.local_only: |
988 | try: | 1005 | with multiprocessing.Manager() as manager: |
989 | ssh.init() | 1006 | with ssh.ProxyManager(manager) as ssh_proxy: |
990 | self._FetchMain(opt, args, all_projects, err_event, manifest_name, | 1007 | # Initialize the socket dir once in the parent. |
991 | load_local_manifests) | 1008 | ssh_proxy.sock() |
992 | finally: | 1009 | self._FetchMain(opt, args, all_projects, err_event, manifest_name, |
993 | ssh.close() | 1010 | load_local_manifests, ssh_proxy) |
1011 | |||
1012 | if opt.network_only: | ||
1013 | return | ||
994 | 1014 | ||
995 | # If we saw an error, exit with code 1 so that other scripts can check. | 1015 | # If we saw an error, exit with code 1 so that other scripts can check. |
996 | if err_event.is_set(): | 1016 | if err_event.is_set(): |
diff --git a/tests/test_ssh.py b/tests/test_ssh.py index 5a4f27e4..ffb5cb94 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py | |||
@@ -14,6 +14,8 @@ | |||
14 | 14 | ||
15 | """Unittests for the ssh.py module.""" | 15 | """Unittests for the ssh.py module.""" |
16 | 16 | ||
17 | import multiprocessing | ||
18 | import subprocess | ||
17 | import unittest | 19 | import unittest |
18 | from unittest import mock | 20 | from unittest import mock |
19 | 21 | ||
@@ -39,14 +41,34 @@ class SshTests(unittest.TestCase): | |||
39 | with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): | 41 | with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): |
40 | self.assertEqual(ssh.version(), (1, 2)) | 42 | self.assertEqual(ssh.version(), (1, 2)) |
41 | 43 | ||
44 | def test_context_manager_empty(self): | ||
45 | """Verify context manager with no clients works correctly.""" | ||
46 | with multiprocessing.Manager() as manager: | ||
47 | with ssh.ProxyManager(manager): | ||
48 | pass | ||
49 | |||
50 | def test_context_manager_child_cleanup(self): | ||
51 | """Verify orphaned clients & masters get cleaned up.""" | ||
52 | with multiprocessing.Manager() as manager: | ||
53 | with ssh.ProxyManager(manager) as ssh_proxy: | ||
54 | client = subprocess.Popen(['sleep', '964853320']) | ||
55 | ssh_proxy.add_client(client) | ||
56 | master = subprocess.Popen(['sleep', '964853321']) | ||
57 | ssh_proxy.add_master(master) | ||
58 | # If the process still exists, these will throw timeout errors. | ||
59 | client.wait(0) | ||
60 | master.wait(0) | ||
61 | |||
42 | def test_ssh_sock(self): | 62 | def test_ssh_sock(self): |
43 | """Check sock() function.""" | 63 | """Check sock() function.""" |
64 | manager = multiprocessing.Manager() | ||
65 | proxy = ssh.ProxyManager(manager) | ||
44 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): | 66 | with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): |
45 | # old ssh version uses port | 67 | # old ssh version uses port |
46 | with mock.patch('ssh.version', return_value=(6, 6)): | 68 | with mock.patch('ssh.version', return_value=(6, 6)): |
47 | self.assertTrue(ssh.sock().endswith('%p')) | 69 | self.assertTrue(proxy.sock().endswith('%p')) |
48 | ssh._ssh_sock_path = None | 70 | |
71 | proxy._sock_path = None | ||
49 | # new ssh version uses hash | 72 | # new ssh version uses hash |
50 | with mock.patch('ssh.version', return_value=(6, 7)): | 73 | with mock.patch('ssh.version', return_value=(6, 7)): |
51 | self.assertTrue(ssh.sock().endswith('%C')) | 74 | self.assertTrue(proxy.sock().endswith('%C')) |
52 | ssh._ssh_sock_path = None | ||