Skip to content

plkit.trainer

module

plkit.trainer

Wrapper of the Trainer class

Classes
class

plkit.trainer.ProgressBar(refresh_rate=1, process_position=0)

Bases
pytorch_lightning.callbacks.progress.ProgressBar pytorch_lightning.callbacks.progress.ProgressBarBase pytorch_lightning.callbacks.base.Callback

Align the Epoch in progress bar

Attributes
  • test_batch_idx (int) The current batch index being processed during testing. Use this to update your progress bar.</>
  • total_test_batches (int) The total number of training batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the test dataloader is of infinite size.</>
  • total_train_batches (int) The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the training dataloader is of infinite size.</>
  • total_val_batches (int) The total number of training batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the validation dataloader is of infinite size.</>
  • train_batch_idx (int) The current batch index being processed during training. Use this to update your progress bar.</>
  • val_batch_idx (int) The current batch index being processed during validation. Use this to update your progress bar.</>
Methods
  • disable() You should provide a way to disable the progress bar. The :class:~pytorch_lightning.trainer.trainer.Trainer will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.</>
  • enable() You should provide a way to enable the progress bar. The :class:~pytorch_lightning.trainer.trainer.Trainer will call this in e.g. pre-training routines like the :ref:learning rate finder <lr_finder> to temporarily enable and disable the main progress bar.</>
  • init_sanity_tqdm() (tqdm) Override this to customize the tqdm bar for the validation sanity run.</>
  • init_test_tqdm() (tqdm) Override this to customize the tqdm bar for testing.</>
  • init_train_tqdm() (tqdm) Override this to customize the tqdm bar for training.</>
  • init_validation_tqdm() (tqdm) Override this to customize the tqdm bar for validation.</>
  • on_after_backward(trainer, pl_module) Called after loss.backward() and before optimizers do anything.</>
  • on_batch_end(trainer, pl_module) Called when the training batch ends.</>
  • on_batch_start(trainer, pl_module) Called when the training batch begins.</>
  • on_before_zero_grad(trainer, pl_module, optimizer) Called after optimizer.step() and before optimizer.zero_grad().</>
  • on_epoch_end(trainer, pl_module) Called when the epoch ends.</>
  • on_epoch_start(trainer, pl_module) Try to align the epoch number</>
  • on_fit_end(trainer, pl_module) Called when fit ends</>
  • on_fit_start(trainer, pl_module) Called when fit begins</>
  • on_init_end(trainer) Called when the trainer initialization ends, model has not yet been set.</>
  • on_init_start(trainer) Called when the trainer initialization begins, model has not yet been set.</>
  • on_keyboard_interrupt(trainer, pl_module) Called when the training is interrupted by KeyboardInterrupt.</>
  • on_load_checkpoint(checkpointed_state) Called when loading a model checkpoint, use to reload state.</>
  • on_pretrain_routine_end(trainer, pl_module) Called when the pretrain routine ends.</>
  • on_pretrain_routine_start(trainer, pl_module) Called when the pretrain routine begins.</>
  • on_sanity_check_end(trainer, pl_module) Called when the validation sanity check ends.</>
  • on_sanity_check_start(trainer, pl_module) Called when the validation sanity check starts.</>
  • on_save_checkpoint(trainer, pl_module) Called when saving a model checkpoint, use to persist state.</>
  • on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) Called when the test batch ends.</>
  • on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) Called when the test batch begins.</>
  • on_test_end(trainer, pl_module) Called when the test ends.</>
  • on_test_epoch_end(trainer, pl_module) Called when the test epoch ends.</>
  • on_test_epoch_start(trainer, pl_module) Called when the test epoch begins.</>
  • on_test_start(trainer, pl_module) Called when the test begins.</>
  • on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) Called when the train batch ends.</>
  • on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) Called when the train batch begins.</>
  • on_train_end(trainer, pl_module) Called when the train ends.</>
  • on_train_epoch_end(trainer, pl_module, outputs) Called when the train epoch ends.</>
  • on_train_epoch_start(trainer, pl_module) Called when the train epoch begins.</>
  • on_train_start(trainer, pl_module) Called when the train begins.</>
  • on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) Called when the validation batch ends.</>
  • on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) Called when the validation batch begins.</>
  • on_validation_end(trainer, pl_module) Called when the validation loop ends.</>
  • on_validation_epoch_end(trainer, pl_module) Called when the val epoch ends.</>
  • on_validation_epoch_start(trainer, pl_module) Called when the val epoch begins.</>
  • on_validation_start(trainer, pl_module) Called when the validation loop begins.</>
  • setup(trainer, pl_module, stage) Called when fit or test begins</>
  • teardown(trainer, pl_module, stage) Called when fit or test ends</>
method

setup(trainer, pl_module, stage)

