Skip to content

SOURCE CODE plkit.optuna DOCS

"""Optuna wrapper for plkit"""
from diot import FrozenDiot
from torch import Tensor
from pytorch_lightning.callbacks import ModelCheckpoint
import optuna
from .trainer import Trainer
from .utils import log_config, logger, plkit_seed_everything

# supress optuna logging
optuna.logging._get_library_root_logger().handlers.clear()

class OptunaSuggest:DOCS
    """Optuna suggests for configuration items

    Args:
        default (any): The default value, which the value will be collapsed to
            when optuna is opted out. So that you don't have to change your
            code if you don't run optuna.
        suggtype (str): The type of suggestion
            For example, `cat` refers to `trial.suggest_categorical`
            The mappings are:
            cat -> 'suggest_categorical',
            categorical -> 'suggest_categorical',
            distuni -> 'suggest_discrete_uniform',
            dist_uni -> 'suggest_discrete_uniform',
            discrete_uniform -> 'suggest_discrete_uniform',
            float -> 'suggest_float',
            int -> 'suggest_int',
            loguni -> 'suggest_loguniform',
            log_uni -> 'suggest_loguniform',
            uni -> 'suggest_uniform'
        *args: The args used in `trial.suggest_xxx(name, *args, **kwargs)`
        **kwargs: The kwargs used in `trial.suggest_xxx(name, *args, **kwargs)`

    Attributes:
        default (any): The default from Args
        suggfunc (str): The transformed suggestion name according to `suggtype`
        args: *args from Args
        kwargs: **kwargs from Args
    """
    def __init__(self, default, suggtype, *args, **kwargs):
        self.default = default
        self.suggfunc = dict(
            cat='suggest_categorical',
            categorical='suggest_categorical',
            distuni='suggest_discrete_uniform',
            dist_uni='suggest_discrete_uniform',
            discrete_uniform='suggest_discrete_uniform',
            float='suggest_float',
            int='suggest_int',
            loguni='suggest_loguniform',
            log_uni='suggest_loguniform',
            uni='suggest_uniform'
        )[suggtype]
        self.args = args
        self.kwargs = kwargs

    def suggest(self, name, trial):DOCS
        """Get the suggested value

        This is used in Optuna class, you don't have to call this

        Args:
            name (str): The name of the parameter
            trial (optuna.Trial): The trial to get the suggested value from

        Returns:
            Any: The suggested value
        """
        return getattr(trial, self.suggfunc)(name, *self.args, **self.kwargs)


