Skip to content

plkit.module

module

plkit.module

The core base module class based on pytorch_lightning.LightningModule

Classes
class

plkit.module.Module(config)

Bases
pytorch_lightning.core.lightning.LightningModule abc.ABC pytorch_lightning.utilities.device_dtype_mixin.DeviceDtypeModuleMixin pytorch_lightning.core.grads.GradInformation pytorch_lightning.core.saving.ModelIO pytorch_lightning.core.hooks.ModelHooks pytorch_lightning.core.hooks.DataHooks pytorch_lightning.core.hooks.CheckpointHooks torch.nn.modules.module.Module

The Module class

on_epoch_end is added to print a newline to keep the progress bar and the stats on it for each epoch. If you don't want this, just overwrite it with:

>>> def on_epoch_end(self):
>>>     pass

If you have other stuff to do in on_epoch_end, make sure to you call:

>>> super().on_epoch_end()

You may or may not need to write loss_function, as it will be inferred from config item loss and num_classes. Basically, MSELoss will be used for regression and CrossEntropyLoss for classification.

measure added for convinience to get some metrics between logits and targets.

Parameters
  • config The configuration dictionary
Attributes
  • _loss_func The loss function
  • automatic_optimization (bool) If False you are responsible for calling .backward, .step, zero_grad.</>
  • config The configs
  • current_epoch (int) The current epoch</>
  • global_step (int) Total training batches seen across all epochs</>
  • num_classes Number of classes to predict. 1 for regression
  • on_gpu True if your model is currently running on GPUs. Useful to set flags around the LightningModule for different CPU vs GPU behavior.</>
  • optim The optimizer name. currently only adam and sgd are supported. With this, of course you can, but you don't need to write configure_optimizers.
Classes
  • ABCMeta Metaclass for defining Abstract Base Classes (ABCs).</>
