Skip to content

SOURCE CODE xqute.xqute DOCS

"""The xqute module"""
from __future__ import annotations

import asyncio
import functools
import signal
from collections import deque
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Type

from .defaults import (
    DEFAULT_JOB_METADIR,
    DEFAULT_JOB_ERROR_STRATEGY,
    DEFAULT_JOB_NUM_RETRIES,
    DEFAULT_SCHEDULER_FORKS,
    DEFAULT_JOB_SUBMISSION_BATCH,
    JobErrorStrategy,
    JobStatus,
)
from .utils import logger
from .plugin import plugin
from .schedulers import get_scheduler

if TYPE_CHECKING:  # pragma: no cover
    from .scheduler import Scheduler
    from .job import Job


class Xqute:DOCS
    """The main class of the package

    Attributes:
        name: The name, used in logger
        EMPTY_BUFFER_SLEEP_TIME: The time to sleep while waiting when
            the buffer is empty

        jobs: The jobs registry
        plugins: The plugins to be enabled or disabled
            to disable a plugin, using `-plugin_name`
            either all plugin names should be prefixed with '+'/'-' or none
            of them should
        job_submission_batch: The number of consumers to submit jobs

        _job_metadir: The job meta directory
        _job_error_strategy: The strategy when there is error happened
        _job_num_retries: The number of retries when strategy is retry

        _cancelling: A mark to mark whether a shutting down event
            is triggered (True for natual cancelling, the signale for
            cancelling with a signal, SIGINT for example)

        buffer_queue: A buffer queue to save the pushed jobs
        queue: The job queue
        scheduler: The scheduler
        task: The task of producer and consumers

    Args:
        scheduler: The scheduler class or name
        plugins: The plugins to be enabled or disabled
            to disable a plugin, using `-plugin_name`
            either all plugin names should be prefixed with '+'/'-' or none
            of them should
            To enabled plugins, objects are
        job_metadir: The job meta directory
        job_submission_batch: The number of consumers to submit jobs
        job_error_strategy: The strategy when there is error happened
        job_num_retries: Max number of retries when job_error_strategy is retry
        scheduler_forks: Max number of job forks
        **scheduler_opts: Additional keyword arguments for scheduler
    """

    name: str = "Xqute"
    EMPTY_BUFFER_SLEEP_TIME: int = 1

    def __init__(
        self,
        scheduler: str | Type[Scheduler] = "local",
        plugins: List[Any] | None = None,
        *,
        job_metadir: PathLike = DEFAULT_JOB_METADIR,
        job_submission_batch: int = DEFAULT_JOB_SUBMISSION_BATCH,
        job_error_strategy: str = DEFAULT_JOB_ERROR_STRATEGY,
        job_num_retries: int = DEFAULT_JOB_NUM_RETRIES,
        scheduler_forks: int = DEFAULT_SCHEDULER_FORKS,
        scheduler_prescript: str = "",
        scheduler_postscript: str = "",
        **scheduler_opts,
    ) -> None:
        """Construct"""
        self.jobs: List[Job] = []

        self.plugin_context = plugin.plugins_context(plugins)

        self.plugin_context.__enter__()

        logger.info(
            "/%s Enabled plugins: %s",
            self.name,
            plugin.get_enabled_plugin_names(),
        )

        self.job_submission_batch = job_submission_batch
        self._job_metadir = job_metadir
        Path(self._job_metadir).mkdir(exist_ok=True)
        self._job_error_strategy = job_error_strategy
        self._job_num_retries = job_num_retries

        self._cancelling: bool | signal.Signals = False

        self.buffer_queue: deque = deque()
        self.queue: asyncio.Queue = asyncio.Queue()
        self.scheduler = get_scheduler(scheduler)(
            scheduler_forks,
            scheduler_prescript,
            scheduler_postscript,
            **scheduler_opts,
        )

        # requires to be defined in a loop
        loop = asyncio.get_running_loop()
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, functools.partial(self.cancel, sig))

        self.task = asyncio.gather(
            self._producer(),
            *(self._consumer(i) for i in range(self.job_submission_batch)),
        )
        plugin.hooks.on_init(self)

    def cancel(self, sig: signal.Signals | None = None) -> None:DOCS
        """Cancel the producer-consumer task

        `self._cancelling` will be set to `signaled` if sig is provided,
        otherwise it will be set to `True`

        Args:
            sig: Whether this cancelling is caused by a signal
        """
        self._cancelling = True
        if sig:
            self._cancelling = sig
            logger.warning(
                "/%s Got signal %r, trying a graceful " "shutdown ...",
                self.name,
                sig.name,
            )

        if plugin.hooks.on_shutdown(self, sig) is not False:
            self.task.cancel()

    async def _producer(self):
        """The producer"""

        while True:
            if not self.buffer_queue:
                logger.debug(
                    "/%s Buffer queue is empty, waiting ...", self.name
                )
                await asyncio.sleep(self.EMPTY_BUFFER_SLEEP_TIME)
                continue

            job = self.buffer_queue.popleft()
            if not await self.scheduler.polling_jobs(
                self.jobs,
                "can_submit",
                self._job_error_strategy == JobErrorStrategy.HALT,
            ):
                logger.debug("/%s Hit max forks of scheduler ...", self.name)
                await asyncio.sleep(0.1)
                self.buffer_queue.appendleft(job)
                continue

            job.status = JobStatus.QUEUED
            await self.queue.put(job)

    async def _consumer(self, index: int) -> None:
        """The consumer

        Args:
            index: The index of the consumer
        """
        while True:
            job = await self.queue.get()
            logger.debug(
                "/%s 'Consumer-%s' submitting %s", self.name, index, job
            )
            await self.scheduler.submit_job_and_update_status(job)
            self.queue.task_done()

    async def put(self, cmd: Job | str | List[str]) -> None:DOCS
        """Put a command into the buffer

        Args:
            cmd: The command
        """
        if isinstance(cmd, self.scheduler.job_class):
            job = cmd
            if job._error_retry is None:
                job._error_retry = (
                    self._job_error_strategy == JobErrorStrategy.RETRY
                )
            if job._num_retries is None:
                job._num_retries = self._job_num_retries
        else:
            job = self.scheduler.job_class(
                len(self.jobs),
                cmd,  # type: ignore
                self._job_metadir,
                self._job_error_strategy == JobErrorStrategy.RETRY,
                self._job_num_retries,
            )
        await plugin.hooks.on_job_init(self.scheduler, job)
        self.jobs.append(job)
        logger.info("/%s Pushing job: %r", self.name, job)

        self.buffer_queue.append(job)
        await plugin.hooks.on_job_queued(self.scheduler, job)

    async def _polling_jobs(self) -> None:
        """Polling the jobs to see if they are all done.

        If yes, cancel the producer-consumer task natually.
        """
        while (
            self._cancelling is False
            and not await self.scheduler.polling_jobs(
                self.jobs,
                "all_done",
                self._job_error_strategy == JobErrorStrategy.HALT,
            )
        ):
            await asyncio.sleep(1.0)

        if self._cancelling is False:
            self.cancel()

    async def _await_task(self) -> None:
        """Await the producer-consumer task and catch the CancelledError"""
        try:
            await self.task
        except asyncio.CancelledError:
            logger.debug("/%s Stoping producer and consumer ...", self.name)
            if self._cancelling not in (True, False):  # signaled
                await self.scheduler.kill_running_jobs(self.jobs)

        logger.info("/%s Done!", self.name)

    async def run_until_complete(self) -> None:DOCS
        """Wait until all jobs complete"""
        logger.debug(
            "/%s Done feeding jobs, waiting for jobs to be done ...", self.name
        )
        try:
            await asyncio.gather(self._polling_jobs(), self._await_task())
        finally:
            if self.plugin_context:
                self.plugin_context.__exit__()