from __future__ import annotations
import asyncio
import json
import re
import shlex
import getpass
from typing import Sequence
from copy import deepcopy
from hashlib import sha256
from yunpath import GSPath, AnyPath
from ..job import Job
from ..scheduler import Scheduler
from ..defaults import JOBCMD_WRAPPER_LANG
from ..utils import logger
from ..path import SpecPath
JOBNAME_PREFIX_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{0,47}$")
NAMED_MOUNT_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*=.+$")
DEFAULT_MOUNTED_ROOT = "/mnt/disks"
class GbatchScheduler(Scheduler):DOCS
"""Scheduler for Google Cloud Batch
You can pass extra configuration parameters to the constructor
that will be used in the job configuration file.
For example, you can pass `taskGroups` to specify the task groups
and their specifications.
For using containers, it is a little bit tricky to specify the commands.
When no `entrypoint` is specified, the `commands` should be a list
with the first element being the interpreter (e.g. `/bin/bash`)
and the second element being the path to the wrapped job script.
If the `entrypoint` is specified, we can use the `{lang}` and `{script}`
placeholders in the `commands` list, where `{lang}` will be replaced
with the interpreter (e.g. `/bin/bash`) and `{script}` will be replaced
with the path to the wrapped job script.
With `entrypoint` specified and no `{script}` placeholder, the joined command
will be the interpreter followed by the path to the wrapped job script will be
appended to the `commands` list.
Args:
project: GCP project ID
location: GCP location (e.g. us-central1)
mount: GCS path to mount (e.g. gs://my-bucket:/mnt/my-bucket)
You can pass a list of mounts.
You can also use named mount like `NAME=gs://bucket/dir`
then it will be mounted to `/mnt/disks/NAME` in the container.
You can use environment variable `NAME` in your job scripts to
refer to the mounted path.
service_account: GCP service account email (e.g. test-account@example.com)
network: GCP network (e.g. default-network)
subnetwork: GCP subnetwork (e.g. regions/us-central1/subnetworks/default)
no_external_ip_address: Whether to disable external IP address
machine_type: GCP machine type (e.g. e2-standard-4)
provisioning_model: GCP provisioning model (e.g. SPOT)
image_uri: Container image URI (e.g. ubuntu-2004-lts)
entrypoint: Container entrypoint (e.g. /bin/bash)
commands: The command list to run in the container.
There are three ways to specify the commands:
1. If no entrypoint is specified, the final command will be
[commands, wrapped_script], where the entrypoint is the wrapper script
interpreter that is determined by `JOBCMD_WRAPPER_LANG` (e.g. /bin/bash),
commands is the list you provided, and wrapped_script is the path to the
wrapped job script.
2. You can specify something like "-c", then the final command
will be ["-c", "wrapper_script_interpreter, wrapper_script"]
3. You can use the placeholders `{lang}` and `{script}` in the commands
list, where `{lang}` will be replaced with the interpreter (e.g. /bin/bash)
and `{script}` will be replaced with the path to the wrapped job script.
For example, you can specify ["{lang} {script}"] and the final command
will be ["wrapper_interpreter, wrapper_script"]
runnables: Additional runnables to run before or after the main job.
Each runnable should be a dictionary that follows the
[GCP Batch API specification](https://cloud.google.com/batch/docs/reference/rest/v1/projects.locations.jobs#runnable).
You can also specify an "order" key in the dictionary to control the
execution order of the runnables. Runnables with negative order
will be executed before the main job, and those with non-negative
order will be executed after the main job. The main job runnable
will always be executed in the order it is defined in the list.
*args, **kwargs: Other arguments passed to base Scheduler class
""" # noqa: E501
name = "gbatch"
__slots__ = Scheduler.__slots__ + (
"gcloud",
"project",
"location",
"runnable_index",
"_path_envs",
)
def __init__(
self,
*args,
project: str,
location: str,
mount: str | Sequence[str] | None = None,
service_account: str | None = None,
network: str | None = None,
subnetwork: str | None = None,
no_external_ip_address: bool | None = None,
machine_type: str | None = None,
provisioning_model: str | None = None,
image_uri: str | None = None,
entrypoint: str = None,
commands: str | Sequence[str] | None = None,
runnables: Sequence[dict] | None = None,
**kwargs,
):
"""Construct the gbatch scheduler"""
self.gcloud = kwargs.pop("gcloud", "gcloud")
self.project = project
self.location = location
kwargs.setdefault("mounted_workdir", f"{DEFAULT_MOUNTED_ROOT}/xqute_workdir")
super().__init__(*args, **kwargs)
if not isinstance(self.workdir, GSPath):
raise ValueError(
"'gbatch' scheduler requires google cloud storage 'workdir'."
)
if not JOBNAME_PREFIX_RE.match(self.jobname_prefix):
raise ValueError(
"'jobname_prefix' for gbatch scheduler doesn't follow pattern "
"^[a-zA-Z][a-zA-Z0-9-]{0,47}$."
)
self._path_envs = {}
task_groups = self.config.setdefault("taskGroups", [])
if not task_groups:
task_groups.append({})
if not task_groups[0]:
task_groups[0] = {}
task_spec = task_groups[0].setdefault("taskSpec", {})
task_runnables = task_spec.setdefault("runnables", [])
# Process additional runnables with ordering
additional_runnables = []
if runnables:
for runnable_dict in runnables:
runnable_copy = deepcopy(runnable_dict)
order = runnable_copy.pop("order", 0)
additional_runnables.append((order, runnable_copy))
# Sort by order
additional_runnables.sort(key=lambda x: x[0])
# Create main job runnable
if not task_runnables:
task_runnables.append({})
if not task_runnables[0]:
task_runnables[0] = {}
job_runnable = task_runnables[0]
if "container" in job_runnable or image_uri:
job_runnable.setdefault("container", {})
if not isinstance(job_runnable["container"], dict): # pragma: no cover
raise ValueError(
"'taskGroups[0].taskSpec.runnables[0].container' should be a "
"dictionary for gbatch configuration."
)
if image_uri:
job_runnable["container"].setdefault("image_uri", image_uri)
if entrypoint:
job_runnable["container"].setdefault("entrypoint", entrypoint)
job_runnable["container"].setdefault("commands", commands or [])
else:
job_runnable["script"] = {
"text": None, # placeholder for job command
"_commands": commands, # Store commands for later use
}
# Clear existing runnables and rebuild with proper ordering
task_runnables.clear()
# Add runnables with negative order (before job)
for order, runnable_dict in additional_runnables:
if order < 0:
task_runnables.append(runnable_dict)
# Add the main job runnable
task_runnables.append(job_runnable)
self.runnable_index = len(task_runnables) - 1
# Add runnables with positive order (after job)
for order, runnable_dict in additional_runnables:
if order >= 0:
task_runnables.append(runnable_dict)
# Only logs the stdout/stderr of submission (when wrapped script doesn't run)
# The logs of the wrapped script are logged to stdout/stderr files
# in the workdir.
logs_policy = self.config.setdefault("logsPolicy", {})
logs_policy.setdefault("destination", "CLOUD_LOGGING")
volumes = task_spec.setdefault("volumes", [])
if not isinstance(volumes, list):
raise ValueError(
"'taskGroups[0].taskSpec.volumes' should be a list for "
"gbatch configuration."
)
volumes.insert(
0,
{
"gcs": {"remotePath": self.workdir._no_prefix},
"mountPath": str(self.workdir.mounted),
},
)
if mount and not isinstance(mount, (tuple, list)):
mount = [mount]
if mount:
for m in mount:
# Let's check if mount is provided as "OUTDIR=gs://bucket/dir"
# If so, we mounted it to $DEFAULT_MOUNTED_ROOT/OUTDIR
# and set OUTDIR env variable to the mounted path in self._path_envs
if NAMED_MOUNT_RE.match(m):
name, gcs = m.split("=", 1)
if not gcs.startswith("gs://"):
raise ValueError(
"When using named mount, it should be in the format "
"'NAME=gs://bucket/dir', where NAME matches "
"^[A-Za-z][A-Za-z0-9_]*$"
)
gcs_path = AnyPath(gcs)
# Check if it is a file path
if gcs_path.is_file():
# Mount the parent directory
gcs = str(gcs_path.parent._no_prefix)
mount_path = (
f"{DEFAULT_MOUNTED_ROOT}/{name}/{gcs_path.parent.name}"
)
self._path_envs[name] = f"{mount_path}/{gcs_path.name}"
else:
gcs = gcs[5:]
mount_path = f"{DEFAULT_MOUNTED_ROOT}/{name}"
self._path_envs[name] = mount_path
volumes.append(
{
"gcs": {"remotePath": gcs},
"mountPath": mount_path,
}
)
else:
# Or, we expect a literal mount "gs://bucket/dir:/mount/path"
gcs, mount_path = m.rsplit(":", 1)
if gcs.startswith("gs://"):
gcs = gcs[5:]
volumes.append(
{
"gcs": {"remotePath": gcs},
"mountPath": mount_path,
}
)
# Add some labels for filtering by `gcloud batch jobs list`
labels = self.config.setdefault("labels", {})
labels.setdefault("xqute", "true")
labels.setdefault("user", getpass.getuser())
allocation_policy = self.config.setdefault("allocationPolicy", {})
if service_account:
allocation_policy.setdefault("serviceAccount", {}).setdefault(
"email", service_account
)
if network or subnetwork or no_external_ip_address is not None:
network_interface = allocation_policy.setdefault("network", {}).setdefault(
"networkInterfaces", []
)
if not network_interface:
network_interface.append({})
network_interface = network_interface[0]
if network:
network_interface.setdefault("network", network)
if subnetwork:
network_interface.setdefault("subnetwork", subnetwork)
if no_external_ip_address is not None:
network_interface.setdefault(
"noExternalIpAddress", no_external_ip_address
)
if machine_type or provisioning_model:
instances = allocation_policy.setdefault("instances", [])
if not instances:
instances.append({})
policy = instances[0].setdefault("policy", {})
if machine_type:
policy.setdefault("machineType", machine_type)
if provisioning_model:
policy.setdefault("provisioningModel", provisioning_model)
email = allocation_policy.get("serviceAccount", {}).get("email")
if email:
# 63 character limit, '@' is not allowed in labels
# labels.setdefault("email", email[:63])
labels.setdefault("sacct", email.split("@", 1)[0][:63])
def job_config_file(self, job: Job) -> SpecPath:
base = f"job.wrapped.{self.name}.json"
conf_file = job.metadir / base
wrapt_script = self.wrapped_job_script(job)
config = deepcopy(self.config)
runnable = config["taskGroups"][0]["taskSpec"]["runnables"][self.runnable_index]
if "container" in runnable:
container = runnable["container"]
if "entrypoint" not in container:
# supports only /bin/bash, but not /bin/bash -u
container["entrypoint"] = JOBCMD_WRAPPER_LANG
container["commands"].append(str(wrapt_script.mounted))
elif any("{script}" in cmd for cmd in container["commands"]):
# If the entrypoint is already set, we assume it is a script
# that will be executed with the job command.
container["commands"] = [
cmd.replace("{lang}", str(JOBCMD_WRAPPER_LANG)).replace(
"{script}", str(wrapt_script.mounted)
)
for cmd in container["commands"]
]
else:
container["commands"].append(
shlex.join(
shlex.split(JOBCMD_WRAPPER_LANG) + [str(wrapt_script.mounted)]
)
)
else:
# Apply commands for script runnables as well
stored_commands = runnable["script"].pop("_commands", None)
if stored_commands:
if any("{script}" in str(cmd) for cmd in stored_commands):
# Use commands with script placeholder replacement
command_parts = [
shlex.quote(cmd)
.replace("{lang}", str(JOBCMD_WRAPPER_LANG))
.replace("{script}", str(wrapt_script.mounted))
for cmd in stored_commands
]
else:
# Append script to commands
command_parts = [
*(shlex.quote(str(cmd)) for cmd in stored_commands),
shlex.quote(
shlex.join(
(
*shlex.split(JOBCMD_WRAPPER_LANG),
str(wrapt_script.mounted),
)
)
),
]
else:
command_parts = [
*shlex.split(JOBCMD_WRAPPER_LANG),
str(wrapt_script.mounted),
]
runnable["script"]["text"] = " ".join(command_parts)
with conf_file.open("w") as f:
json.dump(config, f, indent=2)
return conf_file
async def _delete_job(self, job: Job) -> None:
"""Try to delete the job from google cloud's registry
As google doesn't allow jobs to have the same id.
Args:
job: The job to delete
"""
logger.debug(
"/Scheduler-%s Try deleting job %r on GCP.",
self.name,
job,
)
status = await self._get_job_status(job)
while status.endswith("_IN_PROGRESS"): # pragma: no cover
await asyncio.sleep(1)
status = await self._get_job_status(job)
command = [
self.gcloud,
"batch",
"jobs",
"delete",
job.jid,
"--project",
self.project,
"--location",
self.location,
]
try:
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except Exception:
pass
else: # pragma: no cover
await proc.wait()
status = await self._get_job_status(job)
while status == "DELETION_IN_PROGRESS": # pragma: no cover
await asyncio.sleep(1)
status = await self._get_job_status(job)
if status != "UNKNOWN": # pragma: no cover
logger.warning(
"/Scheduler-%s Failed to delete job %r on GCP, submision may fail.",
self.name,
job,
)
async def submit_job(self, job: Job) -> str:DOCS
sha = sha256(str(self.workdir).encode()).hexdigest()[:8]
job.jid = f"{self.jobname_prefix}-{sha}-{job.index}".lower()
await self._delete_job(job)
conf_file = self.job_config_file(job)
proc = await asyncio.create_subprocess_exec(
self.gcloud,
"batch",
"jobs",
"submit",
job.jid,
"--config",
conf_file.fspath,
"--project",
self.project,
"--location",
self.location,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0: # pragma: no cover
raise RuntimeError(
"Can't submit job to Google Cloud Batch: \n"
f"{stderr.decode()}\n"
"Check the configuration file:\n"
f"{conf_file}"
)
return job.jid
async def kill_job(self, job: Job):DOCS
command = [
self.gcloud,
"alpha",
"batch",
"jobs",
"cancel",
job.jid,
"--project",
self.project,
"--location",
self.location,
"--quiet",
]
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.wait()
async def _get_job_status(self, job: Job) -> str:
if not job.jid_file.is_file():
return "UNKNOWN"
# Do not rely on _jid, as it can be a obolete job.
jid = job.jid_file.read_text().strip()
command = [
self.gcloud,
"batch",
"jobs",
"describe",
jid,
"--project",
self.project,
"--location",
self.location,
]
try:
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except Exception: # pragma: no cover
return "UNKNOWN"
if await proc.wait() != 0:
return "UNKNOWN"
stdout = (await proc.stdout.read()).decode()
return re.search(r"state: (.+)", stdout).group(1)
async def job_is_running(self, job: Job) -> bool:DOCS
status = await self._get_job_status(job)
return status in ("RUNNING", "QUEUED", "SCHEDULED")
def jobcmd_init(self, job) -> str:DOCS
init_cmd = super().jobcmd_init(job)
path_envs_exports = [
f"export {key}={shlex.quote(value)}"
for key, value in self._path_envs.items()
]
if path_envs_exports:
path_envs_exports.insert(0, "# Mounted paths")
init_cmd = "\n".join(path_envs_exports) + "\n" + init_cmd
return init_cmd