Skip to content

SOURCE CODE xqute.schedulers.gbatch_scheduler DOCS

import asyncio
import json
import re
import shlex
from copy import deepcopy
from hashlib import sha256
from yunpath import GSPath
from diot import Diot

from ..job import Job
from ..scheduler import Scheduler
from ..defaults import JOBCMD_WRAPPER_LANG, get_jobcmd_wrapper_init
from ..utils import logger
from ..path import SpecPath


JOBNAME_PREFIX_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{0,47}$")
DEFAULT_MOUNTED_WORKDIR = "/mnt/xqute_workdir"


class GbatchScheduler(Scheduler):DOCS
    """Scheduler for Google Cloud Batch"""

    name = "gbatch"

    __slots__ = Scheduler.__slots__ + (
        "gcloud",
        "project",
        "location",
    )

    def __init__(self, *args, project: str, location: str, **kwargs):
        """Construct the gbatch scheduler"""
        self.gcloud = kwargs.pop("gcloud", "gcloud")
        self.project = project
        self.location = location
        kwargs.setdefault("mounted_workdir", DEFAULT_MOUNTED_WORKDIR)
        super().__init__(*args, **kwargs)

        if not isinstance(self.workdir, GSPath):
            raise ValueError(
                "'gbatch' scheduler requires google cloud storage 'workdir'."
            )

        if not JOBNAME_PREFIX_RE.match(self.jobname_prefix):
            raise ValueError(
                "'jobname_prefix' for gbatch scheduler doesn't follow pattern "
                "^[a-zA-Z][a-zA-Z0-9-]{0,47}$."
            )

        self.config.setdefault("taskGroups", [])
        if not self.config.taskGroups:
            self.config.taskGroups.append(Diot())
        if not self.config.taskGroups[0]:
            self.config.taskGroups[0] = Diot()

        self.config.taskGroups[0].setdefault("taskSpec", Diot())
        self.config.taskGroups[0].taskSpec.setdefault("runnables", [])
        if not self.config.taskGroups[0].taskSpec.runnables:
            self.config.taskGroups[0].taskSpec.runnables.append(Diot())
        if not self.config.taskGroups[0].taskSpec.runnables[0]:
            self.config.taskGroups[0].taskSpec.runnables[0] = Diot()
        self.config.taskGroups[0].taskSpec.runnables[0].script = Diot(
            text=None  # placeholder for job command
        )
        # Only logs the stdout/stderr of submission (when wrapped script doesn't run)
        # The logs of the wrapped script are logged to stdout/stderr files
        # in the workdir.
        self.config.setdefault("logsPolicy", Diot())
        self.config.logsPolicy.setdefault("destination", "CLOUD_LOGGING")

        self.config.taskGroups[0].taskSpec.setdefault("volumes", [])
        if not isinstance(self.config.taskGroups[0].taskSpec.volumes, list):
            raise ValueError(
                "'taskGroups[0].taskSpec.volumes' should be a list for "
                "gbatch configuration."
            )

        meta_volume = Diot()
        meta_volume.gcs = Diot(remotePath=self.workdir._no_prefix)
        meta_volume.mountPath = str(self.workdir.mounted)

        self.config.taskGroups[0].taskSpec.volumes.insert(0, meta_volume)

    @property
    def jobcmd_wrapper_init(self) -> str:
        return get_jobcmd_wrapper_init(True, self.remove_jid_after_done)

    def job_config_file(self, job: Job) -> SpecPath:
        base = f"job.wrapped.{self.name}.json"
        conf_file = job.metadir / base

        wrapt_script = self.wrapped_job_script(job)
        config = deepcopy(self.config)
        config.taskGroups[0].taskSpec.runnables[0].script.text = shlex.join(
            shlex.split(JOBCMD_WRAPPER_LANG) + [str(wrapt_script.mounted)]
        )
        with conf_file.open("w") as f:
            json.dump(config, f, indent=2)

        return conf_file

    async def _delete_job(self, job: Job) -> None:
        """Try to delete the job from google cloud's registry

        As google doesn't allow jobs to have the same id.

        Args:
            job: The job to delete
        """
        logger.debug(
            "/Scheduler-%s Try deleting job %r on GCP.",
            self.name,
            job,
        )
        status = await self._get_job_status(job)
        while status.endswith("_IN_PROGRESS"):  # pragma: no cover
            await asyncio.sleep(1)
            status = await self._get_job_status(job)

        command = [
            self.gcloud,
            "batch",
            "jobs",
            "delete",
            job.jid,
            "--project",
            self.project,
            "--location",
            self.location,
        ]

        try:
            proc = await asyncio.create_subprocess_exec(
                *command,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
        except Exception:
            pass
        else:  # pragma: no cover
            await proc.wait()

        status = await self._get_job_status(job)
        while status == "DELETION_IN_PROGRESS":  # pragma: no cover
            await asyncio.sleep(1)
            status = await self._get_job_status(job)

        if status != "UNKNOWN":  # pragma: no cover
            logger.warning(
                "/Scheduler-%s Failed to delete job %r on GCP, submision may fail.",
                self.name,
                job,
            )

    async def submit_job(self, job: Job) -> str:DOCS

        sha = sha256(str(self.workdir).encode()).hexdigest()[:8]
        job.jid = f"{self.jobname_prefix}-{sha}-{job.index}".lower()
        await self._delete_job(job)

        conf_file = self.job_config_file(job)
        proc = await asyncio.create_subprocess_exec(
            self.gcloud,
            "batch",
            "jobs",
            "submit",
            job.jid,
            "--config",
            conf_file.fspath,
            "--project",
            self.project,
            "--location",
            self.location,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )

        _, stderr = await proc.communicate()
        if proc.returncode != 0:  # pragma: no cover
            raise RuntimeError(
                "Can't submit job to Google Cloud Batch: \n"
                f"{stderr.decode()}\n"
                "Check the configuration file:\n"
                f"{conf_file}"
            )

        return job.jid

    async def kill_job(self, job: Job):DOCS
        command = [
            self.gcloud,
            "alpha",
            "batch",
            "jobs",
            "cancel",
            job.jid,
            "--project",
            self.project,
            "--location",
            self.location,
            "--quiet",
        ]
        proc = await asyncio.create_subprocess_exec(
            *command,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
        await proc.wait()

    async def _get_job_status(self, job: Job) -> str:
        if not job.jid_file.is_file():
            return "UNKNOWN"

        # Do not rely on _jid, as it can be a obolete job.
        jid = job.jid_file.read_text().strip()

        command = [
            self.gcloud,
            "batch",
            "jobs",
            "describe",
            jid,
            "--project",
            self.project,
            "--location",
            self.location,
        ]

        try:
            proc = await asyncio.create_subprocess_exec(
                *command,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
        except Exception:  # pragma: no cover
            return "UNKNOWN"

        if await proc.wait() != 0:
            return "UNKNOWN"

        stdout = (await proc.stdout.read()).decode()
        return re.search(r"state: (.+)", stdout).group(1)

    async def job_is_running(self, job: Job) -> bool:DOCS
        status = await self._get_job_status(job)
        return status in ("RUNNING", "QUEUED", "SCHEDULED")