class Optuna:DOCS
    """The class uses optuna to automate hyperparameter tuning

    Example:
        >>> from plkit import Data, Module, Optuna
        >>> class MyData(Data):
        >>>     ...
        >>> class MyModel(Module):
        >>>     ...
        >>> class MyOptuna(Optuna):
        >>>     def suggests(self, config):
        >>>         ...
        >>>         return new_config
        >>> optuna = MyOptuna('val_loss', 100)
        >>> optuna.run(config, model_class, data_class)

    Args:
        on (str): On which value to optimize. Should be one of the keys of
            dictionary that is returned from `validation_epoch_end`.
            `val_acc` or `val_loss` for example.
        n_trials (int): Number of trials
        **kwargs: Other keyword arguments for `optuna.create_study`

    Attributes:
        on (str): on from Args
        n_trials (int): n_trials from Args
        study (optuna.Study): study object created from kwargs
        trainers (list): list of trainers to keep track of the best one
    """

    def __init__(self,
                 on,
                 n_trials,
                 **kwargs):
        self.on = on
        self.n_trials = n_trials
        self._best_trainer = None
        self._best_model = None
        self.study = optuna.create_study(**kwargs)
        # trainers, used for retrieve the best one
        self.trainers = []

    def _create_objective(self, config, data, model_class):
        """Create objective function for the study to optimize

        The objective function is built to return the best value from
        `validation_epoch_end` based on `self.on`. To implement this, a
        `ModelCheckpoint` callback is used and `best_model_score` is returned
        from it.

        Args:
            config (dict): The configuration dictionary
            data_class (class): The data class subclassed from `plkit.Data`
                Note it's the class itself, not instantiated object
            model_class (class): The model class subclassed from `plkit.Module`
                Note it's the class itself, not instantiated object

        Returns:
            callable: The objective function
        """
        def _objective(trial):
            logger.info('--------------------------------')
            logger.info('Start optuna trial #%s / %s',
                        len(self.trainers), self.n_trials - 1)
            logger.info('--------------------------------')
            suggested = self.suggests(trial, config)

            config_copy = config.copy()

            with config_copy.thaw():
                config_copy.update(suggested)

            log_config(suggested, "Tunable parameters")

            model = model_class(config_copy)
            model.hparams.update(suggested)

            # expose filepath argument?
            checkpoint_callback = ModelCheckpoint(monitor=self.on)
            trainer = Trainer.from_config(
                config_copy,
                checkpoint_callback=checkpoint_callback
            )

            trainer.fit(model, data)
            best_score = checkpoint_callback.best_model_score

            self.trainers.append((checkpoint_callback.best_model_path,
                                  config_copy,
                                  best_score))

            logger.info('')
            logger.info("'Optuna': trial #%s done with  %s = %s "
                        "and parameters: %s",
                        trial.number,
                        self.on,
                        best_score,
                        trial.params)
            number, score, params = self._current_best
            params = {key: val for key, val in params.items()
                      if key in suggested}
            logger.info("'Optuna': the best is #%s with %s = %s "
                        "and parameters: %s",
                        number,
                        self.on,
                        score,
                        params)
            logger.info('')
            return best_score

        return _objective

    @property
    def _current_best(self):
        """Get the current best trial number and parameters when the tuning
        is incomplete"""
        func = (max
                if self.study.direction == optuna.study.StudyDirection.MAXIMIZE
                else min)
        values = [trainer[2] if not isinstance(trainer[2], Tensor)
                  else trainer[2].cpu()
                  for trainer in self.trainers]
        number = values.index(func(values))
        trainer = self.trainers[number]
        return number, trainer[2], trainer[1]

    def suggests(self, trial, conf):DOCS
        """Collect the hyperparameters from the trial suggestions
        if any configuration item is an `OptunaSuggest` object

        Args:
            trial (optuna.Trial): the trial object
            conf (dict): The configuration dictionary

        Returns:
            dict: A dictionary of suggested parameters
        """

        return {key: val.suggest(key, trial)
                for key, val in conf.items()
                if isinstance(val, OptunaSuggest)}

    def run(self, config, data_class, model_class, **kwargs):DOCS
        # pylint: disable=line-too-long
        """Run the optimization

        The optimization is running on fit of the trainer. If test data is
        provided. Test will be performed as well.

        Args:
            config (dict): The configuation dictionary
            data_class (class): The data class subclassed from `plkit.Data`
                Note that this is the class itself, not the instantized object.
            model_class (class): The data class subclassed from `plkit.Module`
                Note that this is the class itself, not the instantized object.
            **kwargs: Other arguments for `study.optimize` other than
                `func` and `n_trials`.
                See: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize
        """
        # pylint: enable=line-too-long
        if not isinstance(config, FrozenDiot):
            config = FrozenDiot(config)

        plkit_seed_everything(config)

        data = data_class(config=config)
        objective = self._create_objective(config, data, model_class)
        self.study.optimize(objective, self.n_trials, **kwargs)

        self._best_trainer = Trainer.from_config(
            self.trainers[self.best_trial.number][1]
        )

        self._best_model = model_class.load_from_checkpoint(
            self.trainers[self.best_trial.number][0],
            # https://github.com/PyTorchLightning/pytorch-lightning/issues/2550
            config=self.trainers[self.best_trial.number][1]
        )

        if hasattr(data, 'test_dataloader'):
            test_dataloaders = data.test_dataloader()
        else: # pragma: no cover
            test_dataloaders = None

        if test_dataloaders:
            logger.info('')
            logger.info('---------------------------------')
            logger.info('Testing using best trial: #%s', self.best_trial.number)
            logger.info('---------------------------------')
            self.best_trainer.test(self.best_model,
                                   test_dataloaders=test_dataloaders)

        return self.best_trainer

    optimize = run

    @propertyDOCS
    def best_params(self):
        """The best parameters from the study

        Returns:
            dict: A dictionary containing parameters of the best trial.
        """
        return self.study.best_params

    @propertyDOCS
    def best_trial(self):
        """The best trial from the study

        Returns:
            optuna.FrozenTrial: A FrozenTrial object of the best trial.
        """
        return self.study.best_trial

    @propertyDOCS
    def trials(self):
        """The trials

        Returns:
            list: A list of FrozenTrial objects.
        """
        return self.study.trials

    @propertyDOCS
    def best_trainer(self):
        """Get the best trainer

        Returns:
            Trainer: The trainer object of the best trial.
        """
        return self._best_trainer

    @propertyDOCS
    def best_model(self):
        """Get the model from best trainer

        Returns:
            Module: The model of the best trial
        """
        return self._best_model