Methods
  • add_module(name, module) Adds a child module to the current module.</>
  • all_gather(tensor, group, sync_grads) Allows users to call self.all_gather() from the LightningModule, thus making the all_gather operation accelerator agnostic.</>
  • apply(fn) (Module) Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).</>
  • backward(loss, optimizer, optimizer_idx, *args, **kwargs) Override backward with your own implementation if you need to.</>
  • buffers(recurse) (torch.Tensor) Returns an iterator over module buffers.</>
  • children() (Module) Returns an iterator over immediate children modules.</>
  • configure_optimizers() Configure the optimizers</>
  • cpu() (Module) Moves all model parameters and buffers to the CPU.</>
  • cuda(device) (Module) Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.</>
  • double() (Module) Casts all floating point parameters and buffers to double datatype.</>
  • eval() (Module) Sets the module in evaluation mode.</>
  • extra_repr() Set the extra representation of the module</>
  • float() (Module) Casts all floating point parameters and buffers to float datatype.</>
  • forward(*args, **kwargs) Same as :meth:torch.nn.Module.forward(), however in Lightning you want this to define the operations you want to use for prediction (i.e.: on a server or as a feature extractor).</>
  • freeze() Freeze all params for inference.</>
  • get_progress_bar_dict() (dict(str: int or str)) Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger.</>
  • grad_norm(norm_type) (dict(str: float)) Compute each parameter's gradient's norm and their overall norm.</>
  • half() (Module) Casts all floating point parameters and buffers to half datatype.</>
  • load_from_checkpoint(checkpoint_path, map_location, hparams_file, strict, **kwargs) Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under hyper_parameters</>
  • load_state_dict(state_dict, strict) (``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields) Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.</>
  • log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, tbptt_reduce_fx, tbptt_pad_token, enable_graph, sync_dist, sync_dist_op, sync_dist_group) Log a key, value</>
  • log_dict(dictionary, prog_bar, logger, on_step, on_epoch, reduce_fx, tbptt_reduce_fx, tbptt_pad_token, enable_graph, sync_dist, sync_dist_op, sync_dist_group) Log a dictonary of values at once</>
  • loss_function(logits, labels) Calculate the loss</>
  • manual_backward(loss, optimizer, *args, **kwargs) Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you</>
  • modules() (Module) Returns an iterator over all modules in the network.</>
  • named_buffers(prefix, recurse) (string, torch.Tensor) Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.</>
  • named_children() (string, Module) Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.</>
  • named_modules(memo, prefix) (string, Module) Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.</>
  • named_parameters(prefix, recurse) (string, Parameter) Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.</>
  • on_after_backward() Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.</>
  • on_before_zero_grad(optimizer) Called after optimizer.step() and before optimizer.zero_grad().</>
  • on_epoch_end() Keep the epoch progress bar This is not documented but working.</>
  • on_epoch_start() Called in the training loop at the very beginning of the epoch.</>
  • on_fit_end() Called at the very end of fit. If on DDP it is called on every process</>
  • on_fit_start() Called at the very beginning of fit. If on DDP it is called on every process</>
  • on_hpc_load(checkpoint) Hook to do whatever you need right before Slurm manager loads the model.</>
  • on_hpc_save(checkpoint) Hook to do whatever you need right before Slurm manager saves the model.</>
  • on_load_checkpoint(checkpoint) Do something with the checkpoint. Gives model a chance to load something before state_dict is restored.</>
  • on_pretrain_routine_end() Called at the end of the pretrain routine (between fit and train start).</>
  • on_pretrain_routine_start() Called at the beginning of the pretrain routine (between fit and train start).</>
  • on_save_checkpoint(checkpoint) Give the model a chance to add something to the checkpoint. state_dict is already there.</>
  • on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) Called in the test loop after the batch.</>
  • on_test_batch_start(batch, batch_idx, dataloader_idx) Called in the test loop before anything happens for that batch.</>
  • on_test_epoch_end() Called in the test loop at the very end of the epoch.</>
  • on_test_epoch_start() Called in the test loop at the very beginning of the epoch.</>
  • on_test_model_eval() Sets the model to eval during the test loop</>
  • on_test_model_train() Sets the model to train during the test loop</>
  • on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) Called in the training loop after the batch.</>
  • on_train_batch_start(batch, batch_idx, dataloader_idx) Called in the training loop before anything happens for that batch.</>
  • on_train_end() Called at the end of training before logger experiment is closed.</>
  • on_train_epoch_end(outputs) Called in the training loop at the very end of the epoch.</>
  • on_train_epoch_start() Called in the training loop at the very beginning of the epoch.</>
  • on_train_start() Called at the beginning of training before sanity check.</>
  • on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) Called in the validation loop after the batch.</>
  • on_validation_batch_start(batch, batch_idx, dataloader_idx) Called in the validation loop before anything happens for that batch.</>
  • on_validation_epoch_end() Called in the validation loop at the very end of the epoch.</>
  • on_validation_epoch_start() Called in the validation loop at the very beginning of the epoch.</>
  • on_validation_model_eval() Sets the model to eval during the val loop</>
  • on_validation_model_train() Sets the model to train during the val loop</>
  • optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs) Override this method to adjust the default way the :class:~pytorch_lightning.trainer.trainer.Trainer calls each optimizer. By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer.</>
  • parameters(recurse) (Parameter) Returns an iterator over module parameters.</>
  • prepare_data() Use this to download and prepare data.</>
  • print(*args, **kwargs) Prints only from process 0. Use this in any distributed mode to log only once.</>
  • register_backward_hook(hook) Registers a backward hook on the module.</>
  • register_buffer(name, tensor) Adds a persistent buffer to the module.</>
  • register_forward_hook(hook) Registers a forward hook on the module.</>
  • register_forward_pre_hook(hook) Registers a forward pre-hook on the module.</>
  • register_parameter(name, param) Adds a parameter to the module.</>
  • requires_grad_(requires_grad) (Module) Change if autograd should record operations on parameters in this module.</>
  • save_hyperparameters(*args, frame) Save all model arguments.</>
  • setup(stage) Called at the beginning of fit and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.</>
  • state_dict(destination, prefix, keep_vars) (dict) Returns a dictionary containing a whole state of the module.</>
  • tbptt_split_batch(batch, split_size) (list) When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.</>
  • teardown(stage) Called at the end of fit and test.</>
  • test_dataloader() (Union(dataloader, list of dataloader)) Implement one or multiple PyTorch DataLoaders for testing.</>
  • test_epoch_end(outputs) Called at the end of a test epoch with the output of all test steps.</>
  • test_step(*args, **kwargs) Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy.</>
  • test_step_end(*args, **kwargs) Use this when testing with dp or ddp2 because :meth:test_step will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.</>
  • to(*args, **kwargs) (Module) Moves and/or casts the parameters and buffers.</>
  • to_onnx(file_path, input_sample, **kwargs) Saves the model in ONNX format</>
  • to_torchscript(file_path, method, example_inputs, **kwargs) (Union(scriptmodule, dict(str: scriptmodule))) By default compiles the whole model to a :class:~torch.jit.ScriptModule. If you want to use tracing, please provided the argument method='trace' and make sure that either the example_inputs argument is provided, or the model has self.example_input_array set. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.</>
  • toggle_optimizer(optimizer, optimizer_idx) Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.</>
  • train(mode) (Module) Sets the module in training mode.</>
  • train_dataloader() (DataLoader) Implement a PyTorch DataLoader for training.</>
  • training_epoch_end(outputs) Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs for every training_step.</>
  • training_step(*args, **kwargs) Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.</>
  • training_step_end(*args, **kwargs) Use this when training with dp or ddp2 because :meth:training_step will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.</>
  • transfer_batch_to_device(batch, device) (any) Override this hook if your :class:~torch.utils.data.DataLoader returns tensors wrapped in a custom data structure.</>
  • type(dst_type) (Module) Casts all parameters and buffers to :attr:dst_type.</>
  • unfreeze() Unfreeze all parameters for training.</>
  • val_dataloader() (Union(dataloader, list of dataloader)) Implement one or multiple PyTorch DataLoaders for validation.</>
  • validation_epoch_end(outputs) Called at the end of the validation epoch with the outputs of all validation steps.</>
  • validation_step(*args, **kwargs) Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.</>
  • validation_step_end(*args, **kwargs) Use this when validating with dp or ddp2 because :meth:validation_step will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.</>
  • zero_grad() Sets gradients of all model parameters to zero.</>
method

register_buffer(name, tensor)

Adds a persistent buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the persistent state.

Buffers can be accessed as attributes using given names.

Parameters
  • name (string) name of the buffer. The buffer can be accessed from this module using the given name
  • tensor (Tensor) buffer to be registered.

Example::

>>> self.register_buffer('running_mean', torch.zeros(num_features))
method

register_parameter(name, param)

Adds a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters
  • name (string) name of the parameter. The parameter can be accessed from this module using the given name
  • param (Parameter) parameter to be added to the module.
method

add_module(name, module)

Adds a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters
  • name (string) name of the child module. The child module can be accessed from this module using the given name
  • module (Module) child module to be added to the module.
method

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).

Parameters
  • fn ( class:Module -> None): function to be applied to each submodule
Returns (Module)

self

Example::

>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.data.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1.,  1.],
        [ 1.,  1.]])
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1.,  1.],
        [ 1.,  1.]])
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
method

register_backward_hook(hook)

Registers a backward hook on the module.

The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature::

hook(module, grad_input, grad_output) -> Tensor or None

The :attr:grad_input and :attr:grad_output may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of :attr:grad_input in subsequent computations.

Returns

class:torch.utils.hooks.RemovableHandle: a handle that can be used to remove the added hook by calling handle.remove()

.. warning ::

