Source code for pypesto.engine.multi_thread

"""Engines with multi-threading parallelization."""
import copy
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List

from tqdm import tqdm

from .base import Engine
from .task import Task

logger = logging.getLogger(__name__)


def work(task):
    """Just execute task."""
    return task.execute()


[docs]class MultiThreadEngine(Engine): """ Parallelize the task execution using multithreading. Parameters ---------- n_threads: The maximum number of threads to use in parallel. Defaults to the number of CPUs available on the system according to `os.cpu_count()`. The effectively used number of threads will be the minimum of `n_threads` and the number of tasks submitted. """
[docs] def __init__(self, n_threads: int = None): super().__init__() if n_threads is None: n_threads = os.cpu_count() logger.info( f"Engine will use up to {n_threads} threads (= CPU count)." ) self.n_threads: int = n_threads
[docs] def execute( self, tasks: List[Task], progress_bar: bool = True ) -> List[Any]: """Deepcopy tasks and distribute work over parallel threads. Parameters ---------- tasks: List of tasks to execute. progress_bar: Whether to display a progress bar. """ n_tasks = len(tasks) copied_tasks = [copy.deepcopy(task) for task in tasks] n_threads = min(self.n_threads, n_tasks) logger.debug(f"Parallelizing on {n_threads} threads.") with ThreadPoolExecutor(max_workers=n_threads) as pool: results = list( tqdm( pool.map(work, copied_tasks), total=len(copied_tasks), disable=not progress_bar, ), ) return results