Skip to content

SOURCE CODE plkit.runner DOCS

"""Run jobs via non-local runners."""
import os
import sys
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type
from diot import FrozenDiot
import cmdy
from .data import DataModule
from .module import Module
from .optuna import Optuna
from .trainer import Trainer
from .utils import logger, warning_to_logging, plkit_seed_everything

class Runner(ABC):DOCS
    """The base class for runner"""
    @abstractmethodDOCS
    def run(self,
            config: Dict[str, Any],
            data_class: Type[DataModule],
            model_class: Type[Module],
            optuna: Optional[Optuna] = None) -> Trainer:
        """Run the whole pipeline using the runner

        Args:
            config: A dictionary of configuration, must have following items:
                - batch_size: The batch size
                - num_classes: The number of classes for classification
                    1 means regression
            data_class: The data class subclassed from `Data`
            model_class: The model class subclassed from `Module`
            optuna: The optuna object
            runner: The runner object

        Returns:
            The trainer object
        """


class LocalRunner(Runner):DOCS
    """The local runner for the pipeline"""

    def run(self,DOCS
            config: Dict[str, Any],
            data_class: Type[DataModule],
            model_class: Type[Module],
            optuna: Optional[Optuna] = None) -> Trainer:
        """Run the pipeline locally"""
        if not isinstance(config, FrozenDiot):
            config = FrozenDiot(config)

        if optuna: # pragma: no cover
            return optuna.run(config, data_class, model_class)

        plkit_seed_everything(config)

        data = data_class(config=config)
        model = model_class(config)
        trainer = Trainer.from_config(config)
        with warning_to_logging():
            trainer.fit(model, data)

        if hasattr(data, 'test_dataloader'):
            test_dataloader = data.test_dataloader()
        else: # pragma: no cover
            test_dataloader = None

        if test_dataloader:
            with warning_to_logging():
                trainer.test(test_dataloaders=test_dataloader)
        return trainer

class SGERunner(LocalRunner):DOCS
    """The SGE runner for the pipeline

    Args:
        opts: The options for SGE runner, which will be translated as arguments
            for `qsub`. For example `opts={'notify': True}` will be translated
            as `qsub --notify ...` from command line.

            there are two special options `qsub` and `workdir`. `qsub` specified
            the path to `qsub` executable and `workdir` specifies a location to
            save outputs, errors and scripts of each job.

    Attributes:
        qsub: The path to qsub executable
        workdir: The path to the workdir
    """

    ENV_FLAG_PREFIX = "PLKIT_SGE_RUNNER_"

    def __init__(self, *args, **opts):
        self.qsub = opts.pop("qsub", "qsub") # type: str
        self.workdir = opts.pop("workdir", "./workdir") # type: str
        os.makedirs(self.workdir, exist_ok=True)

        self.args = args
        self.opts = opts
        self.uid = uuid.uuid5(uuid.NAMESPACE_DNS, str(sys.argv))
        self.envname = SGERunner.ENV_FLAG_PREFIX + str(self.uid).split('-')[0]


    def run(self,DOCS
            config: Dict[str, Any],
            data_class: Type[DataModule],
            model_class: Type[Module],
            optuna: Optional[Optuna] = None) -> Trainer:
        """Run the job depending on the env flag"""
        if not os.environ.get(self.envname):
            logger.info('Wrapping up the job ...')
            workdir = os.path.join(self.workdir, f'plkit-{self.uid}')
            os.makedirs(workdir, exist_ok=True)
            logger.info('  - Workdir: %s', workdir)

            script = os.path.join(workdir, 'job.sge.sh')
            logger.info('  - Script: %s', script)
            with open(script, 'w') as fscript:
                fscript.write("#!/bin/sh\n\n")
                cmd = cmdy._(*sys.argv, _exe=sys.executable).h.strcmd
                fscript.write(f"{self.envname}=1 {cmd}\n")

            opts = self.opts.copy()
            opts.setdefault('o', os.path.join(workdir, 'job.stdout'))
            opts.setdefault('cwd', True)
            opts.setdefault('j', 'y')
            opts.setdefault('notify', True)
            opts.setdefault('N', os.path.basename(workdir))

            logger.info('Submitting the job ...')
            cmd = cmdy.qsub(*self.args,
                            opts,
                            script,
                            cmdy_dupkey=True,
                            cmdy_prefix='-',
                            cmdy_exe=self.qsub).h()
            logger.info('  - Running: %s', cmd.strcmd)
            logger.info('  - %s', cmd.run().stdout.strip())

            cmdy.touch(opts['o'])
            logger.info('Streaming content from %s', opts['o'])
            cmdy.tail(f=True, _=opts['o']).fg()
            return None # pragma: no cover

        return super().run(config, data_class, model_class, optuna)