The current implementation will not have the presented behavior
for complex :class:`Module` that perform many operations.
In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
contain the gradients for a subset of the inputs and outputs.
For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
directly on a specific input or output to get the required gradients.
method

register_forward_pre_hook(hook)

Registers a forward pre-hook on the module.

The hook will be called every time before :func:forward is invoked. It should have the following signature::

hook(module, input) -> None or modified input

The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple).

Returns

class:torch.utils.hooks.RemovableHandle: a handle that can be used to remove the added hook by calling handle.remove()

method

register_forward_hook(hook)

Registers a forward hook on the module.

The hook will be called every time after :func:forward has computed an output. It should have the following signature::

hook(module, input, output) -> None or modified output

The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:forward is called.

Returns

class:torch.utils.hooks.RemovableHandle: a handle that can be used to remove the added hook by calling handle.remove()

method

state_dict(destination=None, prefix='', keep_vars=False)

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

Returns (dict)

e

Example::

>>> module.state_dict().keys()
['bias', 'weight']
method

load_state_dict(state_dict, strict=True)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

Parameters
  • state_dict (dict) a dict containing parameters and persistent buffers.
  • strict (bool, optional) whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's :meth:~torch.nn.Module.state_dict function. Default: True
Returns (``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields)

s s

generator

parameters(recurse=True)

Returns an iterator over module parameters.

This is typically passed to an optimizer.

Parameters
  • recurse (bool) if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Yields (Parameter)

module parameter

Example::

>>> for param in model.parameters():
>>>     print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
generator

named_parameters(prefix='', recurse=True)

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters
  • prefix (str) prefix to prepend to all parameter names.
  • recurse (bool) if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Yields (string, Parameter)

Tuple containing the name and parameter

Example::

>>> for name, param in self.named_parameters():
>>>    if name in ['bias']:
>>>        print(param.size())
generator

buffers(recurse=True)

Returns an iterator over module buffers.

Parameters
  • recurse (bool) if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
Yields (torch.Tensor)

module buffer

Example::

>>> for buf in model.buffers():
>>>     print(type(buf.data), buf.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
generator

named_buffers(prefix='', recurse=True)

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters
  • prefix (str) prefix to prepend to all buffer names.
  • recurse (bool) if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
Yields (string, torch.Tensor)

Tuple containing the name and buffer

Example::

>>> for name, buf in self.named_buffers():
>>>    if name in ['running_var']:
>>>        print(buf.size())
generator

children()

Returns an iterator over immediate children modules.

Yields (Module)

a child module

generator

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields (string, Module)

Tuple containing a name and child module

Example::

>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
generator

modules()

Returns an iterator over all modules in the network.

Yields (Module)

a module in the network

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example::

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
        print(idx, '->', m)

0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)

generator

named_modules(memo=None, prefix='')

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Yields (string, Module)

Tuple of name and module

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example::

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
        print(idx, '->', m)

0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

method

train(mode=True)

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

Parameters
  • mode (bool) whether to set training mode (True) or evaluation mode (False). Default: True.
Returns (Module)

self

method

eval()

Sets the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

This is equivalent with :meth:self.train(False) <torch.nn.Module.train>.

Returns (Module)

self

method

requires_grad_(requires_grad=True)

Change if autograd should record operations on parameters in this module.

This method sets the parameters' :attr:requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

Parameters
  • requires_grad (bool) whether autograd should record operations on parameters in this module. Default: True.
Returns (Module)

self

method

zero_grad()

Sets gradients of all model parameters to zero.

method

extra_repr()

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

method

setup(stage)

Called at the beginning of fit and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters
  • stage (str) either 'fit' or 'test'

Example::

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(stage):
        data = Load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
method

teardown(stage)

Called at the end of fit and test.

Parameters
  • stage (str) either 'fit' or 'test'
method

on_fit_start()

Called at the very beginning of fit. If on DDP it is called on every process

method

on_fit_end()

Called at the very end of fit. If on DDP it is called on every process

method

on_train_start()

Called at the beginning of training before sanity check.

method

on_train_end()

Called at the end of training before logger experiment is closed.

method

on_pretrain_routine_start()

Called at the beginning of the pretrain routine (between fit and train start).

  • fit
  • pretrain_routine start
  • pretrain_routine end
  • training_start
method

on_pretrain_routine_end()

Called at the end of the pretrain routine (between fit and train start).

  • fit
  • pretrain_routine start
  • pretrain_routine end
  • training_start
method

on_train_batch_start(batch, batch_idx, dataloader_idx)

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters
  • batch (any) The batched data as it is returned by the training DataLoader.
  • batch_idx (int) the index of the batch
  • dataloader_idx (int) the index of the dataloader
method

on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called in the training loop after the batch.

Parameters
  • outputs (any) The outputs of training_step_end(training_step(x))
  • batch (any) The batched data as it is returned by the training DataLoader.
  • batch_idx (int) the index of the batch
  • dataloader_idx (int) the index of the dataloader
method

on_validation_model_eval()

Sets the model to eval during the val loop

method

on_validation_model_train()

Sets the model to train during the val loop

method

on_validation_batch_start(batch, batch_idx, dataloader_idx)

Called in the validation loop before anything happens for that batch.

Parameters
  • batch (any) The batched data as it is returned by the validation DataLoader.
  • batch_idx (int) the index of the batch
  • dataloader_idx (int) the index of the dataloader
method

on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called in the validation loop after the batch.

Parameters
  • outputs (any) The outputs of validation_step_end(validation_step(x))
  • batch (any) The batched data as it is returned by the validation DataLoader.
  • batch_idx (int) the index of the batch
  • dataloader_idx (int) the index of the dataloader
method

on_test_batch_start(batch, batch_idx, dataloader_idx)

Called in the test loop before anything happens for that batch.

