from __future__ import annotations
import abc
import math
import multiprocessing as mp
import os
import random
import signal
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Iterator, Optional
from tqdm.auto import tqdm
[docs]class Error(Enum):
"""Type of termination error."""
#: No error occurred during execution
NONE = auto()
#: A keyboard interrupt was received
TERMINATE = auto()
#: A task produced and undesired output
SYSTEM = auto()
class _PoolDummy:
def terminate(self):
pass
def join(self):
pass
def close(self):
pass
errored = Error.NONE
pool = _PoolDummy()
def _sigterm_handler_soft(*_):
global errored, pool
print("SIGTERM SOFT")
errored = Error.TERMINATE
pool.terminate()
raise KeyboardInterrupt
[docs]@dataclass(slots=True, order=True)
class TaskParameters:
"""Data class containing the parameters that should be used by a task.
Inherit this class to add parameters to your task.
"""
pass
@dataclass(slots=True)
class _PrivateTaskParameters:
f: Callable[[TaskParameters], Optional[TaskResult]]
p: TaskParameters
[docs]@dataclass(slots=True)
class TaskResult:
"""Data class containing the result of a task.
Inherit this class to add results to your task.
Example:
You can inherit this class as follows:
.. code-block:: python
:linenos:
from dataclass import dataclass
from slurmer import TaskResult
@dataclass
class MyTaskResult(TaskResult):
my_result: int
"""
pass
[docs]class TaskFailedError(Exception):
"""Exception thrown if one of the tasks fails."""
pass
[docs]class Task(abc.ABC):
"""Base class defining a set of Tasks to be done."""
def __init__(
self,
cluster_id: Optional[int] = None,
cluster_total: Optional[int] = None,
):
"""Initialize the class.
Args:
cluster_id: Id of the current node in the cluster, if None
the value is read from SLURM environment variable SLURM_ARRAY_TASK_ID. Defaults to
None.
cluster_total (Optional[int], optional): Number of allocated nodes in the cluster, if
None the value is read from the SLURM environment variable SLURM_ARRAY_TASK_MAX.
Defaults to None.
"""
self.cluster_id, self.cluster_total = Task.get_cluster_ids(cluster_id, cluster_total)
[docs] @abc.abstractmethod
def generate_parameters(self) -> Iterator[TaskParameters]:
"""Generate all the parameters for the different tasks to run.
This function **must be implemented**. It should return an iterator over all the parameters
that should be passed to the processor_function(). The number of tasks to run is determined
by the number of parameters returned by this method.
"""
pass
[docs] def make_dirs(self):
"""Make directories before execution.
If some directories need to be created before executing the tasks, inherit this method and
create the directories here. *Override this function to add some behaviour*.
"""
pass
[docs] @staticmethod
@abc.abstractmethod
def processor_function(parameters: TaskParameters) -> Optional[TaskResult]:
"""Execute the task.
This function **must be implemented**. Here you can run the task that you want to run.
It should return a TaskResult if the execution worked as expected (even if it's just an
empty object), None if the execution failed and execution should be terminated.
"""
pass
[docs] def process_output(self, result: TaskResult) -> bool:
"""Process the generated output.
Process the output generated with processor_function() and evaluate whether execution
should be terminated. Override this method if the default behaviour (return False) should
be changed.
Args:
result (TaskResult): Result of the task as returned by the processor_function()
Returns:
bool: True if the execution should be terminated, False otherwise.
"""
return False
[docs] def after_run(self):
"""Handle results after running tasks.
Override this method to handle results after everything has been run. Keep in mind that
this is run after all the tasks of this fold have been run but that in another node they
may still be running.
"""
pass
[docs] def key_interrupt(self):
"""Override this method to handle the status after a keyboard interrupt."""
pass
# Do not override anything past this point
[docs] @staticmethod
def get_cluster_ids(cluster_id: Optional[int], cluster_total: Optional[int]) -> tuple[int, int]:
"""Get the cluster id and the total number of nodes.
If the cluster_id and cluster_total are not set, they are obtained from the environment
variables passed by the SLURM task scheduler.
Args:
cluster_id (Optional[int]): ID of the current node.
cluster_total (Optional[int]): Total number of nodes nodes that we can use in the
cluster.
Returns:
tuple[int, int]: Tuple with current cluster ID and total number of nodes.
"""
if cluster_id is None:
cluster_id = int(os.getenv("SLURM_ARRAY_TASK_ID", 1)) - 1
if cluster_total is None:
cluster_total = int(os.getenv("SLURM_ARRAY_TASK_MAX", 1))
return cluster_id, cluster_total
@staticmethod
def _tasks_distribution(total_tasks: int, workers: int) -> list[tuple[int, int]]:
length = int(math.ceil(total_tasks / workers))
limit = total_tasks - (length - 1) * workers
task_list = []
prev_end = 0
for i in range(workers):
task_begin = prev_end
task_end = task_begin + length - (0 if i < limit else 1)
task_end = prev_end = min(task_end, total_tasks)
task_list.append((task_begin, task_end))
return task_list
def _obtain_current_fold(self):
params = sorted(set(self.generate_parameters()))
task_list = Task._tasks_distribution(len(params), self.cluster_total)
task_begin, task_end = task_list[self.cluster_id]
# Shuffle in order to have big tasks matched with small ones in order to save memory
random.seed(42)
random.shuffle(params)
if self.cluster_total > 1:
print(f"Running fold {self.cluster_id + 1} out of {self.cluster_total}")
print(
f"{task_begin} to {task_end} tasks will be run instead of the whole {len(params)}"
)
return params[task_begin:task_end]
@staticmethod
def _process(p: _PrivateTaskParameters) -> Optional[TaskResult]:
try:
return p.f(p.p)
except KeyboardInterrupt:
return None
[docs] def execute_tasks(
self,
make_dirs_only: bool = False,
debug: bool = False,
processes: int = None,
no_bar: bool = False,
description: str = "",
) -> Error:
"""Execute all the tasks.
The tasks are executed in a random order so do not rely on the order and use IDs instead.
Even if the order is random, it is consistent across nodes.
Args:
make_dirs_only (bool, optional): If True, only the directories will be created and
execution will return. Defaults to False.
debug (bool, optional): If True, the execution will be run in debug mode disabling all
the parallelism. Defaults to False.
processes (int, optional): Number of processes to use. Pass None to use as many as the
system has. Defaults to None.
no_bar (bool, optional): If True, the progress bar will not be displayed. Defaults to
False.
description (str, optional): Description of the task. Defaults to "".
Raises:
KeyboardInterrupt: If a keyboard interrupt is passed, all the tasks are stopped and a
KeyboardInterrupt is raised.
TaskFailedError: If one of the tasks returns a strange result then a TaskFailedError is
raised.
Returns:
Error: The termination result of the tasks. If everything is fine, the result is
Error.None.
"""
global errored, pool
params = self._obtain_current_fold()
params2 = (_PrivateTaskParameters(self.processor_function, p) for p in params)
self.make_dirs()
if make_dirs_only:
return Error.NONE
if len(params) <= 0:
print("WARNING: No tasks have to be done, are you sure you did everything right?")
signal_list = [signal.SIGINT, signal.SIGTERM]
signals = {s: signal.getsignal(s) for s in signal_list}
if debug:
pool = _PoolDummy()
generator = map(self._process, params2)
else:
pool = mp.Pool(processes=processes)
generator = pool.imap_unordered(self._process, params2)
errored = Error.NONE
for s in signal_list:
signal.signal(s, _sigterm_handler_soft)
try:
for output in tqdm(
generator,
total=len(params),
desc=description,
disable=(len(params) <= 0) or no_bar,
smoothing=0.02,
):
if output is None:
# This is a keyboard interrupt
errored = Error.TERMINATE
pool.terminate()
break
if self.process_output(output):
# This is caused by an actual error while running
errored = Error.SYSTEM
pool.terminate()
break
if errored != Error.NONE:
pool.terminate()
break
except KeyboardInterrupt:
errored = Error.TERMINATE
if errored != Error.NONE:
pool.terminate()
pool.join()
if errored == Error.TERMINATE:
self.key_interrupt()
raise KeyboardInterrupt
elif errored == Error.SYSTEM:
raise TaskFailedError("System returned non 0 exit code")
pool.close()
self.after_run()
for k, v in signals.items():
if v is None:
continue
signal.signal(k, v)
return errored