Skip to content

SOURCE CODE xqute.schedulers.ssh_scheduler.scheduler DOCS

"""The scheduler to run jobs on SSH"""
from __future__ import annotations

from typing import Mapping, Type

from ...scheduler import Scheduler
from ...utils import a_read_text
from ...job import Job
from ..local_scheduler import LocalJob

from .client import SSHClient


class SshJob(LocalJob):DOCS
    """SSH job"""


class SshScheduler(Scheduler):DOCS
    """The ssh scheduler

    Attributes:
        name: The name of the scheduler
        job_class: The job class

    Args:
        ssh_*: SSH options
        ... other Scheduler args
    """
    name: str = "ssh"
    job_class: Type[Job] = SshJob

    def __init__(self, forks: int, prescript: str = "", postscript: str = "", **kwargs):
        super().__init__(forks, prescript, postscript, **kwargs)
        self.servers: Mapping[str, SSHClient] = {}
        ssh = self.config.get('ssh', 'ssh')
        for key, val in self.config.get('ssh_servers', {}).items():
            client = SSHClient(ssh, key, **val)
            self.servers[client.name] = client

        if not self.servers:
            raise ValueError(
                "No ssh_servers defined in config, "
                "please define at least one server",
            )

    def __del__(self):
        for server in self.servers.values():
            if server.is_connected:
                server.disconnect()

    async def submit_job(self, job: Job) -> str:DOCS
        """Submit a job to SSH

        Args:
            job: The job

        Returns:
            The job id
        """
        server = list(self.servers.values())[job.index % len(self.servers)]
        await server.connect()

        return await server.submit(
            job.__class__.CMD_WRAPPER_SHELL,
            await job.wrapped_script(self),
        )

    async def kill_job(self, job: Job):DOCS
        """Kill a job on SSH

        Args:
            job: The job
        """
        try:
            server, pid = str(job.jid).split('/')
            await self.servers[server].kill(pid)
        except Exception:  # pragma: no cover
            pass

    async def job_is_running(self, job: Job) -> bool:DOCS
        """Tell if a job is really running, not only the job.jid_file

        In case where the jid file is not cleaned when job is done.

        Args:
            job: The job

        Returns:
            True if it is, otherwise False
        """
        try:
            jid = await a_read_text(job.jid_file)
        except FileNotFoundError:
            return False

        if not jid:
            return False

        server, pid = jid.split('/')
        if server not in self.servers:
            return False

        try:
            return await self.servers[server].is_running(pid)
        except Exception:  # pragma: no cover
            return False