Parameters
  • batch (any) The batched data as it is returned by the test DataLoader.
  • batch_idx (int) the index of the batch
  • dataloader_idx (int) the index of the dataloader
method

on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called in the test loop after the batch.

Parameters
  • outputs (any) The outputs of test_step_end(test_step(x))
  • batch (any) The batched data as it is returned by the test DataLoader.
  • batch_idx (int) the index of the batch
  • dataloader_idx (int) the index of the dataloader
method

on_test_model_eval()

Sets the model to eval during the test loop

method

on_test_model_train()

Sets the model to train during the test loop

method

on_epoch_start()

Called in the training loop at the very beginning of the epoch.

method

on_train_epoch_start()

Called in the training loop at the very beginning of the epoch.

method

on_train_epoch_end(outputs)

Called in the training loop at the very end of the epoch.

method

on_validation_epoch_start()

Called in the validation loop at the very beginning of the epoch.

method

on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

method

on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

method

on_test_epoch_end()

Called in the test loop at the very end of the epoch.

method

on_before_zero_grad(optimizer)

Called after optimizer.step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called::

for optimizer in optimizers:
    optimizer.step()
    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()
Parameters
  • optimizer (Optimizer) The optimizer for which grads should be zeroed.
method

on_after_backward()

Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.

Example::

def on_after_backward(self):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        params = self.state_dict()
        for k, v in params.items():
            grads = v
            name = k
            self.logger.experiment.add_histogram(tag=name, values=grads,
                                                 global_step=self.trainer.global_step)
method

prepare_data()

Use this to download and prepare data.

.. warning:: DO NOT set state to the model (use setup instead) since this is NOT called on every GPU in DDP/TPU

Example::

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.
  2. Once in total. Only called on GLOBAL_RANK=0.

Example::

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)

# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)

This is called before requesting the dataloaders:

.. code-block:: python

model.prepare_data()
    if ddp/tpu: init()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
method

train_dataloader() → DataLoader

Implement a PyTorch DataLoader for training.

Return: Single PyTorch :class:~torch.utils.data.DataLoader.

The dataloader you return will not be called every epoch unless you set :paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch to True.

For data processing use the following pattern:

- download in :meth:`prepare_data`
- process and split in :meth:`setup`

However, the above are only necessary for distributed processing.

.. warning:: do not assign state in prepare_data

  • :meth:~pytorch_lightning.trainer.Trainer.fit
  • ...
  • :meth:prepare_data
  • :meth:setup
  • :meth:train_dataloader

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example

.. code-block:: python

def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader
method

test_dataloader() → Union(dataloader, list of dataloader)

Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be called every epoch unless you set :paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch to True.

For data processing use the following pattern:

- download in :meth:`prepare_data`
- process and split in :meth:`setup`

However, the above are only necessary for distributed processing.

.. warning:: do not assign state in prepare_data

  • :meth:~pytorch_lightning.trainer.Trainer.fit
  • ...
  • :meth:prepare_data
  • :meth:setup
  • :meth:train_dataloader
  • :meth:val_dataloader
  • :meth:test_dataloader

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return: Single or multiple PyTorch DataLoaders.

Example

.. code-block:: python

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don't need a test dataset and a :meth:test_step, you don't need to implement this method.

Note

In the case where you return multiple test dataloaders, the :meth:test_step will have an argument dataloader_idx which matches the order here.

method

val_dataloader() → Union(dataloader, list of dataloader)

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be called every epoch unless you set :paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch to True.

It's recommended that all data downloads and preparation happen in :meth:prepare_data.

  • :meth:~pytorch_lightning.trainer.Trainer.fit
  • ...
  • :meth:prepare_data
  • :meth:train_dataloader
  • :meth:val_dataloader
  • :meth:test_dataloader

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return: Single or multiple PyTorch DataLoaders.

Examples

.. code-block:: python

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don't need a validation dataset and a :meth:validation_step, you don't need to implement this method.

Note

In the case where you return multiple validation dataloaders, the :meth:validation_step will have an argument dataloader_idx which matches the order here.

method

transfer_batch_to_device(batch, device=None)

Override this hook if your :class:~torch.utils.data.DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

  • :class:torch.Tensor or anything that implements .to(...)
  • :class:list
  • :class:dict
  • :class:tuple
  • :class:torchtext.data.batch.Batch

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).

Example::

def transfer_batch_to_device(self, batch, device)
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    else:
        batch = super().transfer_batch_to_device(data, device)
    return batch
Parameters
  • batch (any) A batch of data that needs to be transferred to a new device.
  • device (device, optional) The target device as defined in PyTorch.
Returns (any)

A reference to the data on the new device.

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).

Note

This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support for your custom batch objects, you need to define your custom :class:~torch.nn.parallel.DistributedDataParallel or :class:~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel and override :meth:~pytorch_lightning.core.lightning.LightningModule.configure_ddp.

See Also
  • :func:~pytorch_lightning.utilities.apply_func.move_data_to_device
  • :func:~pytorch_lightning.utilities.apply_func.apply_to_collection
classmethod

load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under hyper_parameters

Any arguments specified through *args and **kwargs will override args stored in hyper_parameters.

Parameters
  • checkpoint_path (str or IO) Path to checkpoint. This can also be a URL, or file-like object
  • map_location (Union(dict(str: str), str, device, int, callable, nonetype), optional) If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:torch.load.
  • hparams_file (str, optional) Optional path to a .yaml file with hierarchical structure as in this example::
    drop_prob: 0.2
    dataloader:
        batch_size: 32
    
    You most likely won't need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you'd like to use. These will be converted into a :class:~dict and passed into your :class:LightningModule for use.
    If your model's hparams argument is :class:~argparse.Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as :class:~dict.
  • strict (bool, optional) Whether to strictly enforce that the keys in :attr:checkpoint_path match the keys returned by this module's state dict. Default: True.
  • kwargs Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Return: :class:LightningModule with loaded weights and hyperparameters (if available).

