import asyncio
import json
import re
import shlex
from copy import deepcopy
from hashlib import sha256
from yunpath import GSPath
from diot import Diot
from ..job import Job
from ..scheduler import Scheduler
from ..defaults import JOBCMD_WRAPPER_LANG, get_jobcmd_wrapper_init
from ..utils import logger
from ..path import SpecPath
JOBNAME_PREFIX_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{0,47}$")
DEFAULT_MOUNTED_WORKDIR = "/mnt/disks/xqute_workdir"
class GbatchScheduler(Scheduler):DOCS
"""Scheduler for Google Cloud Batch
You can pass extra configuration parameters to the constructor
that will be used in the job configuration file.
For example, you can pass `taskGroups` to specify the task groups
and their specifications.
For using containers, it is a little bit tricky to specify the commands.
When no `entrypoint` is specified, the `commands` should be a list
with the first element being the interpreter (e.g. `/bin/bash`)
and the second element being the path to the wrapped job script.
If the `entrypoint` is specified, we can use the `{lang}` and `{script}`
placeholders in the `commands` list, where `{lang}` will be replaced
with the interpreter (e.g. `/bin/bash`) and `{script}` will be replaced
with the path to the wrapped job script.
With `entrypoint` specified and no `{script}` placeholder, the joined command
will be the interpreter followed by the path to the wrapped job script will be
appended to the `commands` list.
"""
name = "gbatch"
__slots__ = Scheduler.__slots__ + (
"gcloud",
"project",
"location",
)
def __init__(self, *args, project: str, location: str, **kwargs):
"""Construct the gbatch scheduler"""
self.gcloud = kwargs.pop("gcloud", "gcloud")
self.project = project
self.location = location
kwargs.setdefault("mounted_workdir", DEFAULT_MOUNTED_WORKDIR)
super().__init__(*args, **kwargs)
if not isinstance(self.workdir, GSPath):
raise ValueError(
"'gbatch' scheduler requires google cloud storage 'workdir'."
)
if not JOBNAME_PREFIX_RE.match(self.jobname_prefix):
raise ValueError(
"'jobname_prefix' for gbatch scheduler doesn't follow pattern "
"^[a-zA-Z][a-zA-Z0-9-]{0,47}$."
)
task_groups = self.config.setdefault("taskGroups", [])
if not task_groups:
task_groups.append(Diot())
if not task_groups[0]:
task_groups[0] = Diot()
task_spec = task_groups[0].setdefault("taskSpec", Diot())
runnables = task_spec.setdefault("runnables", [])
if not runnables:
runnables.append(Diot())
if not runnables[0]:
runnables[0] = Diot()
if "container" in runnables[0]:
if not isinstance(runnables[0].container, dict): # pragma: no cover
raise ValueError(
"'taskGroups[0].taskSpec.runnables[0].container' should be a "
"dictionary for gbatch configuration."
)
runnables[0].container.setdefault("commands", [])
else:
runnables[0].script = Diot(text=None) # placeholder for job command
# Only logs the stdout/stderr of submission (when wrapped script doesn't run)
# The logs of the wrapped script are logged to stdout/stderr files
# in the workdir.
logs_policy = self.config.setdefault("logsPolicy", Diot())
logs_policy.setdefault("destination", "CLOUD_LOGGING")
volumes = task_spec.setdefault("volumes", [])
if not isinstance(volumes, list):
raise ValueError(
"'taskGroups[0].taskSpec.volumes' should be a list for "
"gbatch configuration."
)
meta_volume = Diot()
meta_volume.gcs = Diot(remotePath=self.workdir._no_prefix)
meta_volume.mountPath = str(self.workdir.mounted)
volumes.insert(0, meta_volume)
@property
def jobcmd_wrapper_init(self) -> str:
return get_jobcmd_wrapper_init(True, self.remove_jid_after_done)
def job_config_file(self, job: Job) -> SpecPath:
base = f"job.wrapped.{self.name}.json"
conf_file = job.metadir / base
wrapt_script = self.wrapped_job_script(job)
config = deepcopy(self.config)
runnable = config.taskGroups[0].taskSpec.runnables[0]
if "container" in runnable:
container = runnable.container
if "entrypoint" not in container:
# supports only /bin/bash, but not /bin/bash -u
container.entrypoint = JOBCMD_WRAPPER_LANG
container.commands.append(str(wrapt_script.mounted))
elif any("{script}" in cmd for cmd in container.commands):
# If the entrypoint is already set, we assume it is a script
# that will be executed with the job command.
container.commands = [
cmd
.replace("{lang}", str(JOBCMD_WRAPPER_LANG))
.replace("{script}", str(wrapt_script.mounted))
for cmd in container.commands
]
else:
container.commands.append(
shlex.join(
shlex.split(JOBCMD_WRAPPER_LANG) + [str(wrapt_script.mounted)]
)
)
else:
runnable.script.text = shlex.join(
shlex.split(JOBCMD_WRAPPER_LANG) + [str(wrapt_script.mounted)]
)
with conf_file.open("w") as f:
json.dump(config, f, indent=2)
return conf_file
async def _delete_job(self, job: Job) -> None:
"""Try to delete the job from google cloud's registry
As google doesn't allow jobs to have the same id.
Args:
job: The job to delete
"""
logger.debug(
"/Scheduler-%s Try deleting job %r on GCP.",
self.name,
job,
)
status = await self._get_job_status(job)
while status.endswith("_IN_PROGRESS"): # pragma: no cover
await asyncio.sleep(1)
status = await self._get_job_status(job)
command = [
self.gcloud,
"batch",
"jobs",
"delete",
job.jid,
"--project",
self.project,
"--location",
self.location,
]
try:
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except Exception:
pass
else: # pragma: no cover
await proc.wait()
status = await self._get_job_status(job)
while status == "DELETION_IN_PROGRESS": # pragma: no cover
await asyncio.sleep(1)
status = await self._get_job_status(job)
if status != "UNKNOWN": # pragma: no cover
logger.warning(
"/Scheduler-%s Failed to delete job %r on GCP, submision may fail.",
self.name,
job,
)
async def submit_job(self, job: Job) -> str:DOCS
sha = sha256(str(self.workdir).encode()).hexdigest()[:8]
job.jid = f"{self.jobname_prefix}-{sha}-{job.index}".lower()
await self._delete_job(job)
conf_file = self.job_config_file(job)
proc = await asyncio.create_subprocess_exec(
self.gcloud,
"batch",
"jobs",
"submit",
job.jid,
"--config",
conf_file.fspath,
"--project",
self.project,
"--location",
self.location,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0: # pragma: no cover
raise RuntimeError(
"Can't submit job to Google Cloud Batch: \n"
f"{stderr.decode()}\n"
"Check the configuration file:\n"
f"{conf_file}"
)
return job.jid
async def kill_job(self, job: Job):DOCS
command = [
self.gcloud,
"alpha",
"batch",
"jobs",
"cancel",
job.jid,
"--project",
self.project,
"--location",
self.location,
"--quiet",
]
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.wait()
async def _get_job_status(self, job: Job) -> str:
if not job.jid_file.is_file():
return "UNKNOWN"
# Do not rely on _jid, as it can be a obolete job.
jid = job.jid_file.read_text().strip()
command = [
self.gcloud,
"batch",
"jobs",
"describe",
jid,
"--project",
self.project,
"--location",
self.location,
]
try:
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except Exception: # pragma: no cover
return "UNKNOWN"
if await proc.wait() != 0:
return "UNKNOWN"
stdout = (await proc.stdout.read()).decode()
return re.search(r"state: (.+)", stdout).group(1)
async def job_is_running(self, job: Job) -> bool:DOCS
status = await self._get_job_status(job)
return status in ("RUNNING", "QUEUED", "SCHEDULED")