From 339f2df1ddd741070e340ec01d6882dd1eee617c Mon Sep 17 00:00:00 2001 From: Mike Frysinger Date: Thu, 6 May 2021 00:44:42 -0400 Subject: ssh: rewrite proxy management for multiprocessing usage We changed sync to use multiprocessing for parallel work. This broke the ssh proxy code as it's all based on threads. Rewrite the logic to be multiprocessing safe. Now instead of the module acting as a stateful object, callers have to instantiate a new ProxyManager class that holds all the state, an pass that down to any users. Bug: https://crbug.com/gerrit/12389 Change-Id: I4b1af116f7306b91e825d3c56fb4274c9b033562 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305486 Tested-by: Mike Frysinger Reviewed-by: Chris Mcdonald --- subcmds/sync.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) (limited to 'subcmds') diff --git a/subcmds/sync.py b/subcmds/sync.py index 28568062..fb25c221 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py @@ -358,7 +358,7 @@ later is required to fix a server side protocol bug. optimized_fetch=opt.optimized_fetch, retry_fetches=opt.retry_fetches, prune=opt.prune, - ssh_proxy=True, + ssh_proxy=self.ssh_proxy, clone_filter=self.manifest.CloneFilter, partial_clone_exclude=self.manifest.PartialCloneExclude) @@ -380,7 +380,11 @@ later is required to fix a server side protocol bug. finish = time.time() return (success, project, start, finish) - def _Fetch(self, projects, opt, err_event): + @classmethod + def _FetchInitChild(cls, ssh_proxy): + cls.ssh_proxy = ssh_proxy + + def _Fetch(self, projects, opt, err_event, ssh_proxy): ret = True jobs = opt.jobs_network if opt.jobs_network else self.jobs @@ -410,8 +414,14 @@ later is required to fix a server side protocol bug. break return ret + # We pass the ssh proxy settings via the class. This allows multiprocessing + # to pickle it up when spawning children. We can't pass it as an argument + # to _FetchProjectList below as multiprocessing is unable to pickle those. + Sync.ssh_proxy = None + # NB: Multiprocessing is heavy, so don't spin it up for one job. if len(projects_list) == 1 or jobs == 1: + self._FetchInitChild(ssh_proxy) if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): ret = False else: @@ -429,7 +439,8 @@ later is required to fix a server side protocol bug. else: pm.update(inc=0, msg='warming up') chunksize = 4 - with multiprocessing.Pool(jobs) as pool: + with multiprocessing.Pool( + jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,)) as pool: results = pool.imap_unordered( functools.partial(self._FetchProjectList, opt), projects_list, @@ -438,6 +449,11 @@ later is required to fix a server side protocol bug. ret = False pool.close() + # Cleanup the reference now that we're done with it, and we're going to + # release any resources it points to. If we don't, later multiprocessing + # usage (e.g. checkouts) will try to pickle and then crash. + del Sync.ssh_proxy + pm.end() self._fetch_times.Save() @@ -447,7 +463,7 @@ later is required to fix a server side protocol bug. return (ret, fetched) def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, - load_local_manifests): + load_local_manifests, ssh_proxy): """The main network fetch loop. Args: @@ -457,6 +473,7 @@ later is required to fix a server side protocol bug. err_event: Whether an error was hit while processing. manifest_name: Manifest file to be reloaded. load_local_manifests: Whether to load local manifests. + ssh_proxy: SSH manager for clients & masters. """ rp = self.manifest.repoProject @@ -467,7 +484,7 @@ later is required to fix a server side protocol bug. to_fetch.extend(all_projects) to_fetch.sort(key=self._fetch_times.Get, reverse=True) - success, fetched = self._Fetch(to_fetch, opt, err_event) + success, fetched = self._Fetch(to_fetch, opt, err_event, ssh_proxy) if not success: err_event.set() @@ -498,7 +515,7 @@ later is required to fix a server side protocol bug. if previously_missing_set == missing_set: break previously_missing_set = missing_set - success, new_fetched = self._Fetch(missing, opt, err_event) + success, new_fetched = self._Fetch(missing, opt, err_event, ssh_proxy) if not success: err_event.set() fetched.update(new_fetched) @@ -985,12 +1002,15 @@ later is required to fix a server side protocol bug. self._fetch_times = _FetchTimes(self.manifest) if not opt.local_only: - try: - ssh.init() - self._FetchMain(opt, args, all_projects, err_event, manifest_name, - load_local_manifests) - finally: - ssh.close() + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager) as ssh_proxy: + # Initialize the socket dir once in the parent. + ssh_proxy.sock() + self._FetchMain(opt, args, all_projects, err_event, manifest_name, + load_local_manifests, ssh_proxy) + + if opt.network_only: + return # If we saw an error, exit with code 1 so that other scripts can check. if err_event.is_set(): -- cgit v1.2.3-54-g00ecf