Example

.. code-block:: python

# load weights without mapping ...
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path: NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
method

on_load_checkpoint(checkpoint)

Do something with the checkpoint. Gives model a chance to load something before state_dict is restored.

Parameters
  • checkpoint (dict(str: any)) A dictionary with variables from the checkpoint.
method

on_save_checkpoint(checkpoint)

Give the model a chance to add something to the checkpoint. state_dict is already there.

Parameters
  • checkpoint (dict(str: any)) A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
method

on_hpc_save(checkpoint)

Hook to do whatever you need right before Slurm manager saves the model.

Parameters
  • checkpoint (dict(str: any)) A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
method

on_hpc_load(checkpoint)

Hook to do whatever you need right before Slurm manager loads the model.

Parameters
  • checkpoint (dict(str: any)) A dictionary with variables from the checkpoint.
method

grad_norm(norm_type) → dict(str: float)

Compute each parameter's gradient's norm and their overall norm.

The overall norm is computed over all gradients together, as if they were concatenated into a single vector.

Parameters
  • norm_type (float, int, or str) The type of the used p-norm, cast to float if necessary. Can be 'inf' for infinity norm.

Return: norms: The dictionary of p-norms of each parameter's gradient and a special entry for the total p-norm of the gradients viewed as a single vector.

method

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) .. function:: to(tensor, non_blocking=False) Its signature is similar to :meth:torch.Tensor.to, but only accepts floating point desired :attr:dtype s. In addition, this method will only cast the floating point parameters and buffers to :attr:dtype (if given). The integral parameters and buffers will be moved :attr:device, if that is given, but with dtypes unchanged. When :attr:non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples.

Note

This method modifies the module in-place.

Parameters
  • device the desired device of the parameters and buffers in this module
  • dtype the desired floating point type of the floating point parameters and buffers in this module
  • tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
Returns (Module)

self

Example::

>>> class ExampleModule(DeviceDtypeModuleMixin):
...     def __init__(self, weight: torch.Tensor):
...         super().__init__()
...         self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.device
device(type='cpu')
>>> module.dtype
torch.float16

method

cuda(device=None)

Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Parameters
  • device (int, optional) if specified, all parameters will be copied to that device
Returns (Module)

self

method

cpu()

Moves all model parameters and buffers to the CPU.

Returns (Module)

self

method

type(dst_type)

Casts all parameters and buffers to :attr:dst_type.

Parameters
  • dst_type (type or string) the desired type
Returns (Module)

self

method

float()

Casts all floating point parameters and buffers to float datatype.

Returns (Module)

self

method

double()

Casts all floating point parameters and buffers to double datatype.

Returns (Module)

self

method

half()

Casts all floating point parameters and buffers to half datatype.

Returns (Module)

self

class

abc.ABCMeta(name, bases, namespace, **kwargs)

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

Methods
staticmethod
register(cls, subclass)

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

staticmethod
__instancecheck__(cls, instance)

Override for isinstance(instance, cls).

staticmethod
__subclasscheck__(cls, subclass)

Override for issubclass(subclass, cls).

method

print(*args, **kwargs)

Prints only from process 0. Use this in any distributed mode to log only once.

Parameters
  • *args The thing to print. Will be passed to Python's built-in print function.
  • **kwargs Will be passed to Python's built-in print function.
Example

n

: )

method

log(name, value, prog_bar=False, logger=True, on_step=None, on_epoch=None, reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>, tbptt_reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>, tbptt_pad_token=0, enable_graph=False, sync_dist=False, sync_dist_op='mean', sync_dist_group=None)

Log a key, value

Example::

self.log('train_loss', loss)

The default behavior per hook is as follows

.. csv-table:: * also applies to the test loop :header: "LightningMoule Hook", "on_step", "on_epoch", "prog_bar", "logger" :widths: 20, 10, 10, 10, 10

"training_step", "T", "F", "F", "T" "training_step_end", "T", "F", "F", "T" "training_epoch_end", "F", "T", "F", "T" "validation_step", "F", "T", "F", "T" "validation_step_end", "F", "T", "F", "T" "validation_epoch_end*", "F", "T", "F", "T"

Parameters
  • name (str) key name
  • value (any) value name
  • prog_bar (bool, optional) if True logs to the progress bar
  • logger (bool, optional) if True logs to the logger
  • on_step (bool, optional) if True logs at this step. None auto-logs at the training_step but not validation/test_step
  • on_epoch (bool, optional) if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
  • reduce_fx (callable, optional) reduction function over step values for end of epoch. Torch.mean by default
  • tbptt_reduce_fx (callable, optional) function to reduce on truncated back prop
  • tbptt_pad_token (int, optional) token to use for padding
  • enable_graph (bool, optional) if True, will not auto detach the graph
  • sync_dist (bool, optional) if True, reduces the metric across GPUs/TPUs
  • sync_dist_op (any or str, optional) the op to sync across GPUs/TPUs
  • sync_dist_group (any, optional) the ddp group
method

log_dict(dictionary, prog_bar=False, logger=True, on_step=None, on_epoch=None, reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>, tbptt_reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>, tbptt_pad_token=0, enable_graph=False, sync_dist=False, sync_dist_op='mean', sync_dist_group=None)

Log a dictonary of values at once

Example::

values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
self.log_dict(values)
Parameters
  • dictionary (dict) key value pairs (str, tensors)
  • prog_bar (bool, optional) if True logs to the progress base
  • logger (bool, optional) if True logs to the logger
  • on_step (bool, optional) if True logs at this step. None auto-logs for training_step but not validation/test_step
  • on_epoch (bool, optional) if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
  • reduce_fx (callable, optional) reduction function over step values for end of epoch. Torch.mean by default
  • tbptt_reduce_fx (callable, optional) function to reduce on truncated back prop
  • tbptt_pad_token (int, optional) token to use for padding
  • enable_graph (bool, optional) if True, will not auto detach the graph
  • sync_dist (bool, optional) if True, reduces the metric across GPUs/TPUs
  • sync_dist_op (any or str, optional) the op to sync across GPUs/TPUs
  • sync_dist_group (any, optional) the ddp group:
