plkit.module
plkit.module.
Module
(
config
)
The Module class
on_epoch_end
is added to print a newline to keep the progress bar and the
stats on it for each epoch. If you don't want this, just overwrite it with:
>>> def on_epoch_end(self):
>>> pass
If you have other stuff to do in on_epoch_end
, make sure to you call:
>>> super().on_epoch_end()
You may or may not need to write loss_function
, as it will be inferred
from config item loss
and num_classes
. Basically, MSELoss
will be
used for regression and CrossEntropyLoss
for classification.
measure
added for convinience to get some metrics between logits
and targets.
config
— The configuration dictionary
_loss_func
— The loss functionautomatic_optimization
(bool) — If False you are responsible for calling .backward, .step, zero_grad.</>config
— The configscurrent_epoch
(int) — The current epoch</>global_step
(int) — Total training batches seen across all epochs</>num_classes
— Number of classes to predict. 1 for regressionon_gpu
— True if your model is currently running on GPUs. Useful to set flags around the LightningModule for different CPU vs GPU behavior.</>optim
— The optimizer name. currently onlyadam
andsgd
are supported. With this, of course you can, but you don't need to writeconfigure_optimizers
.
add_module
(
name
,module
)
— Adds a child module to the current module.</>all_gather
(
tensor
,group
,sync_grads
)
— Allows users to callself.all_gather()
from the LightningModule, thus making theall_gather
operation accelerator agnostic.</>apply
(
fn
)
(Module) — Appliesfn
recursively to every submodule (as returned by.children()
) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc
).</>backward
(
loss
,optimizer
,optimizer_idx
,*args
,**kwargs
)
— Override backward with your own implementation if you need to.</>buffers
(
recurse
)
(torch.Tensor) — Returns an iterator over module buffers.</>children
(
)
(Module) — Returns an iterator over immediate children modules.</>configure_optimizers
(
)
— Configure the optimizers</>cpu
(
)
(Module) — Moves all model parameters and buffers to the CPU.</>cuda
(
device
)
(Module) — Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.</>double
(
)
(Module) — Casts all floating point parameters and buffers todouble
datatype.</>eval
(
)
(Module) — Sets the module in evaluation mode.</>extra_repr
(
)
— Set the extra representation of the module</>float
(
)
(Module) — Casts all floating point parameters and buffers to float datatype.</>forward
(
*args
,**kwargs
)
— Same as :meth:torch.nn.Module.forward()
, however in Lightning you want this to define the operations you want to use for prediction (i.e.: on a server or as a feature extractor).</>freeze
(
)
— Freeze all params for inference.</>get_progress_bar_dict
(
)
(dict(str: int or str)) — Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger.</>grad_norm
(
norm_type
)
(dict(str: float)) — Compute each parameter's gradient's norm and their overall norm.</>half
(
)
(Module) — Casts all floating point parameters and buffers tohalf
datatype.</>load_from_checkpoint
(
checkpoint_path
,map_location
,hparams_file
,strict
,**kwargs
)
— Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to__init__
in the checkpoint underhyper_parameters
</>load_state_dict
(
state_dict
,strict
)
(``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields) — Copies parameters and buffers from :attr:state_dict
into this module and its descendants. If :attr:strict
isTrue
, then the keys of :attr:state_dict
must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict
function.</>log
(
name
,value
,prog_bar
,logger
,on_step
,on_epoch
,reduce_fx
,tbptt_reduce_fx
,tbptt_pad_token
,enable_graph
,sync_dist
,sync_dist_op
,sync_dist_group
)
— Log a key, value</>log_dict
(
dictionary
,prog_bar
,logger
,on_step
,on_epoch
,reduce_fx
,tbptt_reduce_fx
,tbptt_pad_token
,enable_graph
,sync_dist
,sync_dist_op
,sync_dist_group
)
— Log a dictonary of values at once</>loss_function
(
logits
,labels
)
— Calculate the loss</>manual_backward
(
loss
,optimizer
,*args
,**kwargs
)
— Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you</>modules
(
)
(Module) — Returns an iterator over all modules in the network.</>named_buffers
(
prefix
,recurse
)
(string, torch.Tensor) — Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.</>named_children
(
)
(string, Module) — Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.</>named_modules
(
memo
,prefix
)
(string, Module) — Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.</>named_parameters
(
prefix
,recurse
)
(string, Parameter) — Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.</>on_after_backward
(
)
— Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.</>on_before_zero_grad
(
optimizer
)
— Called after optimizer.step() and before optimizer.zero_grad().</>on_epoch_end
(
)
— Keep the epoch progress bar This is not documented but working.</>on_epoch_start
(
)
— Called in the training loop at the very beginning of the epoch.</>on_fit_end
(
)
— Called at the very end of fit. If on DDP it is called on every process</>on_fit_start
(
)
— Called at the very beginning of fit. If on DDP it is called on every process</>on_hpc_load
(
checkpoint
)
— Hook to do whatever you need right before Slurm manager loads the model.</>on_hpc_save
(
checkpoint
)
— Hook to do whatever you need right before Slurm manager saves the model.</>on_load_checkpoint
(
checkpoint
)
— Do something with the checkpoint. Gives model a chance to load something beforestate_dict
is restored.</>on_pretrain_routine_end
(
)
— Called at the end of the pretrain routine (between fit and train start).</>on_pretrain_routine_start
(
)
— Called at the beginning of the pretrain routine (between fit and train start).</>on_save_checkpoint
(
checkpoint
)
— Give the model a chance to add something to the checkpoint.state_dict
is already there.</>on_test_batch_end
(
outputs
,batch
,batch_idx
,dataloader_idx
)
— Called in the test loop after the batch.</>on_test_batch_start
(
batch
,batch_idx
,dataloader_idx
)
— Called in the test loop before anything happens for that batch.</>on_test_epoch_end
(
)
— Called in the test loop at the very end of the epoch.</>on_test_epoch_start
(
)
— Called in the test loop at the very beginning of the epoch.</>on_test_model_eval
(
)
— Sets the model to eval during the test loop</>on_test_model_train
(
)
— Sets the model to train during the test loop</>on_train_batch_end
(
outputs
,batch
,batch_idx
,dataloader_idx
)
— Called in the training loop after the batch.</>on_train_batch_start
(
batch
,batch_idx
,dataloader_idx
)
— Called in the training loop before anything happens for that batch.</>on_train_end
(
)
— Called at the end of training before logger experiment is closed.</>on_train_epoch_end
(
outputs
)
— Called in the training loop at the very end of the epoch.</>on_train_epoch_start
(
)
— Called in the training loop at the very beginning of the epoch.</>on_train_start
(
)
— Called at the beginning of training before sanity check.</>on_validation_batch_end
(
outputs
,batch
,batch_idx
,dataloader_idx
)
— Called in the validation loop after the batch.</>on_validation_batch_start
(
batch
,batch_idx
,dataloader_idx
)
— Called in the validation loop before anything happens for that batch.</>on_validation_epoch_end
(
)
— Called in the validation loop at the very end of the epoch.</>on_validation_epoch_start
(
)
— Called in the validation loop at the very beginning of the epoch.</>on_validation_model_eval
(
)
— Sets the model to eval during the val loop</>on_validation_model_train
(
)
— Sets the model to train during the val loop</>optimizer_step
(
epoch
,batch_idx
,optimizer
,optimizer_idx
,optimizer_closure
,on_tpu
,using_native_amp
,using_lbfgs
)
— Override this method to adjust the default way the :class:~pytorch_lightning.trainer.trainer.Trainer
calls each optimizer. By default, Lightning callsstep()
andzero_grad()
as shown in the example once per optimizer.</>parameters
(
recurse
)
(Parameter) — Returns an iterator over module parameters.</>prepare_data
(
)
— Use this to download and prepare data.</>print
(
*args
,**kwargs
)
— Prints only from process 0. Use this in any distributed mode to log only once.</>register_backward_hook
(
hook
)
— Registers a backward hook on the module.</>register_buffer
(
name
,tensor
)
— Adds a persistent buffer to the module.</>register_forward_hook
(
hook
)
— Registers a forward hook on the module.</>register_forward_pre_hook
(
hook
)
— Registers a forward pre-hook on the module.</>register_parameter
(
name
,param
)
— Adds a parameter to the module.</>requires_grad_
(
requires_grad
)
(Module) — Change if autograd should record operations on parameters in this module.</>save_hyperparameters
(
*args
,frame
)
— Save all model arguments.</>setup
(
stage
)
— Called at the beginning of fit and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.</>state_dict
(
destination
,prefix
,keep_vars
)
(dict) — Returns a dictionary containing a whole state of the module.</>tbptt_split_batch
(
batch
,split_size
)
(list) — When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.</>teardown
(
stage
)
— Called at the end of fit and test.</>test_dataloader
(
)
(Union(dataloader, list of dataloader)) — Implement one or multiple PyTorch DataLoaders for testing.</>test_epoch_end
(
outputs
)
— Called at the end of a test epoch with the output of all test steps.</>test_step
(
*args
,**kwargs
)
— Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy.</>test_step_end
(
*args
,**kwargs
)
— Use this when testing with dp or ddp2 because :meth:test_step
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.</>to
(
*args
,**kwargs
)
(Module) — Moves and/or casts the parameters and buffers.</>to_onnx
(
file_path
,input_sample
,**kwargs
)
— Saves the model in ONNX format</>to_torchscript
(
file_path
,method
,example_inputs
,**kwargs
)
(Union(scriptmodule, dict(str: scriptmodule))) — By default compiles the whole model to a :class:~torch.jit.ScriptModule
. If you want to use tracing, please provided the argumentmethod='trace'
and make sure that either the example_inputs argument is provided, or the model has self.example_input_array set. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.</>toggle_optimizer
(
optimizer
,optimizer_idx
)
— Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.</>train
(
mode
)
(Module) — Sets the module in training mode.</>train_dataloader
(
)
(DataLoader) — Implement a PyTorch DataLoader for training.</>training_epoch_end
(
outputs
)
— Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs for every training_step.</>training_step
(
*args
,**kwargs
)
— Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.</>training_step_end
(
*args
,**kwargs
)
— Use this when training with dp or ddp2 because :meth:training_step
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.</>transfer_batch_to_device
(
batch
,device
)
(any) — Override this hook if your :class:~torch.utils.data.DataLoader
returns tensors wrapped in a custom data structure.</>type
(
dst_type
)
(Module) — Casts all parameters and buffers to :attr:dst_type
.</>unfreeze
(
)
— Unfreeze all parameters for training.</>val_dataloader
(
)
(Union(dataloader, list of dataloader)) — Implement one or multiple PyTorch DataLoaders for validation.</>validation_epoch_end
(
outputs
)
— Called at the end of the validation epoch with the outputs of all validation steps.</>validation_step
(
*args
,**kwargs
)
— Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.</>validation_step_end
(
*args
,**kwargs
)
— Use this when validating with dp or ddp2 because :meth:validation_step
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.</>zero_grad
(
)
— Sets gradients of all model parameters to zero.</>
register_buffer
(
name
, tensor
)
Adds a persistent buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's running_mean
is not a parameter, but is part of the persistent state.
Buffers can be accessed as attributes using given names.
name
(string) — name of the buffer. The buffer can be accessed from this module using the given nametensor
(Tensor) — buffer to be registered.
Example::
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_parameter
(
name
, param
)
Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
name
(string) — name of the parameter. The parameter can be accessed from this module using the given nameparam
(Parameter) — parameter to be added to the module.
add_module
(
name
, module
)
Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
name
(string) — name of the child module. The child module can be accessed from this module using the given namemodule
(Module) — child module to be added to the module.
apply
(
fn
)
Applies fn
recursively to every submodule (as returned by .children()
)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:nn-init-doc
).
fn (
— class:Module
-> None): function to be applied to each submodule
self
Example::
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.data.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1., 1.],
[ 1., 1.]])
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1., 1.],
[ 1., 1.]])
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
register_backward_hook
(
hook
)
Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:grad_input
and :attr:grad_output
may be tuples if the
module has multiple inputs or outputs. The hook should not modify its
arguments, but it can optionally return a new gradient with respect to
input that will be used in place of :attr:grad_input
in subsequent
computations.
class:torch.utils.hooks.RemovableHandle
:
a handle that can be used to remove the added hook by calling
handle.remove()
.. warning ::
The current implementation will not have the presented behavior
for complex :class:`Module` that perform many operations.
In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
contain the gradients for a subset of the inputs and outputs.
For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
directly on a specific input or output to get the required gradients.
register_forward_pre_hook
(
hook
)
Registers a forward pre-hook on the module.
The hook will be called every time before :func:forward
is invoked.
It should have the following signature::
hook(module, input) -> None or modified input
The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple).
class:torch.utils.hooks.RemovableHandle
:
a handle that can be used to remove the added hook by calling
handle.remove()
register_forward_hook
(
hook
)
Registers a forward hook on the module.
The hook will be called every time after :func:forward
has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:forward
is called.
class:torch.utils.hooks.RemovableHandle
:
a handle that can be used to remove the added hook by calling
handle.remove()
state_dict
(
destination=None
, prefix=''
, keep_vars=False
)
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
e
Example::
>>> module.state_dict().keys()
['bias', 'weight']
load_state_dict
(
state_dict
, strict=True
)
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
state_dict
(dict) — a dict containing parameters and persistent buffers.strict
(bool, optional) — whether to strictly enforce that the keys in :attr:state_dict
match the keys returned by this module's :meth:~torch.nn.Module.state_dict
function. Default:True
s s
parameters
(
recurse=True
)
Returns an iterator over module parameters.
This is typically passed to an optimizer.
recurse
(bool) — if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
module parameter
Example::
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
named_parameters
(
prefix=''
, recurse=True
)
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
prefix
(str) — prefix to prepend to all parameter names.recurse
(bool) — if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Tuple containing the name and parameter
Example::
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
buffers
(
recurse=True
)
Returns an iterator over module buffers.
recurse
(bool) — if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
module buffer
Example::
>>> for buf in model.buffers():
>>> print(type(buf.data), buf.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
named_buffers
(
prefix=''
, recurse=True
)
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
prefix
(str) — prefix to prepend to all buffer names.recurse
(bool) — if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
Tuple containing the name and buffer
Example::
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
children
(
)
Returns an iterator over immediate children modules.
a child module
named_children
(
)
Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
Tuple containing a name and child module
Example::
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
modules
(
)
Returns an iterator over all modules in the network.
a module in the network
Note
Duplicate modules are returned only once. In the following
example, l
will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
print(idx, '->', m)
0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)
named_modules
(
memo=None
, prefix=''
)
Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Tuple of name and module
Note
Duplicate modules are returned only once. In the following
example, l
will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
print(idx, '->', m)
0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
train
(
mode=True
)
Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:Dropout
, :class:BatchNorm
,
etc.
mode
(bool) — whether to set training mode (True
) or evaluation mode (False
). Default:True
.
self
eval
(
)
Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:Dropout
, :class:BatchNorm
,
etc.
This is equivalent with :meth:self.train(False) <torch.nn.Module.train>
.
self
requires_grad_
(
requires_grad=True
)
Change if autograd should record operations on parameters in this module.
This method sets the parameters' :attr:requires_grad
attributes
in-place.
This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).
requires_grad
(bool) — whether autograd should record operations on parameters in this module. Default:True
.
self
zero_grad
(
)
Sets gradients of all model parameters to zero.
extra_repr
(
)
Set the extra representation of the module
To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.
setup
(
stage
)
Called at the beginning of fit and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
stage
(str) — either 'fit' or 'test'
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
teardown
(
stage
)
Called at the end of fit and test.
stage
(str) — either 'fit' or 'test'
on_fit_start
(
)
Called at the very beginning of fit. If on DDP it is called on every process
on_fit_end
(
)
Called at the very end of fit. If on DDP it is called on every process
on_train_start
(
)
Called at the beginning of training before sanity check.
on_train_end
(
)
Called at the end of training before logger experiment is closed.
on_pretrain_routine_start
(
)
Called at the beginning of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
on_pretrain_routine_end
(
)
Called at the end of the pretrain routine (between fit and train start).
- fit
- pretrain_routine start
- pretrain_routine end
- training_start
on_train_batch_start
(
batch
, batch_idx
, dataloader_idx
)
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
batch
(any) — The batched data as it is returned by the training DataLoader.batch_idx
(int) — the index of the batchdataloader_idx
(int) — the index of the dataloader
on_train_batch_end
(
outputs
, batch
, batch_idx
, dataloader_idx
)
Called in the training loop after the batch.
outputs
(any) — The outputs of training_step_end(training_step(x))batch
(any) — The batched data as it is returned by the training DataLoader.batch_idx
(int) — the index of the batchdataloader_idx
(int) — the index of the dataloader
on_validation_model_eval
(
)
Sets the model to eval during the val loop
on_validation_model_train
(
)
Sets the model to train during the val loop
on_validation_batch_start
(
batch
, batch_idx
, dataloader_idx
)
Called in the validation loop before anything happens for that batch.
batch
(any) — The batched data as it is returned by the validation DataLoader.batch_idx
(int) — the index of the batchdataloader_idx
(int) — the index of the dataloader
on_validation_batch_end
(
outputs
, batch
, batch_idx
, dataloader_idx
)
Called in the validation loop after the batch.
outputs
(any) — The outputs of validation_step_end(validation_step(x))batch
(any) — The batched data as it is returned by the validation DataLoader.batch_idx
(int) — the index of the batchdataloader_idx
(int) — the index of the dataloader
on_test_batch_start
(
batch
, batch_idx
, dataloader_idx
)
Called in the test loop before anything happens for that batch.
batch
(any) — The batched data as it is returned by the test DataLoader.batch_idx
(int) — the index of the batchdataloader_idx
(int) — the index of the dataloader
on_test_batch_end
(
outputs
, batch
, batch_idx
, dataloader_idx
)
Called in the test loop after the batch.
outputs
(any) — The outputs of test_step_end(test_step(x))batch
(any) — The batched data as it is returned by the test DataLoader.batch_idx
(int) — the index of the batchdataloader_idx
(int) — the index of the dataloader
on_test_model_eval
(
)
Sets the model to eval during the test loop
on_test_model_train
(
)
Sets the model to train during the test loop
on_epoch_start
(
)
Called in the training loop at the very beginning of the epoch.
on_train_epoch_start
(
)
Called in the training loop at the very beginning of the epoch.
on_train_epoch_end
(
outputs
)
Called in the training loop at the very end of the epoch.
on_validation_epoch_start
(
)
Called in the validation loop at the very beginning of the epoch.
on_validation_epoch_end
(
)
Called in the validation loop at the very end of the epoch.
on_test_epoch_start
(
)
Called in the test loop at the very beginning of the epoch.
on_test_epoch_end
(
)
Called in the test loop at the very end of the epoch.
on_before_zero_grad
(
optimizer
)
Called after optimizer.step() and before optimizer.zero_grad().
Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called::
for optimizer in optimizers:
optimizer.step()
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad()
optimizer
(Optimizer) — The optimizer for which grads should be zeroed.
on_after_backward
(
)
Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.
Example::
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
grads = v
name = k
self.logger.experiment.add_histogram(tag=name, values=grads,
global_step=self.trainer.global_step)
prepare_data
(
)
Use this to download and prepare data.
.. warning:: DO NOT set state to the model (use setup
instead)
since this is NOT called on every GPU in DDP/TPU
Example::
def prepare_data(self):
# good
download_data()
tokenize()
etc()
# bad
self.split = data_split
self.some_state = some_other_state()
In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):
- Once per node. This is the default and is only called on LOCAL_RANK=0.
- Once in total. Only called on GLOBAL_RANK=0.
Example::
# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)
# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)
This is called before requesting the dataloaders:
.. code-block:: python
model.prepare_data()
if ddp/tpu: init()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
train_dataloader
(
)
→ DataLoader
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:~torch.utils.data.DataLoader
.
The dataloader you return will not be called every epoch unless you set
:paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch
to True
.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:
~pytorch_lightning.trainer.Trainer.fit
- ...
- :meth:
prepare_data
- :meth:
setup
- :meth:
train_dataloader
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
.. code-block:: python
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=True
)
return loader
test_dataloader
(
)
→ Union(dataloader, list of dataloader)
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be called every epoch unless you set
:paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch
to True
.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:
~pytorch_lightning.trainer.Trainer.fit
- ...
- :meth:
prepare_data
- :meth:
setup
- :meth:
train_dataloader
- :meth:
val_dataloader
- :meth:
test_dataloader
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Return: Single or multiple PyTorch DataLoaders.
.. code-block:: python
def test_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def test_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note
If you don't need a test dataset and a :meth:test_step
, you don't need to implement
this method.
Note
In the case where you return multiple test dataloaders, the :meth:test_step
will have an argument dataloader_idx
which matches the order here.
val_dataloader
(
)
→ Union(dataloader, list of dataloader)
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be called every epoch unless you set
:paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch
to True
.
It's recommended that all data downloads and preparation happen in :meth:prepare_data
.
- :meth:
~pytorch_lightning.trainer.Trainer.fit
- ...
- :meth:
prepare_data
- :meth:
train_dataloader
- :meth:
val_dataloader
- :meth:
test_dataloader
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Return: Single or multiple PyTorch DataLoaders.
.. code-block:: python
def val_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False,
transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note
If you don't need a validation dataset and a :meth:validation_step
, you don't need to
implement this method.
Note
In the case where you return multiple validation dataloaders, the :meth:validation_step
will have an argument dataloader_idx
which matches the order here.
transfer_batch_to_device
(
batch
, device=None
)
Override this hook if your :class:~torch.utils.data.DataLoader
returns tensors
wrapped in a custom data structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:
torch.Tensor
or anything that implements.to(...)
- :class:
list
- :class:
dict
- :class:
tuple
- :class:
torchtext.data.batch.Batch
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
batch
(any) — A batch of data that needs to be transferred to a new device.device
(device, optional) — The target device as defined in PyTorch.
A reference to the data on the new device.
Note
This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).
Note
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:~torch.nn.parallel.DistributedDataParallel
or
:class:~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel
and
override :meth:~pytorch_lightning.core.lightning.LightningModule.configure_ddp
.
- :func:
~pytorch_lightning.utilities.apply_func.move_data_to_device
- :func:
~pytorch_lightning.utilities.apply_func.apply_to_collection
load_from_checkpoint
(
checkpoint_path
, map_location=None
, hparams_file=None
, strict=True
, **kwargs
)
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to __init__
in the checkpoint under hyper_parameters
Any arguments specified through *args and **kwargs will override args stored in hyper_parameters
.
checkpoint_path
(str or IO) — Path to checkpoint. This can also be a URL, or file-like objectmap_location
(Union(dict(str: str), str, device, int, callable, nonetype), optional) — If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:torch.load
.hparams_file
(str, optional) — Optional path to a .yaml file with hierarchical structure as in this example::
You most likely won't need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you'd like to use. These will be converted into a :class:drop_prob: 0.2 dataloader: batch_size: 32
~dict
and passed into your :class:LightningModule
for use.
If your model'shparams
argument is :class:~argparse.Namespace
and .yaml file has hierarchical structure, you need to refactor your model to treathparams
as :class:~dict
.strict
(bool, optional) — Whether to strictly enforce that the keys in :attr:checkpoint_path
match the keys returned by this module's state dict. Default:True
.kwargs
— Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
Return:
:class:LightningModule
with loaded weights and hyperparameters (if available).
.. code-block:: python
# load weights without mapping ...
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
map_location=map_location
)
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
hparams_file='/path/to/hparams_file.yaml'
)
# override some of the params with new values
MyLightningModule.load_from_checkpoint(
PATH,
num_layers=128,
pretrained_ckpt_path: NEW_PATH,
)
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
on_load_checkpoint
(
checkpoint
)
Do something with the checkpoint.
Gives model a chance to load something before state_dict
is restored.
checkpoint
(dict(str: any)) — A dictionary with variables from the checkpoint.
on_save_checkpoint
(
checkpoint
)
Give the model a chance to add something to the checkpoint.
state_dict
is already there.
checkpoint
(dict(str: any)) — A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
on_hpc_save
(
checkpoint
)
Hook to do whatever you need right before Slurm manager saves the model.
checkpoint
(dict(str: any)) — A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
on_hpc_load
(
checkpoint
)
Hook to do whatever you need right before Slurm manager loads the model.
checkpoint
(dict(str: any)) — A dictionary with variables from the checkpoint.
grad_norm
(
norm_type
)
→ dict(str: float)
Compute each parameter's gradient's norm and their overall norm.
The overall norm is computed over all gradients together, as if they were concatenated into a single vector.
norm_type
(float, int, or str) — The type of the used p-norm, cast to float if necessary. Can be'inf'
for infinity norm.
Return: norms: The dictionary of p-norms of each parameter's gradient and a special entry for the total p-norm of the gradients viewed as a single vector.
to
(
*args
, **kwargs
)
Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:torch.Tensor.to
, but only accepts
floating point desired :attr:dtype
s. In addition, this method will
only cast the floating point parameters and buffers to :attr:dtype
(if given). The integral parameters and buffers will be moved
:attr:device
, if that is given, but with dtypes unchanged. When
:attr:non_blocking
is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
device
— the desired device of the parameters and buffers in this moduledtype
— the desired floating point type of the floating point parameters and buffers in this moduletensor
— Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
self
Example::
>>> class ExampleModule(DeviceDtypeModuleMixin):
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.device
device(type='cpu')
>>> module.dtype
torch.float16
cuda
(
device=None
)
Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
device
(int, optional) — if specified, all parameters will be copied to that device
self
cpu
(
)
Moves all model parameters and buffers to the CPU.
self
type
(
dst_type
)
Casts all parameters and buffers to :attr:dst_type
.
dst_type
(type or string) — the desired type
self
float
(
)
Casts all floating point parameters and buffers to float datatype.
self
double
(
)
Casts all floating point parameters and buffers to double
datatype.
self
half
(
)
Casts all floating point parameters and buffers to half
datatype.
self
abc.
ABCMeta
(
name
, bases
, namespace
, **kwargs
)
Metaclass for defining Abstract Base Classes (ABCs).
Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).
__instancecheck__
(
cls
,instance
)
— Override for isinstance(instance, cls).</>__subclasscheck__
(
cls
,subclass
)
— Override for issubclass(subclass, cls).</>register
(
cls
,subclass
)
— Register a virtual subclass of an ABC.</>
register
(
cls
, subclass
)
Register a virtual subclass of an ABC.
Returns the subclass, to allow usage as a class decorator.
__instancecheck__
(
cls
, instance
)
Override for isinstance(instance, cls).
__subclasscheck__
(
cls
, subclass
)
Override for issubclass(subclass, cls).
print
(
*args
, **kwargs
)
Prints only from process 0. Use this in any distributed mode to log only once.
*args
— The thing to print. Will be passed to Python's built-in print function.**kwargs
— Will be passed to Python's built-in print function.
n
: )
log
(
name
, value
, prog_bar=False
, logger=True
, on_step=None
, on_epoch=None
, reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>
, tbptt_reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>
, tbptt_pad_token=0
, enable_graph=False
, sync_dist=False
, sync_dist_op='mean'
, sync_dist_group=None
)
Log a key, value
Example::
self.log('train_loss', loss)
The default behavior per hook is as follows
.. csv-table:: *
also applies to the test loop
:header: "LightningMoule Hook", "on_step", "on_epoch", "prog_bar", "logger"
:widths: 20, 10, 10, 10, 10
"training_step", "T", "F", "F", "T" "training_step_end", "T", "F", "F", "T" "training_epoch_end", "F", "T", "F", "T" "validation_step", "F", "T", "F", "T" "validation_step_end", "F", "T", "F", "T" "validation_epoch_end*", "F", "T", "F", "T"
name
(str) — key namevalue
(any) — value nameprog_bar
(bool, optional) — if True logs to the progress barlogger
(bool, optional) — if True logs to the loggeron_step
(bool, optional) — if True logs at this step. None auto-logs at the training_step but not validation/test_stepon_epoch
(bool, optional) — if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_stepreduce_fx
(callable, optional) — reduction function over step values for end of epoch. Torch.mean by defaulttbptt_reduce_fx
(callable, optional) — function to reduce on truncated back proptbptt_pad_token
(int, optional) — token to use for paddingenable_graph
(bool, optional) — if True, will not auto detach the graphsync_dist
(bool, optional) — if True, reduces the metric across GPUs/TPUssync_dist_op
(any or str, optional) — the op to sync across GPUs/TPUssync_dist_group
(any, optional) — the ddp group
log_dict
(
dictionary
, prog_bar=False
, logger=True
, on_step=None
, on_epoch=None
, reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>
, tbptt_reduce_fx=<built-in method mean of type object at 0x7f7ebfbd7860>
, tbptt_pad_token=0
, enable_graph=False
, sync_dist=False
, sync_dist_op='mean'
, sync_dist_group=None
)
Log a dictonary of values at once
Example::
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
self.log_dict(values)
dictionary
(dict) — key value pairs (str, tensors)prog_bar
(bool, optional) — if True logs to the progress baselogger
(bool, optional) — if True logs to the loggeron_step
(bool, optional) — if True logs at this step. None auto-logs for training_step but not validation/test_stepon_epoch
(bool, optional) — if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_stepreduce_fx
(callable, optional) — reduction function over step values for end of epoch. Torch.mean by defaulttbptt_reduce_fx
(callable, optional) — function to reduce on truncated back proptbptt_pad_token
(int, optional) — token to use for paddingenable_graph
(bool, optional) — if True, will not auto detach the graphsync_dist
(bool, optional) — if True, reduces the metric across GPUs/TPUssync_dist_op
(any or str, optional) — the op to sync across GPUs/TPUssync_dist_group
(any, optional) — the ddp group:
all_gather
(
tensor
, group=None
, sync_grads=False
)
Allows users to call self.all_gather()
from the LightningModule, thus making
the all_gather
operation accelerator agnostic.
all_gather
is a function provided by accelerators to gather a tensor from several
distributed processes
tensor
(Tensor) — tensor of shape (batch, ...)group
(any, optional) — the process group to gather results from. Defaults to all processes (world)sync_grads
(bool, optional) — flag that allows users to synchronize gradients for all_gather op
Return: A tensor of shape (world_size, batch, ...)
forward
(
*args
, **kwargs
)
Same as :meth:torch.nn.Module.forward()
, however in Lightning you want this to define
the operations you want to use for prediction (i.e.: on a server or as a feature extractor).
Normally you'd call self()
from your :meth:training_step
method.
This makes it easy to write a complex system for training with the outputs
you'd want in a prediction setting.
You may also find the :func:~pytorch_lightning.core.decorators.auto_move_data
decorator useful
when using the module outside Lightning in a production setting.
*args
— Whatever you decide to pass into the forward method.**kwargs
— Keyword arguments are also possible.
Return: Predicted output
.. code-block:: python
# example if we were using this model as a feature extractor
def forward(self, x):
feature_maps = self.convnet(x)
return feature_maps
def training_step(self, batch, batch_idx):
x, y = batch
feature_maps = self(x)
logits = self.classifier(feature_maps)
# ...
return loss
# splitting it this way allows model to be used a feature extractor
model = MyModelAbove()
inputs = server.get_request()
results = model(inputs)
server.write_results(results)
# -------------
# This is in stark contrast to torch.nn.Module where normally you would have this:
def forward(self, batch):
x, y = batch
feature_maps = self.convnet(x)
logits = self.classifier(feature_maps)
return logits
training_step
(
*args
, **kwargs
)
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
batch (
— class:~torch.Tensor
| (:class:~torch.Tensor
, ...) | [:class:~torch.Tensor
, ...]): The output of your :class:~torch.utils.data.DataLoader
. A tensor, tuple or list.batch_idx
(int) — Integer displaying index of this batchoptimizer_idx
(int) — When using multiple optimizers, this argument will also be present.hiddens(
— class:~torch.Tensor
): Passed in if :paramref:~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps
> 0.
Return: Any of.
- :class:`~torch.Tensor` - The loss tensor
- `dict` - A dictionary. Can include any keys, but must include the key 'loss'
- `None` - Training will skip to the next batch
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.
.. code-block:: python
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
# do training_step with encoder
if optimizer_idx == 1:
# do training_step with decoder
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
.. code-block:: python
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hidden states from the previous truncated backprop step
...
out, hiddens = self.lstm(data, hiddens)
...
return {'loss': loss, 'hiddens': hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
training_step_end
(
*args
, **kwargs
)
Use this when training with dp or ddp2 because :meth:training_step
will operate on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
Note
If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
training_step_end(batch_parts_outputs)
batch_parts_outputs
— What you return intraining_step
for each batch part.
Return: Anything
When using dp/ddp2 distributed backends, only a portion of the batch is inside the training_step:
.. code-block:: python
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self(x)
# softmax uses only a portion of the batch in the denomintaor
loss = self.softmax(out)
loss = nce_loss(loss)
return loss
If you wish to do something with all the parts of the batch, then use this method to do it:
.. code-block:: python
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.encoder(x)
return {'pred': out}
def training_step_end(self, training_step_outputs):
gpu_0_pred = training_step_outputs[0]['pred']
gpu_1_pred = training_step_outputs[1]['pred']
gpu_n_pred = training_step_outputs[n]['pred']
# this softmax now uses the full batch
loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred])
return loss
See the :ref:multi_gpu
guide for more details.
training_epoch_end
(
outputs
)
Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs for every training_step.
.. code-block:: python
# the pseudocode for these calls
train_outs = []
for train_batch in train_data:
out = training_step(train_batch)
train_outs.append(out)
training_epoch_end(train_outs)
outputs
(list of any) — List of outputs you defined in :meth:training_step
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Return: None
Note
If this method is not overridden, this won't be called.
Example::
def training_epoch_end(self, training_step_outputs):
# do something with all training_step outputs
return result
With multiple dataloaders, outputs
will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each training step for that dataloader.
.. code-block:: python
def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
# do something here
validation_step
(
*args
, **kwargs
)
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
val_outs.append(out)
validation_epoch_end(val_outs)
batch (
— class:~torch.Tensor
| (:class:~torch.Tensor
, ...) | [:class:~torch.Tensor
, ...]): The output of your :class:~torch.utils.data.DataLoader
. A tensor, tuple or list.batch_idx
(int) — The index of this batchdataloader_idx
(int) — The index of the dataloader that produced this batch (only if multiple val datasets used)
Return: Any of.
- Any object or value
- `None` - Validation will skip to the next batch
.. code-block:: python
# pseudocode of order
out = validation_step()
if defined('validation_step_end'):
out = validation_step_end(out)
out = validation_epoch_end(out)
.. code-block:: python
# if you have one val dataloader:
def validation_step(self, batch, batch_idx)
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx)
.. code-block:: python
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val datasets, validation_step will have an additional argument.
.. code-block:: python
# CASE 2: multiple validation datasets
def validation_step(self, batch, batch_idx, dataloader_idx):
# dataloader_idx tells you which dataset this is.
Note
If you don't need to validate you don't need to implement this method.
Note
When the :meth:validation_step
is called, the model has been put in eval mode
and PyTorch gradients have been disabled. At the end of validation,
the model goes back to training mode and gradients are enabled.
validation_step_end
(
*args
, **kwargs
)
Use this when validating with dp or ddp2 because :meth:validation_step
will operate on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
Note
If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code.
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches]
validation_step_end(batch_parts_outputs)
batch_parts_outputs
— What you return in :meth:validation_step
for each batch part.
Return: None or anything
.. code-block:: python
# WITHOUT validation_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def validation_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.encoder(x)
loss = self.softmax(out)
loss = nce_loss(loss)
self.log('val_loss', loss)
# --------------
# with validation_step_end to do softmax over the full batch
def validation_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self(x)
return out
def validation_epoch_end(self, val_step_outputs):
for out in val_step_outputs:
# do something with these
See the :ref:multi_gpu
guide for more details.
validation_epoch_end
(
outputs
)
Called at the end of the validation epoch with the outputs of all validation steps.
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
val_outs.append(out)
validation_epoch_end(val_outs)
outputs
(list of any) — List of outputs you defined in :meth:validation_step
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Return: None
Note
If you didn't define a :meth:validation_step
, this won't be called.
With a single dataloader:
.. code-block:: python
def validation_epoch_end(self, val_step_outputs):
for out in val_step_outputs:
# do something
With multiple dataloaders, outputs
will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
.. code-block:: python
def validation_epoch_end(self, outputs):
for dataloader_output_result in outputs:
dataloader_outs = dataloader_output_result.dataloader_i_outputs
self.log('final_metric', final_value)
test_step
(
*args
, **kwargs
)
Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(test_batch)
test_outs.append(out)
test_epoch_end(test_outs)
batch (
— class:~torch.Tensor
| (:class:~torch.Tensor
, ...) | [:class:~torch.Tensor
, ...]): The output of your :class:~torch.utils.data.DataLoader
. A tensor, tuple or list.batch_idx
(int) — The index of this batch.dataloader_idx
(int) — The index of the dataloader that produced this batch (only if multiple test datasets used).
Return: Any of.
- Any object or value
- `None` - Testing will skip to the next batch
.. code-block:: python
# if you have one test dataloader:
def test_step(self, batch, batch_idx)
# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx)
.. code-block:: python
# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple validation datasets, :meth:test_step
will have an additional
argument.
.. code-block:: python
# CASE 2: multiple test datasets
def test_step(self, batch, batch_idx, dataloader_idx):
# dataloader_idx tells you which dataset this is.
Note
If you don't need to validate you don't need to implement this method.
Note
When the :meth:test_step
is called, the model has been put in eval mode and
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
to training mode and gradients are enabled.
test_step_end
(
*args
, **kwargs
)
Use this when testing with dp or ddp2 because :meth:test_step
will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
Note
If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code.
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches]
test_step_end(batch_parts_outputs)
batch_parts_outputs
— What you return in :meth:test_step
for each batch part.
Return: None or anything
.. code-block:: python
# WITHOUT test_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def test_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self(x)
loss = self.softmax(out)
self.log('test_loss', loss)
# --------------
# with test_step_end to do softmax over the full batch
def test_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.encoder(x)
return out
def test_epoch_end(self, output_results):
# this out is now the full size of the batch
all_test_step_outs = output_results.out
loss = nce_loss(all_test_step_outs)
self.log('test_loss', loss)
See the :ref:multi_gpu
guide for more details.
test_epoch_end
(
outputs
)
Called at the end of a test epoch with the output of all test steps.
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(test_batch)
test_outs.append(out)
test_epoch_end(test_outs)
outputs
(list of any) — List of outputs you defined in :meth:test_step_end
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader
Return: None
Note
If you didn't define a :meth:test_step
, this won't be called.
With a single dataloader:
.. code-block:: python
def test_epoch_end(self, outputs):
# do something with the outputs of all test batches
all_test_preds = test_step_outputs.predictions
some_result = calc_all_results(all_test_preds)
self.log(some_result)
With multiple dataloaders, outputs
will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each test step for that dataloader.
.. code-block:: python
def test_epoch_end(self, outputs):
final_value = 0
for dataloader_outputs in outputs:
for test_step_out in dataloader_outputs:
# do something
final_value += test_step_out
self.log('final_metric', final_value)
manual_backward
(
loss
, optimizer
, *args
, **kwargs
)
Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
This function forwards all args to the .backward() call as well.
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
.. tip:: In manual mode we still automatically accumulate grad over batches if
Trainer(accumulate_grad_batches=x) is set and you use optimizer.step()
Example::
def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
opt_a.step()
backward
(
loss
, optimizer
, optimizer_idx
, *args
, **kwargs
)
Override backward with your own implementation if you need to.
loss
(Tensor) — Loss is already scaled by accumulated gradsoptimizer
(Optimizer) — Current optimizer being usedoptimizer_idx
(int) — Index of the current optimizer being used
Called to perform backward step. Feel free to override as needed. The loss passed in has already been scaled for accumulated gradients if requested.
Example::
def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
toggle_optimizer
(
optimizer
, optimizer_idx
)
Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
.. note:: Only called when using multiple optimizers
Override for your own behavior
optimizer_step
(
epoch=None
, batch_idx=None
, optimizer=None
, optimizer_idx=None
, optimizer_closure=None
, on_tpu=None
, using_native_amp=None
, using_lbfgs=None
)
Override this method to adjust the default way the
:class:~pytorch_lightning.trainer.trainer.Trainer
calls each optimizer.
By default, Lightning calls step()
and zero_grad()
as shown in the example
once per optimizer.
.. tip:: With Trainer(enable_pl_optimizer=True)
, you can user optimizer.step()
directly and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.
Warning
If you are overriding this method, make sure that you pass the optimizer_closure
parameter
to optimizer.step()
function as shown in the examples. This ensures that
train_step_and_backward_closure
is called within
:meth:~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch
.
epoch
(int, optional) — Current epochbatch_idx
(int, optional) — Index of current batchoptimizer
(Optimizer, optional) — A PyTorch optimizeroptimizer_idx
(int, optional) — If you used multiple optimizers this indexes into that list.optimizer_closure
(callable, optional) — closure for all optimizerson_tpu
(bool, optional) — true if TPU backward is requiredusing_native_amp
(bool, optional) — True if using native ampusing_lbfgs
(bool, optional) — True if the matching optimizer is lbfgs
.. code-block:: python
# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step(closure=optimizer_closure)
# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0 :
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0 :
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
# ...
# add as many optimizers as you want
s :
n
p , : r : ) : e
s ) )
tbptt_split_batch
(
batch
, split_size
)
→ list
When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.
batch
(Tensor) — Current batchsplit_size
(int) — The size of the split
Return:
List of batch splits. Each split will be passed to :meth:training_step
to enable truncated
back propagation through time. The default implementation splits root level Tensors and
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
.. code-block:: python
def tbptt_split_batch(self, batch, split_size):
splits = []
for t in range(0, time_dims[0], split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t:t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t:t + split_size]
batch_split.append(split_x)
splits.append(batch_split)
return splits
Note
Called in the training loop after
:meth:~pytorch_lightning.callbacks.base.Callback.on_batch_start
if :paramref:~pytorch_lightning.trainer.Trainer.truncated_bptt_steps
> 0.
Each returned batch split is passed separately to :meth:training_step
.
freeze
(
)
Freeze all params for inference.
.. code-block:: python
model = MyLightningModule(...)
model.freeze()
unfreeze
(
)
Unfreeze all parameters for training.
.. code-block:: python
model = MyLightningModule(...)
model.unfreeze()
get_progress_bar_dict
(
)
→ dict(str: int or str)
Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger.
.. code-block::
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
Here is an example how to override the defaults:
.. code-block:: python
def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items
Return: Dictionary with the items to be displayed in the progress bar.
save_hyperparameters
(
*args
, frame=None
)
Save all model arguments.
args
— single object ofdict
,NameSpace
orOmegaConf
or string names or argumenst from class__init__
>>> from collections import OrderedDict
>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assign arguments
... self.save_hyperparameters('arg1', 'arg3')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
>>> class AutomaticArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> class SingleArgModel(LightningModule):
... def __init__(self, params):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(params)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
to_onnx
(
file_path
, input_sample=None
, **kwargs
)
Saves the model in ONNX format
file_path
(str or Path) — The path of the file the onnx model should be saved to.input_sample
(any, optional) — An input for tracing. Default: None (Use self.example_input_array)**kwargs
— Will be passed to torch.onnx.export function.
>>> class SimpleModel(LightningModule):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
... model = SimpleModel()
... input_sample = torch.randn((1, 64))
... model.to_onnx(tmpfile.name, input_sample, export_params=True)
... os.path.isfile(tmpfile.name)
True
to_torchscript
(
file_path=None
, method='script'
, example_inputs=None
, **kwargs
)
→ Union(scriptmodule, dict(str: scriptmodule))
By default compiles the whole model to a :class:~torch.jit.ScriptModule
.
If you want to use tracing, please provided the argument method='trace'
and make sure that either the
example_inputs argument is provided, or the model has self.example_input_array set.
If you would like to customize the modules that are scripted you should override this method.
In case you want to return multiple modules, we recommend using a dictionary.
file_path
(str, Path, or NoneType, optional) — Path where to save the torchscript. Default: None (no file saved).method
(str, optional) — Whether to use TorchScript's script or trace method. Default: 'script'example_inputs
(any, optional) — An input to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array)**kwargs
— Additional arguments that will be passed to the :func:torch.jit.script
or :func:torch.jit.trace
function.
Note
- Requires the implementation of the
:meth:
~pytorch_lightning.core.lightning.LightningModule.forward
method. - The exported script will be set to evaluation mode.
- It is recommended that you install the latest supported version of PyTorch
to use this feature without limitations. See also the :mod:
torch.jit
documentation for supported features.
>>> class SimpleModel(LightningModule):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
>>> os.path.isfile("model.pt") # doctest: +SKIP
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
... example_inputs=torch.randn(1, 64))) # doctest: +SKIP
>>> os.path.isfile("model_trace.pt") # doctest: +SKIP
True
Return: This LightningModule as a torchscript, regardless of whether file_path is defined or not.
on_epoch_end
(
)
Keep the epoch progress bar This is not documented but working.
loss_function
(
logits
, labels
)
Calculate the loss
configure_optimizers
(
)
Configure the optimizers