Called when fit or test begins

method

teardown(trainer, pl_module, stage)

Called when fit or test ends

method

on_init_start(trainer)

Called when the trainer initialization begins, model has not yet been set.

method

on_fit_start(trainer, pl_module)

Called when fit begins

method

on_fit_end(trainer, pl_module)

Called when fit ends

method

on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)

Called when the train batch begins.

method

on_train_epoch_start(trainer, pl_module)

Called when the train epoch begins.

method

on_train_epoch_end(trainer, pl_module, outputs)

Called when the train epoch ends.

method

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

method

on_validation_epoch_end(trainer, pl_module)

Called when the val epoch ends.

method

on_test_epoch_start(trainer, pl_module)

Called when the test epoch begins.

method

on_test_epoch_end(trainer, pl_module)

Called when the test epoch ends.

method

on_epoch_end(trainer, pl_module)

Called when the epoch ends.

method

on_batch_start(trainer, pl_module)

Called when the training batch begins.

method

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)

Called when the validation batch begins.

method

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)

Called when the test batch begins.

method

on_batch_end(trainer, pl_module)

Called when the training batch ends.

method

on_pretrain_routine_start(trainer, pl_module)

Called when the pretrain routine begins.

method

on_pretrain_routine_end(trainer, pl_module)

Called when the pretrain routine ends.

method

on_keyboard_interrupt(trainer, pl_module)

Called when the training is interrupted by KeyboardInterrupt.

method

on_save_checkpoint(trainer, pl_module)

Called when saving a model checkpoint, use to persist state.

method

on_load_checkpoint(checkpointed_state)

Called when loading a model checkpoint, use to reload state.

method

on_after_backward(trainer, pl_module)

Called after loss.backward() and before optimizers do anything.

method

on_before_zero_grad(trainer, pl_module, optimizer)

Called after optimizer.step() and before optimizer.zero_grad().

method

on_init_end(trainer)

Called when the trainer initialization ends, model has not yet been set.

method

disable()

You should provide a way to disable the progress bar. The :class:~pytorch_lightning.trainer.trainer.Trainer will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.

method

enable()

You should provide a way to enable the progress bar. The :class:~pytorch_lightning.trainer.trainer.Trainer will call this in e.g. pre-training routines like the :ref:learning rate finder <lr_finder> to temporarily enable and disable the main progress bar.

method

init_sanity_tqdm() → tqdm

Override this to customize the tqdm bar for the validation sanity run.

method

init_train_tqdm() → tqdm

Override this to customize the tqdm bar for training.

method

init_validation_tqdm() → tqdm

Override this to customize the tqdm bar for validation.

method

init_test_tqdm() → tqdm

Override this to customize the tqdm bar for testing.

method

on_sanity_check_start(trainer, pl_module)

Called when the validation sanity check starts.

method

on_sanity_check_end(trainer, pl_module)

Called when the validation sanity check ends.

method

on_train_start(trainer, pl_module)

Called when the train begins.

method

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

Called when the train batch ends.

method

on_validation_start(trainer, pl_module)

Called when the validation loop begins.

method

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

Called when the validation batch ends.

method

on_validation_end(trainer, pl_module)

Called when the validation loop ends.

method

on_train_end(trainer, pl_module)

Called when the train ends.

method

on_test_start(trainer, pl_module)

Called when the test begins.

method

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

Called when the test batch ends.

method

on_test_end(trainer, pl_module)

Called when the test ends.

method

on_epoch_start(trainer, pl_module)

Try to align the epoch number

class

plkit.trainer.Trainer(*args, **kwargs)

Bases
pytorch_lightning.trainer.trainer.Trainer pytorch_lightning.trainer.properties.TrainerProperties pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin pytorch_lightning.trainer.logging.TrainerLoggingMixin pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin abc.ABC pytorch_lightning.trainer.deprecated_api.DeprecatedDistDeviceAttributes

The Trainner class

from_config (aka from_dict) added as classmethod to instantiate trainer from configuration dictionaries.

Attributes
  • checkpoint_callback (ModelCheckpoint, optional) The first checkpoint callback in the Trainer.callbacks list, or None if no checkpoint callbacks exist.</>
  • checkpoint_callbacks (list of ModelCheckpoint) A list of all instances of ModelCheckpoint found in the Trainer.callbacks list.</>
  • default_root_dir (str) The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths.</>
  • disable_validation (bool) Check if validation is disabled during training.</>
  • enable_validation (bool) Check if we should run validation during training.</>
  • progress_bar_dict (dict) Format progress bar metrics.</>
  • weights_save_path (str) The default root location to save weights (checkpoints), e.g., when the :class:~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint does not define a file path.</>
Classes
  • ABCMeta Metaclass for defining Abstract Base Classes (ABCs).</>
Methods
class