method

all_gather(tensor, group=None, sync_grads=False)

Allows users to call self.all_gather() from the LightningModule, thus making the all_gather operation accelerator agnostic.

all_gather is a function provided by accelerators to gather a tensor from several distributed processes

Parameters
  • tensor (Tensor) tensor of shape (batch, ...)
  • group (any, optional) the process group to gather results from. Defaults to all processes (world)
  • sync_grads (bool, optional) flag that allows users to synchronize gradients for all_gather op

Return: A tensor of shape (world_size, batch, ...)

method

forward(*args, **kwargs)

Same as :meth:torch.nn.Module.forward(), however in Lightning you want this to define the operations you want to use for prediction (i.e.: on a server or as a feature extractor).

Normally you'd call self() from your :meth:training_step method. This makes it easy to write a complex system for training with the outputs you'd want in a prediction setting.

You may also find the :func:~pytorch_lightning.core.decorators.auto_move_data decorator useful when using the module outside Lightning in a production setting.

Parameters
  • *args Whatever you decide to pass into the forward method.
  • **kwargs Keyword arguments are also possible.

Return: Predicted output

Examples

.. code-block:: python

# example if we were using this model as a feature extractor
def forward(self, x):
    feature_maps = self.convnet(x)
    return feature_maps

def training_step(self, batch, batch_idx):
    x, y = batch
    feature_maps = self(x)
    logits = self.classifier(feature_maps)

    # ...
    return loss

# splitting it this way allows model to be used a feature extractor
model = MyModelAbove()

inputs = server.get_request()
results = model(inputs)
server.write_results(results)

# -------------
# This is in stark contrast to torch.nn.Module where normally you would have this:
def forward(self, batch):
    x, y = batch
    feature_maps = self.convnet(x)
    logits = self.classifier(feature_maps)
    return logits
method

training_step(*args, **kwargs)

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters
  • batch ( class:~torch.Tensor | (:class:~torch.Tensor, ...) | [:class:~torch.Tensor, ...]): The output of your :class:~torch.utils.data.DataLoader. A tensor, tuple or list.
  • batch_idx (int) Integer displaying index of this batch
  • optimizer_idx (int) When using multiple optimizers, this argument will also be present.
  • hiddens( class:~torch.Tensor): Passed in if :paramref:~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps > 0.

Return: Any of.

- :class:`~torch.Tensor` - The loss tensor
- `dict` - A dictionary. Can include any keys, but must include the key 'loss'
- `None` - Training will skip to the next batch

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

.. code-block:: python

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

.. code-block:: python

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    ...
    out, hiddens = self.lstm(data, hiddens)
    ...
    return {'loss': loss, 'hiddens': hiddens}

Note

The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.

method

training_step_end(*args, **kwargs)

Use this when training with dp or ddp2 because :meth:training_step will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.

Note

If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code

.. code-block:: python

# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
training_step_end(batch_parts_outputs)
Parameters
  • batch_parts_outputs What you return in training_step for each batch part.

Return: Anything

When using dp/ddp2 distributed backends, only a portion of the batch is inside the training_step:

.. code-block:: python

def training_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)

    # softmax uses only a portion of the batch in the denomintaor
    loss = self.softmax(out)
    loss = nce_loss(loss)
    return loss

If you wish to do something with all the parts of the batch, then use this method to do it:

.. code-block:: python

def training_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    return {'pred': out}

def training_step_end(self, training_step_outputs):
    gpu_0_pred = training_step_outputs[0]['pred']
    gpu_1_pred = training_step_outputs[1]['pred']
    gpu_n_pred = training_step_outputs[n]['pred']

    # this softmax now uses the full batch
    loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred])
    return loss
See Also

See the :ref:multi_gpu guide for more details.

method

training_epoch_end(outputs)

Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs for every training_step.

.. code-block:: python

# the pseudocode for these calls
train_outs = []
for train_batch in train_data:
    out = training_step(train_batch)
    train_outs.append(out)
training_epoch_end(train_outs)
Parameters
  • outputs (list of any) List of outputs you defined in :meth:training_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.

Return: None

Note

If this method is not overridden, this won't be called.

Example::

def training_epoch_end(self, training_step_outputs):
    # do something with all training_step outputs
    return result

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each training step for that dataloader.

.. code-block:: python

def training_epoch_end(self, training_step_outputs):
    for out in training_step_outputs:
        # do something here
method

validation_step(*args, **kwargs)

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

.. code-block:: python

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    val_outs.append(out)
    validation_epoch_end(val_outs)
Parameters
  • batch ( class:~torch.Tensor | (:class:~torch.Tensor, ...) | [:class:~torch.Tensor, ...]): The output of your :class:~torch.utils.data.DataLoader. A tensor, tuple or list.
  • batch_idx (int) The index of this batch
  • dataloader_idx (int) The index of the dataloader that produced this batch (only if multiple val datasets used)

Return: Any of.

- Any object or value
- `None` - Validation will skip to the next batch

.. code-block:: python

# pseudocode of order
out = validation_step()
if defined('validation_step_end'):
    out = validation_step_end(out)
out = validation_epoch_end(out)

.. code-block:: python

# if you have one val dataloader:
def validation_step(self, batch, batch_idx)

# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx)
Examples

.. code-block:: python

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val datasets, validation_step will have an additional argument.

.. code-block:: python

