Skip to content

SOURCE CODE plkit.module DOCS

"""The core base module class based on pytorch_lightning.LightningModule"""
import torch
from torch import nn
from pytorch_lightning import LightningModule
from .utils import collapse_suggest_config

class Module(LightningModule): # pylint: disable=too-many-ancestorsDOCS
    """The Module class

    `on_epoch_end` is added to print a newline to keep the progress bar and the
    stats on it for each epoch. If you don't want this, just overwrite it with:
        >>> def on_epoch_end(self):
        >>>     pass

    If you have other stuff to do in `on_epoch_end`, make sure to you call:
        >>> super().on_epoch_end()

    You may or may not need to write `loss_function`, as it will be inferred
    from config item `loss` and `num_classes`. Basically, `MSELoss` will be
    used for regression and `CrossEntropyLoss` for classification.

    `measure` added for convinience to get some metrics between logits
    and targets.

    Args:
        config: The configuration dictionary

    Attributes:
        Apart from attributes of `LightningModule`, following attributes added:
        config: The configs
        optim: The optimizer name. currently only `adam` and `sgd`
            are supported. With this, of course you can, but you don't need to
            write `configure_optimizers`.
        num_classes: Number of classes to predict. 1 for regression
        _loss_func: The loss function
    """

    def __init__(self, config):
        super().__init__()
        self.config = collapse_suggest_config(config)
        self.optim = self.config.get('optim', 'adam')
        self.num_classes = self.config.get('num_classes')
        # We may run test only without measurement.
        # if not self.num_classes:
        #     raise ValueError('We need `num_classes` from config or passed in '
        #                      'explictly to check final logits size.')

        loss = self.config.get('loss', 'auto')
        if loss == 'auto':
            if self.num_classes == 1: # regression
                self._loss_func = nn.MSELoss()
            else:
                self._loss_func = nn.CrossEntropyLoss()
        else:
            self._loss_func = loss

        # the hyperparameters to be logged to tensorboard
        self.hparams = {}

    def on_epoch_end(self):DOCS
        """Keep the epoch progress bar
        This is not documented but working."""
        print()

    def loss_function(self, logits, labels):DOCS
        """Calculate the loss"""
        return self._loss_func(logits, labels)

    # pylint: disable=inconsistent-return-statements
    def configure_optimizers(self):DOCS
        """Configure the optimizers"""
        if self.optim == 'adam':
            return torch.optim.Adam(self.parameters(),
                                    lr=self.config.get('learning_rate', 1e-3))
        if self.optim == 'sgd':
            return torch.optim.SGD(self.parameters(),
                                   lr=self.config.get('learning_rate', 1e-3),
                                   momentum=self.config.get('momentum', .9))
        # more to support