Source code for pypesto.engine.multi_thread

"""Engines with multi-threading parallelization."""

import copy
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union

from ..util import tqdm
from .base import Engine
from .task import Task

logger = logging.getLogger(__name__)


def work(task):
    """Execute task."""
    return task.execute()


[docs] class MultiThreadEngine(Engine): """ Parallelize the task execution using multithreading. Parameters ---------- n_threads: The maximum number of threads to use in parallel. Defaults to the number of CPUs available on the system according to `os.cpu_count()`. The effectively used number of threads will be the minimum of `n_threads` and the number of tasks submitted. """
[docs] def __init__(self, n_threads: Union[int, None] = None): super().__init__() if n_threads is None: n_threads = os.cpu_count() logger.info( f"Engine will use up to {n_threads} threads (= CPU count)." ) self.n_threads: int = n_threads
[docs] def execute( self, tasks: list[Task], progress_bar: bool = None ) -> list[Any]: """Deepcopy tasks and distribute work over parallel threads. Parameters ---------- tasks: List of tasks to execute. progress_bar: Whether to display a progress bar. Returns ------- A list of results. """ n_tasks = len(tasks) copied_tasks = [copy.deepcopy(task) for task in tasks] n_threads = min(self.n_threads, n_tasks) logger.debug(f"Parallelizing on {n_threads} threads.") with ThreadPoolExecutor(max_workers=n_threads) as pool: results = list( tqdm( pool.map(work, copied_tasks), total=len(copied_tasks), enable=progress_bar, ), ) return results