# CASE 2: multiple validation datasets
def validation_step(self, batch, batch_idx, dataloader_idx):
    # dataloader_idx tells you which dataset this is.

Note

If you don't need to validate you don't need to implement this method.

Note

When the :meth:validation_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

method

validation_step_end(*args, **kwargs)

Use this when validating with dp or ddp2 because :meth:validation_step will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.

Note

If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code.

.. code-block:: python

# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches]
validation_step_end(batch_parts_outputs)
Parameters
  • batch_parts_outputs What you return in :meth:validation_step for each batch part.

Return: None or anything

.. code-block:: python

# WITHOUT validation_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def validation_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    loss = self.softmax(out)
    loss = nce_loss(loss)
    self.log('val_loss', loss)

# --------------
# with validation_step_end to do softmax over the full batch
def validation_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)
    return out

def validation_epoch_end(self, val_step_outputs):
    for out in val_step_outputs:
        # do something with these
See Also

See the :ref:multi_gpu guide for more details.

method

validation_epoch_end(outputs)

Called at the end of the validation epoch with the outputs of all validation steps.

.. code-block:: python

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    val_outs.append(out)
validation_epoch_end(val_outs)
Parameters
  • outputs (list of any) List of outputs you defined in :meth:validation_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.

Return: None

Note

If you didn't define a :meth:validation_step, this won't be called.

Examples

With a single dataloader:

.. code-block:: python

def validation_epoch_end(self, val_step_outputs):
    for out in val_step_outputs:
        # do something

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.

.. code-block:: python

def validation_epoch_end(self, outputs):
    for dataloader_output_result in outputs:
        dataloader_outs = dataloader_output_result.dataloader_i_outputs

    self.log('final_metric', final_value)
method

test_step(*args, **kwargs)

Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy.

.. code-block:: python

# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
    out = test_step(test_batch)
    test_outs.append(out)
test_epoch_end(test_outs)
Parameters
  • batch ( class:~torch.Tensor | (:class:~torch.Tensor, ...) | [:class:~torch.Tensor, ...]): The output of your :class:~torch.utils.data.DataLoader. A tensor, tuple or list.
  • batch_idx (int) The index of this batch.
  • dataloader_idx (int) The index of the dataloader that produced this batch (only if multiple test datasets used).

Return: Any of.

- Any object or value
- `None` - Testing will skip to the next batch

.. code-block:: python

# if you have one test dataloader:
def test_step(self, batch, batch_idx)

# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx)
Examples

.. code-block:: python

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple validation datasets, :meth:test_step will have an additional argument.

.. code-block:: python

# CASE 2: multiple test datasets
def test_step(self, batch, batch_idx, dataloader_idx):
    # dataloader_idx tells you which dataset this is.

Note

If you don't need to validate you don't need to implement this method.

Note

When the :meth:test_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

method

test_step_end(*args, **kwargs)

Use this when testing with dp or ddp2 because :meth:test_step will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.

Note

If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code.

.. code-block:: python

# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches]
test_step_end(batch_parts_outputs)
Parameters
  • batch_parts_outputs What you return in :meth:test_step for each batch part.

Return: None or anything

.. code-block:: python

# WITHOUT test_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def test_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)
    loss = self.softmax(out)
    self.log('test_loss', loss)

# --------------
# with test_step_end to do softmax over the full batch
def test_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    return out

def test_epoch_end(self, output_results):
    # this out is now the full size of the batch
    all_test_step_outs = output_results.out
    loss = nce_loss(all_test_step_outs)
    self.log('test_loss', loss)
See Also

See the :ref:multi_gpu guide for more details.

method

test_epoch_end(outputs)

Called at the end of a test epoch with the output of all test steps.

.. code-block:: python

# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
    out = test_step(test_batch)
    test_outs.append(out)
test_epoch_end(test_outs)
Parameters
  • outputs (list of any) List of outputs you defined in :meth:test_step_end, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader

Return: None

Note

If you didn't define a :meth:test_step, this won't be called.

Examples

With a single dataloader:

.. code-block:: python

def test_epoch_end(self, outputs):
    # do something with the outputs of all test batches
    all_test_preds = test_step_outputs.predictions

    some_result = calc_all_results(all_test_preds)
    self.log(some_result)

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each test step for that dataloader.

.. code-block:: python

def test_epoch_end(self, outputs):
    final_value = 0
    for dataloader_outputs in outputs:
        for test_step_out in dataloader_outputs:
            # do something
            final_value += test_step_out

    self.log('final_metric', final_value)
method

manual_backward(loss, optimizer, *args, **kwargs)

Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you

This function forwards all args to the .backward() call as well.

.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set

.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set and you use optimizer.step()

Example::

def training_step(...):
    (opt_a, opt_b) = self.optimizers()
    loss = ...
    # automatically applies scaling, etc...
    self.manual_backward(loss, opt_a)
    opt_a.step()
method

backward(loss, optimizer, optimizer_idx, *args, **kwargs)

Override backward with your own implementation if you need to.

Parameters
  • loss (Tensor) Loss is already scaled by accumulated grads
  • optimizer (Optimizer) Current optimizer being used
  • optimizer_idx (int) Index of the current optimizer being used

Called to perform backward step. Feel free to override as needed. The loss passed in has already been scaled for accumulated gradients if requested.

Example::

def backward(self, loss, optimizer, optimizer_idx):
    loss.backward()
method

toggle_optimizer(optimizer, optimizer_idx)

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

.. note:: Only called when using multiple optimizers

Override for your own behavior

method

optimizer_step(epoch=None, batch_idx=None, optimizer=None, optimizer_idx=None, optimizer_closure=None, on_tpu=None, using_native_amp=None, using_lbfgs=None)

Override this method to adjust the default way the :class:~pytorch_lightning.trainer.trainer.Trainer calls each optimizer. By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer.

.. tip:: With Trainer(enable_pl_optimizer=True), you can user optimizer.step() directly and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.

