Skip to content

SOURCE CODE xqute.schedulers.container_scheduler DOCS

"""The scheduler to run jobs via containers"""
from __future__ import annotations

import os
import shlex
import shutil
import asyncio
from pathlib import Path
from typing import Dict, List, Sequence

from ..job import Job
from ..defaults import JOBCMD_WRAPPER_LANG
from ..path import SpecPath
from .local_scheduler import LocalScheduler


DEFAULT_MOUNTED_WORKDIR = "/mnt/disks/xqute_workdir"
CONTAINER_TYPES = {
    "docker": "docker",
    "podman": "podman",
    "apptainer": "apptainer",
    "singularity": "apptainer",
}


class ContainerScheduler(LocalScheduler):DOCS
    """Scheduler to run jobs via containers (Docker/Podman/Apptainer)

    This scheduler can execute jobs inside containers using Docker, Podman,
    or Apptainer.

    Args:
        image: Container image to use for running jobs
        entrypoint: Entrypoint command for the container
        bin: Path to container runtime binary (e.g. /path/to/docker)
        volumes: host:container volume mapping string or strings
        envs: Environment variables to set in container
        user: User to run the container as (only for Docker/Podman)
            By default, it runs as the current user (os.getuid() and os.getgid())
        remove: Whether to remove the container after execution.
            Only applies to Docker/Podman.
        bin_args: Additional arguments to pass to the container runtime
        **kwargs: Additional arguments passed to parent Scheduler
    """

    name = "container"

    __slots__ = (
        "image",
        "entrypoint",
        "bin",
        "volumes",
        "envs",
        "remove",
        "user",
        "bin_args",
        "_container_type",
    )

    def __init__(
        self,
        image: str,
        entrypoint: str | List[str] = JOBCMD_WRAPPER_LANG,
        bin: str = "docker",
        volumes: str | Sequence[str] | None = None,
        envs: Dict[str, str] | None = None,
        remove: bool = True,
        user: str | None = None,
        bin_args: List[str] | None = None,
        **kwargs
    ):
        kwargs.setdefault("mounted_workdir", DEFAULT_MOUNTED_WORKDIR)
        super().__init__(**kwargs)

        self.bin = shutil.which(bin)
        if not self.bin:
            raise ValueError(
                f"Container runtime binary '{bin}' not found in PATH"
            )

        self.image = image
        self.entrypoint = (
            list(entrypoint)
            if isinstance(entrypoint, (list, tuple))
            else [entrypoint]
        )
        self.volumes = volumes or []
        self.volumes = (
            [self.volumes] if isinstance(self.volumes, str) else list(self.volumes)
        )
        self.envs = envs or {}
        self.remove = remove
        self.user = user or f"{os.getuid()}:{os.getgid()}"
        self.bin_args = bin_args or []
        self.volumes.append(f"{self.workdir}:{self.workdir.mounted}")

        self._container_type = CONTAINER_TYPES.get(
            Path(self.bin).name.lower(),
            "docker",
        )
        if (
            self._container_type in ("docker", "podman")
            and self.image.startswith("docker://")
        ):
            # Convert docker://image to image name
            self.image = self.image[9:]

    def wrapped_job_script(self, job: Job) -> SpecPath:DOCS
        """Get the wrapped job script

        Args:
            job: The job

        Returns:
            The path of the wrapped job script
        """
        base = f"job.wrapped.{self._container_type}"
        wrapt_script = job.metadir / base
        wrapt_script.write_text(self.wrap_job_script(job))

        return wrapt_script

    def jobcmd_shebang(self, job: Job) -> str:DOCS
        """The shebang of the wrapper script"""
        cmd = [self.bin, "run"]
        if self._container_type == "apptainer":
            cmd.extend(["--pwd", str(self.workdir.mounted)])
            for key, value in self.envs.items():
                cmd.extend(["--env", f"{key}={value}"])
            for vol in self.volumes:
                cmd.extend(["--bind", f"{vol}"])
        else:
            if self.remove:
                cmd.append("--rm")
            cmd.extend(["--user", self.user])
            for key, value in self.envs.items():
                cmd.extend(["-e", f"{key}={value}"])
            for vol in self.volumes:
                cmd.extend(["-v", vol])
            cmd.extend(["--workdir", str(self.workdir.mounted)])

        cmd.extend(self.bin_args)
        cmd.append(self.image)
        cmd.extend(self.entrypoint)

        return shlex.join(cmd)

    async def submit_job(self, job: Job) -> int:DOCS
        """Submit a job locally

        Args:
            job: The job

        Returns:
            The process id
        """
        proc = await asyncio.create_subprocess_exec(
            *shlex.split(self.jobcmd_shebang(job)),
            str(self.wrapped_job_script(job).mounted),
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
        # wait for a while to make sure the process is running
        # this is to avoid the real command is not run when proc is recycled too early
        # this happens for python < 3.12
        while not job.stdout_file.exists() and not job.stderr_file.exists():
            if proc.returncode is not None:
                # The process has already finished and no stdout/stderr files are
                # generated
                # Something went wrong with the wrapper script?
                stderr = await proc.stderr.read()
                raise RuntimeError(
                    f"Failed to submit job #{job.index}: {stderr.decode()}\n"
                    f"Command: {self.jobcmd_shebang(job)} "
                    f"{self.wrapped_job_script(job).mounted}\n"
                )

            await asyncio.sleep(0.1)

        # don't await for the results, as this will run the real command
        return proc.pid