plkit.trainer
Wrapper of the Trainer class
ProgressBar
— Align the Epoch in progress bar</>Trainer
— The Trainner class</>
plkit.trainer.
ProgressBar
(
refresh_rate=1
, process_position=0
)
Align the Epoch in progress bar
test_batch_idx
(int) — The current batch index being processed during testing. Use this to update your progress bar.</>total_test_batches
(int) — The total number of training batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can returninf
if the test dataloader is of infinite size.</>total_train_batches
(int) — The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can returninf
if the training dataloader is of infinite size.</>total_val_batches
(int) — The total number of training batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can returninf
if the validation dataloader is of infinite size.</>train_batch_idx
(int) — The current batch index being processed during training. Use this to update your progress bar.</>val_batch_idx
(int) — The current batch index being processed during validation. Use this to update your progress bar.</>
disable
(
)
— You should provide a way to disable the progress bar. The :class:~pytorch_lightning.trainer.trainer.Trainer
will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.</>enable
(
)
— You should provide a way to enable the progress bar. The :class:~pytorch_lightning.trainer.trainer.Trainer
will call this in e.g. pre-training routines like the :ref:learning rate finder <lr_finder>
to temporarily enable and disable the main progress bar.</>init_sanity_tqdm
(
)
(tqdm) — Override this to customize the tqdm bar for the validation sanity run.</>init_test_tqdm
(
)
(tqdm) — Override this to customize the tqdm bar for testing.</>init_train_tqdm
(
)
(tqdm) — Override this to customize the tqdm bar for training.</>init_validation_tqdm
(
)
(tqdm) — Override this to customize the tqdm bar for validation.</>on_after_backward
(
trainer
,pl_module
)
— Called after loss.backward() and before optimizers do anything.</>on_batch_end
(
trainer
,pl_module
)
— Called when the training batch ends.</>on_batch_start
(
trainer
,pl_module
)
— Called when the training batch begins.</>on_before_zero_grad
(
trainer
,pl_module
,optimizer
)
— Called after optimizer.step() and before optimizer.zero_grad().</>on_epoch_end
(
trainer
,pl_module
)
— Called when the epoch ends.</>on_epoch_start
(
trainer
,pl_module
)
— Try to align the epoch number</>on_fit_end
(
trainer
,pl_module
)
— Called when fit ends</>on_fit_start
(
trainer
,pl_module
)
— Called when fit begins</>on_init_end
(
trainer
)
— Called when the trainer initialization ends, model has not yet been set.</>on_init_start
(
trainer
)
— Called when the trainer initialization begins, model has not yet been set.</>on_keyboard_interrupt
(
trainer
,pl_module
)
— Called when the training is interrupted by KeyboardInterrupt.</>on_load_checkpoint
(
checkpointed_state
)
— Called when loading a model checkpoint, use to reload state.</>on_pretrain_routine_end
(
trainer
,pl_module
)
— Called when the pretrain routine ends.</>on_pretrain_routine_start
(
trainer
,pl_module
)
— Called when the pretrain routine begins.</>on_sanity_check_end
(
trainer
,pl_module
)
— Called when the validation sanity check ends.</>on_sanity_check_start
(
trainer
,pl_module
)
— Called when the validation sanity check starts.</>on_save_checkpoint
(
trainer
,pl_module
)
— Called when saving a model checkpoint, use to persist state.</>on_test_batch_end
(
trainer
,pl_module
,outputs
,batch
,batch_idx
,dataloader_idx
)
— Called when the test batch ends.</>on_test_batch_start
(
trainer
,pl_module
,batch
,batch_idx
,dataloader_idx
)
— Called when the test batch begins.</>on_test_end
(
trainer
,pl_module
)
— Called when the test ends.</>on_test_epoch_end
(
trainer
,pl_module
)
— Called when the test epoch ends.</>on_test_epoch_start
(
trainer
,pl_module
)
— Called when the test epoch begins.</>on_test_start
(
trainer
,pl_module
)
— Called when the test begins.</>on_train_batch_end
(
trainer
,pl_module
,outputs
,batch
,batch_idx
,dataloader_idx
)
— Called when the train batch ends.</>on_train_batch_start
(
trainer
,pl_module
,batch
,batch_idx
,dataloader_idx
)
— Called when the train batch begins.</>on_train_end
(
trainer
,pl_module
)
— Called when the train ends.</>on_train_epoch_end
(
trainer
,pl_module
,outputs
)
— Called when the train epoch ends.</>on_train_epoch_start
(
trainer
,pl_module
)
— Called when the train epoch begins.</>on_train_start
(
trainer
,pl_module
)
— Called when the train begins.</>on_validation_batch_end
(
trainer
,pl_module
,outputs
,batch
,batch_idx
,dataloader_idx
)
— Called when the validation batch ends.</>on_validation_batch_start
(
trainer
,pl_module
,batch
,batch_idx
,dataloader_idx
)
— Called when the validation batch begins.</>on_validation_end
(
trainer
,pl_module
)
— Called when the validation loop ends.</>on_validation_epoch_end
(
trainer
,pl_module
)
— Called when the val epoch ends.</>on_validation_epoch_start
(
trainer
,pl_module
)
— Called when the val epoch begins.</>on_validation_start
(
trainer
,pl_module
)
— Called when the validation loop begins.</>setup
(
trainer
,pl_module
,stage
)
— Called when fit or test begins</>teardown
(
trainer
,pl_module
,stage
)
— Called when fit or test ends</>
setup
(
trainer
, pl_module
, stage
)
Called when fit or test begins
teardown
(
trainer
, pl_module
, stage
)
Called when fit or test ends
on_init_start
(
trainer
)
Called when the trainer initialization begins, model has not yet been set.
on_fit_start
(
trainer
, pl_module
)
Called when fit begins
on_fit_end
(
trainer
, pl_module
)
Called when fit ends
on_train_batch_start
(
trainer
, pl_module
, batch
, batch_idx
, dataloader_idx
)
Called when the train batch begins.
on_train_epoch_start
(
trainer
, pl_module
)
Called when the train epoch begins.
on_train_epoch_end
(
trainer
, pl_module
, outputs
)
Called when the train epoch ends.
on_validation_epoch_start
(
trainer
, pl_module
)
Called when the val epoch begins.
on_validation_epoch_end
(
trainer
, pl_module
)
Called when the val epoch ends.
on_test_epoch_start
(
trainer
, pl_module
)
Called when the test epoch begins.
on_test_epoch_end
(
trainer
, pl_module
)
Called when the test epoch ends.
on_epoch_end
(
trainer
, pl_module
)
Called when the epoch ends.
on_batch_start
(
trainer
, pl_module
)
Called when the training batch begins.
on_validation_batch_start
(
trainer
, pl_module
, batch
, batch_idx
, dataloader_idx
)
Called when the validation batch begins.
on_test_batch_start
(
trainer
, pl_module
, batch
, batch_idx
, dataloader_idx
)
Called when the test batch begins.
on_batch_end
(
trainer
, pl_module
)
Called when the training batch ends.
on_pretrain_routine_start
(
trainer
, pl_module
)
Called when the pretrain routine begins.
on_pretrain_routine_end
(
trainer
, pl_module
)
Called when the pretrain routine ends.
on_keyboard_interrupt
(
trainer
, pl_module
)
Called when the training is interrupted by KeyboardInterrupt.
on_save_checkpoint
(
trainer
, pl_module
)
Called when saving a model checkpoint, use to persist state.
on_load_checkpoint
(
checkpointed_state
)
Called when loading a model checkpoint, use to reload state.
on_after_backward
(
trainer
, pl_module
)
Called after loss.backward() and before optimizers do anything.
on_before_zero_grad
(
trainer
, pl_module
, optimizer
)
Called after optimizer.step() and before optimizer.zero_grad().
on_init_end
(
trainer
)
Called when the trainer initialization ends, model has not yet been set.
disable
(
)
You should provide a way to disable the progress bar.
The :class:~pytorch_lightning.trainer.trainer.Trainer
will call this to disable the
output on processes that have a rank different from 0, e.g., in multi-node training.
enable
(
)
You should provide a way to enable the progress bar.
The :class:~pytorch_lightning.trainer.trainer.Trainer
will call this in e.g. pre-training
routines like the :ref:learning rate finder <lr_finder>
to temporarily enable and
disable the main progress bar.
init_sanity_tqdm
(
)
→ tqdm
Override this to customize the tqdm bar for the validation sanity run.
init_train_tqdm
(
)
→ tqdm
Override this to customize the tqdm bar for training.
init_validation_tqdm
(
)
→ tqdm
Override this to customize the tqdm bar for validation.
init_test_tqdm
(
)
→ tqdm
Override this to customize the tqdm bar for testing.
on_sanity_check_start
(
trainer
, pl_module
)
Called when the validation sanity check starts.
on_sanity_check_end
(
trainer
, pl_module
)
Called when the validation sanity check ends.
on_train_start
(
trainer
, pl_module
)
Called when the train begins.
on_train_batch_end
(
trainer
, pl_module
, outputs
, batch
, batch_idx
, dataloader_idx
)
Called when the train batch ends.
on_validation_start
(
trainer
, pl_module
)
Called when the validation loop begins.
on_validation_batch_end
(
trainer
, pl_module
, outputs
, batch
, batch_idx
, dataloader_idx
)
Called when the validation batch ends.
on_validation_end
(
trainer
, pl_module
)
Called when the validation loop ends.
on_train_end
(
trainer
, pl_module
)
Called when the train ends.
on_test_start
(
trainer
, pl_module
)
Called when the test begins.
on_test_batch_end
(
trainer
, pl_module
, outputs
, batch
, batch_idx
, dataloader_idx
)
Called when the test batch ends.
on_test_end
(
trainer
, pl_module
)
Called when the test ends.
on_epoch_start
(
trainer
, pl_module
)
Try to align the epoch number
plkit.trainer.
Trainer
(
*args
, **kwargs
)
The Trainner class
from_config
(aka from_dict
) added as classmethod to instantiate trainer
from configuration dictionaries.
checkpoint_callback
(ModelCheckpoint, optional) — The first checkpoint callback in the Trainer.callbacks list, orNone
if no checkpoint callbacks exist.</>checkpoint_callbacks
(list of ModelCheckpoint) — A list of all instances of ModelCheckpoint found in the Trainer.callbacks list.</>default_root_dir
(str) — The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths.</>disable_validation
(bool) — Check if validation is disabled during training.</>enable_validation
(bool) — Check if we should run validation during training.</>progress_bar_dict
(dict) — Format progress bar metrics.</>weights_save_path
(str) — The default root location to save weights (checkpoints), e.g., when the :class:~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
does not define a file path.</>
available_plugins
(
)
— List of all available plugins that can be string arguments to the trainer. Returns: List of all available plugins that are supported as string arguments.</>fit
(
*args
,**kwargs
)
— Train and validate the model</>from_config
(
config
,**kwargs
)
— Create an instance from CLI arguments.</>from_config
(
config
,**kwargs
)
— Create an instance from CLI arguments.</>get_deprecated_arg_names
(
)
(list) — Returns a list with deprecated Trainer arguments.</>on_after_backward
(
)
— Called after loss.backward() and before optimizers do anything.</>on_batch_end
(
)
— Called when the training batch ends.</>on_batch_start
(
)
— Called when the training batch begins.</>on_before_zero_grad
(
optimizer
)
— Called after optimizer.step() and before optimizer.zero_grad().</>on_epoch_end
(
)
— Called when the epoch ends.</>on_epoch_start
(
)
— Called when the epoch begins.</>on_fit_end
(
)
— Called when the trainer initialization begins, model has not yet been set.</>on_fit_start
(
)
— Called when the trainer initialization begins, model has not yet been set.</>on_init_end
(
)
— Called when the trainer initialization ends, model has not yet been set.</>on_init_start
(
)
— Called when the trainer initialization begins, model has not yet been set.</>on_keyboard_interrupt
(
)
— Called when the training is interrupted by KeyboardInterrupt.</>on_load_checkpoint
(
checkpoint
)
— Called when loading a model checkpoint.</>on_pretrain_routine_end
(
model
)
— Called when the train ends.</>on_pretrain_routine_start
(
model
)
— Called when the train begins.</>on_sanity_check_end
(
)
— Called when the validation sanity check ends.</>on_sanity_check_start
(
)
— Called when the validation sanity check starts.</>on_save_checkpoint
(
)
— Called when saving a model checkpoint.</>on_test_batch_end
(
outputs
,batch
,batch_idx
,dataloader_idx
)
— Called when the test batch ends.</>on_test_batch_start
(
batch
,batch_idx
,dataloader_idx
)
— Called when the test batch begins.</>on_test_end
(
)
— Called when the test ends.</>on_test_epoch_end
(
)
— Called when the epoch ends.</>on_test_epoch_start
(
)
— Called when the epoch begins.</>on_test_start
(
)
— Called when the test begins.</>on_train_batch_end
(
outputs
,batch
,batch_idx
,dataloader_idx
)
— Called when the training batch ends.</>on_train_batch_start
(
batch
,batch_idx
,dataloader_idx
)
— Called when the training batch begins.</>on_train_end
(
)
— Called when the train ends.</>on_train_epoch_end
(
outputs
)
— Called when the epoch ends.</>on_train_epoch_start
(
)
— Called when the epoch begins.</>on_train_start
(
)
— Called when the train begins.</>on_validation_batch_end
(
outputs
,batch
,batch_idx
,dataloader_idx
)
— Called when the validation batch ends.</>on_validation_batch_start
(
batch
,batch_idx
,dataloader_idx
)
— Called when the validation batch begins.</>on_validation_end
(
)
— Called when the validation loop ends.</>on_validation_epoch_end
(
)
— Called when the epoch ends.</>on_validation_epoch_start
(
)
— Called when the epoch begins.</>on_validation_start
(
)
— Called when the validation loop begins.</>process_dict_result
(
output
,train
)
— Reduces output according to the training mode.</>request_dataloader
(
dataloader_fx
)
(DataLoader) — Handles downloading data in the GPU or TPU case.</>reset_test_dataloader
(
model
)
— Resets the validation dataloader and determines the number of batches.</>reset_train_dataloader
(
model
)
— Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.).</>reset_val_dataloader
(
model
)
— Resets the validation dataloader and determines the number of batches.</>setup
(
model
,stage
)
— Called in the beginning of fit and test</>teardown
(
stage
)
— Called at the end of fit and test</>test
(
*args
,**kwargs
)
— Test the model</>tune
(
model
,train_dataloader
,val_dataloaders
,datamodule
)
— Runs routines to tune hyperparameters before training.</>
abc.
ABCMeta
(
name
, bases
, namespace
, **kwargs
)
Metaclass for defining Abstract Base Classes (ABCs).
Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).
__instancecheck__
(
cls
,instance
)
— Override for isinstance(instance, cls).</>__subclasscheck__
(
cls
,subclass
)
— Override for issubclass(subclass, cls).</>register
(
cls
,subclass
)
— Register a virtual subclass of an ABC.</>
register
(
cls
, subclass
)
Register a virtual subclass of an ABC.
Returns the subclass, to allow usage as a class decorator.
__instancecheck__
(
cls
, instance
)
Override for isinstance(instance, cls).
__subclasscheck__
(
cls
, subclass
)
Override for issubclass(subclass, cls).
reset_train_dataloader
(
model
)
Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.).
model
(LightningModule) — The currentLightningModule
reset_val_dataloader
(
model
)
Resets the validation dataloader and determines the number of batches.
model
(LightningModule) — The currentLightningModule
reset_test_dataloader
(
model
)
Resets the validation dataloader and determines the number of batches.
model
— The currentLightningModule
request_dataloader
(
dataloader_fx
)
Handles downloading data in the GPU or TPU case.
dataloader_fx
(callable) — The bound dataloader getter
The dataloader
process_dict_result
(
output
, train=False
)
Reduces output according to the training mode.
Separates loss from logging and progress bar metrics
setup
(
model
, stage
)
Called in the beginning of fit and test
teardown
(
stage
)
Called at the end of fit and test
on_init_start
(
)
Called when the trainer initialization begins, model has not yet been set.
on_init_end
(
)
Called when the trainer initialization ends, model has not yet been set.
on_fit_start
(
)
Called when the trainer initialization begins, model has not yet been set.
on_fit_end
(
)
Called when the trainer initialization begins, model has not yet been set.
on_sanity_check_start
(
)
Called when the validation sanity check starts.
on_sanity_check_end
(
)
Called when the validation sanity check ends.
on_train_epoch_start
(
)
Called when the epoch begins.
on_train_epoch_end
(
outputs
)
Called when the epoch ends.
on_validation_epoch_start
(
)
Called when the epoch begins.
on_validation_epoch_end
(
)
Called when the epoch ends.
on_test_epoch_start
(
)
Called when the epoch begins.
on_test_epoch_end
(
)
Called when the epoch ends.
on_epoch_start
(
)
Called when the epoch begins.
on_epoch_end
(
)
Called when the epoch ends.
on_train_start
(
)
Called when the train begins.
on_train_end
(
)
Called when the train ends.
on_pretrain_routine_start
(
model
)
Called when the train begins.
on_pretrain_routine_end
(
model
)
Called when the train ends.
on_batch_start
(
)
Called when the training batch begins.
on_batch_end
(
)
Called when the training batch ends.
on_train_batch_start
(
batch
, batch_idx
, dataloader_idx
)
Called when the training batch begins.
on_train_batch_end
(
outputs
, batch
, batch_idx
, dataloader_idx
)
Called when the training batch ends.
on_validation_batch_start
(
batch
, batch_idx
, dataloader_idx
)
Called when the validation batch begins.
on_validation_batch_end
(
outputs
, batch
, batch_idx
, dataloader_idx
)
Called when the validation batch ends.
on_test_batch_start
(
batch
, batch_idx
, dataloader_idx
)
Called when the test batch begins.
on_test_batch_end
(
outputs
, batch
, batch_idx
, dataloader_idx
)
Called when the test batch ends.
on_validation_start
(
)
Called when the validation loop begins.
on_validation_end
(
)
Called when the validation loop ends.
on_test_start
(
)
Called when the test begins.
on_test_end
(
)
Called when the test ends.
on_keyboard_interrupt
(
)
Called when the training is interrupted by KeyboardInterrupt.
on_save_checkpoint
(
)
Called when saving a model checkpoint.
on_load_checkpoint
(
checkpoint
)
Called when loading a model checkpoint.
on_after_backward
(
)
Called after loss.backward() and before optimizers do anything.
on_before_zero_grad
(
optimizer
)
Called after optimizer.step() and before optimizer.zero_grad().
get_deprecated_arg_names
(
)
→ list
Returns a list with deprecated Trainer arguments.
tune
(
model
, train_dataloader=None
, val_dataloaders=None
, datamodule=None
)
Runs routines to tune hyperparameters before training.
model
(LightningModule) — Model to tune.train_dataloader
(DataLoader, optional) — A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders
(Union(dataloader, list of dataloader, nonetype), optional) — Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skippeddatamodule
(LightningDataModule, optional) — A instance of :class:LightningDataModule
.
available_plugins
(
)
List of all available plugins that can be string arguments to the trainer. Returns: List of all available plugins that are supported as string arguments.
from_config
(
config
, **kwargs
)
Create an instance from CLI arguments.
>>> config = {'my_custom_arg': 'something'}
>>> trainer = Trainer.from_dict(config, logger=False)
config
— The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the :class:Trainer
.**kwargs
— Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments.
from_config
(
config
, **kwargs
)
Create an instance from CLI arguments.
>>> config = {'my_custom_arg': 'something'}
>>> trainer = Trainer.from_dict(config, logger=False)
config
— The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the :class:Trainer
.**kwargs
— Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments.
fit
(
*args
, **kwargs
)
Train and validate the model
test
(
*args
, **kwargs
)
Test the model