Warning

If you are overriding this method, make sure that you pass the optimizer_closure parameter to optimizer.step() function as shown in the examples. This ensures that train_step_and_backward_closure is called within :meth:~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch.

Parameters
  • epoch (int, optional) Current epoch
  • batch_idx (int, optional) Index of current batch
  • optimizer (Optimizer, optional) A PyTorch optimizer
  • optimizer_idx (int, optional) If you used multiple optimizers this indexes into that list.
  • optimizer_closure (callable, optional) closure for all optimizers
  • on_tpu (bool, optional) true if TPU backward is required
  • using_native_amp (bool, optional) True if using native amp
  • using_lbfgs (bool, optional) True if the matching optimizer is lbfgs
Examples

.. code-block:: python

# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    # update generator opt every 2 steps
    if optimizer_idx == 0:
        if batch_idx % 2 == 0 :
            optimizer.step(closure=optimizer_closure)
            optimizer.zero_grad()

    # update discriminator opt every 4 steps
    if optimizer_idx == 1:
        if batch_idx % 4 == 0 :
            optimizer.step(closure=optimizer_closure)
            optimizer.zero_grad()

    # ...
    # add as many optimizers as you want

s :

n

p , : r : ) : e

s ) )

method

tbptt_split_batch(batch, split_size) → list

When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.

Parameters
  • batch (Tensor) Current batch
  • split_size (int) The size of the split

Return: List of batch splits. Each split will be passed to :meth:training_step to enable truncated back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.

Examples

.. code-block:: python

def tbptt_split_batch(self, batch, split_size):
  splits = []
  for t in range(0, time_dims[0], split_size):
      batch_split = []
      for i, x in enumerate(batch):
          if isinstance(x, torch.Tensor):
              split_x = x[:, t:t + split_size]
          elif isinstance(x, collections.Sequence):
              split_x = [None] * len(x)
              for batch_idx in range(len(x)):
                  split_x[batch_idx] = x[batch_idx][t:t + split_size]

          batch_split.append(split_x)

      splits.append(batch_split)

  return splits

Note

Called in the training loop after :meth:~pytorch_lightning.callbacks.base.Callback.on_batch_start if :paramref:~pytorch_lightning.trainer.Trainer.truncated_bptt_steps > 0. Each returned batch split is passed separately to :meth:training_step.

method

freeze()

Freeze all params for inference.

Example

.. code-block:: python

model = MyLightningModule(...)
model.freeze()
method

unfreeze()

Unfreeze all parameters for training.

.. code-block:: python

model = MyLightningModule(...)
model.unfreeze()
method

get_progress_bar_dict() → dict(str: int or str)

Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger.

.. code-block::

Epoch 1:   4%|▎         | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]

Here is an example how to override the defaults:

.. code-block:: python

def get_progress_bar_dict(self):
    # don't show the version number
    items = super().get_progress_bar_dict()
    items.pop("v_num", None)
    return items

Return: Dictionary with the items to be displayed in the progress bar.

method

save_hyperparameters(*args, frame=None)

Save all model arguments.

Parameters
  • args single object of dict, NameSpace or OmegaConf or string names or argumenst from class __init__
>>> from collections import OrderedDict
>>> class ManuallyArgsModel(LightningModule):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # manually assign arguments
...         self.save_hyperparameters('arg1', 'arg3')
...     def forward(self, *args, **kwargs):
...         ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
>>> class AutomaticArgsModel(LightningModule):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # equivalent automatic
...         self.save_hyperparameters()
...     def forward(self, *args, **kwargs):
...         ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> class SingleArgModel(LightningModule):
...     def __init__(self, params):
...         super().__init__()
...         # manually assign single argument
...         self.save_hyperparameters(params)
...     def forward(self, *args, **kwargs):
...         ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
method

to_onnx(file_path, input_sample=None, **kwargs)

Saves the model in ONNX format

Parameters
  • file_path (str or Path) The path of the file the onnx model should be saved to.
  • input_sample (any, optional) An input for tracing. Default: None (Use self.example_input_array)
  • **kwargs Will be passed to torch.onnx.export function.
Example
>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
...     model = SimpleModel()
...     input_sample = torch.randn((1, 64))
...     model.to_onnx(tmpfile.name, input_sample, export_params=True)
...     os.path.isfile(tmpfile.name)
True
method

to_torchscript(file_path=None, method='script', example_inputs=None, **kwargs) → Union(scriptmodule, dict(str: scriptmodule))

By default compiles the whole model to a :class:~torch.jit.ScriptModule. If you want to use tracing, please provided the argument method='trace' and make sure that either the example_inputs argument is provided, or the model has self.example_input_array set. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.

Parameters
  • file_path (str, Path, or NoneType, optional) Path where to save the torchscript. Default: None (no file saved).
  • method (str, optional) Whether to use TorchScript's script or trace method. Default: 'script'
  • example_inputs (any, optional) An input to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array)
  • **kwargs Additional arguments that will be passed to the :func:torch.jit.script or :func:torch.jit.trace function.

Note

  • Requires the implementation of the :meth:~pytorch_lightning.core.lightning.LightningModule.forward method.
  • The exported script will be set to evaluation mode.
  • It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. See also the :mod:torch.jit documentation for supported features.
Example
>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt")  # doctest: +SKIP
>>> os.path.isfile("model.pt")  # doctest: +SKIP
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
...                                     example_inputs=torch.randn(1, 64)))  # doctest: +SKIP
>>> os.path.isfile("model_trace.pt")  # doctest: +SKIP
True

Return: This LightningModule as a torchscript, regardless of whether file_path is defined or not.

method

on_epoch_end()

Keep the epoch progress bar This is not documented but working.

method

loss_function(logits, labels)

Calculate the loss

method

configure_optimizers()

Configure the optimizers