abc.ABCMeta(name, bases, namespace, **kwargs)

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

Methods
staticmethod
register(cls, subclass)

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

staticmethod
__instancecheck__(cls, instance)

Override for isinstance(instance, cls).

staticmethod
__subclasscheck__(cls, subclass)

Override for issubclass(subclass, cls).

method

reset_train_dataloader(model)

Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.).

Parameters
  • model (LightningModule) The current LightningModule
method

reset_val_dataloader(model)

Resets the validation dataloader and determines the number of batches.

Parameters
  • model (LightningModule) The current LightningModule
method

reset_test_dataloader(model)

Resets the validation dataloader and determines the number of batches.

Parameters
  • model The current LightningModule
method

request_dataloader(dataloader_fx)

Handles downloading data in the GPU or TPU case.

Parameters
  • dataloader_fx (callable) The bound dataloader getter
Returns (DataLoader)

The dataloader

method

process_dict_result(output, train=False)

Reduces output according to the training mode.

Separates loss from logging and progress bar metrics

method

setup(model, stage)

Called in the beginning of fit and test

method

teardown(stage)

Called at the end of fit and test

method

on_init_start()

Called when the trainer initialization begins, model has not yet been set.

method

on_init_end()

Called when the trainer initialization ends, model has not yet been set.

method

on_fit_start()

Called when the trainer initialization begins, model has not yet been set.

method

on_fit_end()

Called when the trainer initialization begins, model has not yet been set.

method

on_sanity_check_start()

Called when the validation sanity check starts.

method

on_sanity_check_end()

Called when the validation sanity check ends.

method

on_train_epoch_start()

Called when the epoch begins.

method

on_train_epoch_end(outputs)

Called when the epoch ends.

method

on_validation_epoch_start()

Called when the epoch begins.

method

on_validation_epoch_end()

Called when the epoch ends.

method

on_test_epoch_start()

Called when the epoch begins.

method

on_test_epoch_end()

Called when the epoch ends.

method

on_epoch_start()

Called when the epoch begins.

method

on_epoch_end()

Called when the epoch ends.

method

on_train_start()

Called when the train begins.

method

on_train_end()

Called when the train ends.

method

on_pretrain_routine_start(model)

Called when the train begins.

method

on_pretrain_routine_end(model)

Called when the train ends.

method

on_batch_start()

Called when the training batch begins.

method

on_batch_end()

Called when the training batch ends.

method

on_train_batch_start(batch, batch_idx, dataloader_idx)

Called when the training batch begins.

method

on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called when the training batch ends.

method

on_validation_batch_start(batch, batch_idx, dataloader_idx)

Called when the validation batch begins.

method

on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called when the validation batch ends.

method

on_test_batch_start(batch, batch_idx, dataloader_idx)

Called when the test batch begins.

method

on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called when the test batch ends.

method

on_validation_start()

Called when the validation loop begins.

method

on_validation_end()

Called when the validation loop ends.

method

on_test_start()

Called when the test begins.

method

on_test_end()

Called when the test ends.

method

on_keyboard_interrupt()

Called when the training is interrupted by KeyboardInterrupt.

method

on_save_checkpoint()

Called when saving a model checkpoint.

method

on_load_checkpoint(checkpoint)

Called when loading a model checkpoint.

method

on_after_backward()

Called after loss.backward() and before optimizers do anything.

method

on_before_zero_grad(optimizer)

Called after optimizer.step() and before optimizer.zero_grad().

classmethod

get_deprecated_arg_names() → list

Returns a list with deprecated Trainer arguments.

method

tune(model, train_dataloader=None, val_dataloaders=None, datamodule=None)

Runs routines to tune hyperparameters before training.

Parameters
  • model (LightningModule) Model to tune.
  • train_dataloader (DataLoader, optional) A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.
  • val_dataloaders (Union(dataloader, list of dataloader, nonetype), optional) Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
  • datamodule (LightningDataModule, optional) A instance of :class:LightningDataModule.
staticmethod

available_plugins()

List of all available plugins that can be string arguments to the trainer. Returns: List of all available plugins that are supported as string arguments.

classmethod

from_config(config, **kwargs)

Create an instance from CLI arguments.

Examples
>>> config = {'my_custom_arg': 'something'}
>>> trainer = Trainer.from_dict(config, logger=False)
Parameters
  • config The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the :class:Trainer.
  • **kwargs Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments.
classmethod

from_config(config, **kwargs)

Create an instance from CLI arguments.

Examples
>>> config = {'my_custom_arg': 'something'}
>>> trainer = Trainer.from_dict(config, logger=False)
Parameters
  • config The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the :class:Trainer.
  • **kwargs Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments.
method

fit(*args, **kwargs)

Train and validate the model

method

test(*args, **kwargs)

Test the model