diff options
Diffstat (limited to 'command.py')
-rw-r--r-- | command.py | 50 |
1 files changed, 46 insertions, 4 deletions
@@ -12,6 +12,7 @@ | |||
12 | # See the License for the specific language governing permissions and | 12 | # See the License for the specific language governing permissions and |
13 | # limitations under the License. | 13 | # limitations under the License. |
14 | 14 | ||
15 | import contextlib | ||
15 | import multiprocessing | 16 | import multiprocessing |
16 | import optparse | 17 | import optparse |
17 | import os | 18 | import os |
@@ -70,6 +71,14 @@ class Command: | |||
70 | # migrated subcommands can set it to False. | 71 | # migrated subcommands can set it to False. |
71 | MULTI_MANIFEST_SUPPORT = True | 72 | MULTI_MANIFEST_SUPPORT = True |
72 | 73 | ||
74 | # Shared data across parallel execution workers. | ||
75 | _parallel_context = None | ||
76 | |||
77 | @classmethod | ||
78 | def get_parallel_context(cls): | ||
79 | assert cls._parallel_context is not None | ||
80 | return cls._parallel_context | ||
81 | |||
73 | def __init__( | 82 | def __init__( |
74 | self, | 83 | self, |
75 | repodir=None, | 84 | repodir=None, |
@@ -242,9 +251,36 @@ class Command: | |||
242 | """Perform the action, after option parsing is complete.""" | 251 | """Perform the action, after option parsing is complete.""" |
243 | raise NotImplementedError | 252 | raise NotImplementedError |
244 | 253 | ||
245 | @staticmethod | 254 | @classmethod |
255 | @contextlib.contextmanager | ||
256 | def ParallelContext(cls): | ||
257 | """Obtains the context, which is shared to ExecuteInParallel workers. | ||
258 | |||
259 | Callers can store data in the context dict before invocation of | ||
260 | ExecuteInParallel. The dict will then be shared to child workers of | ||
261 | ExecuteInParallel. | ||
262 | """ | ||
263 | assert cls._parallel_context is None | ||
264 | cls._parallel_context = {} | ||
265 | try: | ||
266 | yield | ||
267 | finally: | ||
268 | cls._parallel_context = None | ||
269 | |||
270 | @classmethod | ||
271 | def _SetParallelContext(cls, context): | ||
272 | cls._parallel_context = context | ||
273 | |||
274 | @classmethod | ||
246 | def ExecuteInParallel( | 275 | def ExecuteInParallel( |
247 | jobs, func, inputs, callback, output=None, ordered=False | 276 | cls, |
277 | jobs, | ||
278 | func, | ||
279 | inputs, | ||
280 | callback, | ||
281 | output=None, | ||
282 | ordered=False, | ||
283 | chunksize=WORKER_BATCH_SIZE, | ||
248 | ): | 284 | ): |
249 | """Helper for managing parallel execution boiler plate. | 285 | """Helper for managing parallel execution boiler plate. |
250 | 286 | ||
@@ -269,6 +305,8 @@ class Command: | |||
269 | output: An output manager. May be progress.Progess or | 305 | output: An output manager. May be progress.Progess or |
270 | color.Coloring. | 306 | color.Coloring. |
271 | ordered: Whether the jobs should be processed in order. | 307 | ordered: Whether the jobs should be processed in order. |
308 | chunksize: The number of jobs processed in batch by parallel | ||
309 | workers. | ||
272 | 310 | ||
273 | Returns: | 311 | Returns: |
274 | The |callback| function's results are returned. | 312 | The |callback| function's results are returned. |
@@ -278,12 +316,16 @@ class Command: | |||
278 | if len(inputs) == 1 or jobs == 1: | 316 | if len(inputs) == 1 or jobs == 1: |
279 | return callback(None, output, (func(x) for x in inputs)) | 317 | return callback(None, output, (func(x) for x in inputs)) |
280 | else: | 318 | else: |
281 | with multiprocessing.Pool(jobs) as pool: | 319 | with multiprocessing.Pool( |
320 | jobs, | ||
321 | initializer=cls._SetParallelContext, | ||
322 | initargs=(cls._parallel_context,), | ||
323 | ) as pool: | ||
282 | submit = pool.imap if ordered else pool.imap_unordered | 324 | submit = pool.imap if ordered else pool.imap_unordered |
283 | return callback( | 325 | return callback( |
284 | pool, | 326 | pool, |
285 | output, | 327 | output, |
286 | submit(func, inputs, chunksize=WORKER_BATCH_SIZE), | 328 | submit(func, inputs, chunksize=chunksize), |
287 | ) | 329 | ) |
288 | finally: | 330 | finally: |
289 | if isinstance(output, progress.Progress): | 331 | if isinstance(output, progress.Progress): |