Module
Class Module
is subclassed from LightningModule
of pytorch-lightning
.
The differences are list below.
Initialization
Every instance of Module
needs a config
argument to initialize. It collapses the OptunaSuggest
objects in it and becomes an attributes of the instance. So the configuration item can be accessed by self.config.xxx
.
Danger
self.config
is a FrozenDiot
object to prevent the changes to be passed down to other modules (See diot).
on_epoch_end
An empty line is printed on every epoch end to keep the progress bar of the previous epoch run so that we can easily keep track of the metrics of all epochs. If you want to disable this just overwrite on_epoch_end
with a pass
statement. Or if you want to keep this and do something else in on_epoch_end
, you should call super().on_epoch_end()
inside your on_epoch_end
.
loss_function
Calculate loss using the loss function according to the loss
configuration item.
configure_optimizers
You don't have to implement this function if you have optim
specified in configuration.
Note
Currently only adam
and sgd
are supported. Please also specify learning_rate
and/or momentum
for corresponding optimizers (see configurations).
measure
This method calculates a metric according to num_classes
. For num_classes=1
(regression), available metrics are mse
, rmse
, mae
and rmsle
. And for classifications, avaiable metrics are accuracy
, precision
, recall
, f1_score
, iou
, fbeta_score
, auroc
, average_precision
and dice_score
.
To use it in your training_step
, validation_step
or test_step
, you can do:
self.measure(output, labels, 'accuracy')
self.measure(output, labels, 'fbeta_score', beta=1.0)
For more details of the metrics, see pytorch-lightning
's documentation