Skip to content

Configurations

Ones of the principles of plkit is to try to put configuration items in just one dictionary for data and module construction. Any items that work as arguments for pytorch-lightning's Trainer initialization could be valid configuration items (See Trainer API from pytorch-lightning's documentation).

We do have some different or additional configuration items, in terms of their values or behaviors.

seed

For full reproducibility, one should call seed_everything and set deterministic to True for trainer initialization using pytorch-lightning (see Reproducibility)

However, with plkit, you only need to set a seed in the configuration (seed_everything will be set automatically), and deterministic will be automatically set to True for trainer initialization.

If you don't want deterministic to be True when a seed is specified, you can set deterministric to False in configuration.

num_classes

Specification of num_classes in configuration ensures the builtin measurement calling the right loss function and metric for the output and labels (see configuration loss and Builtin measurement for more details).

data_num_workers

num_workers argument of DataLoader for DataModule

data_tvt

Train-val-test ratio for splitting the data read by DataModule.data_reader.

It could be a tuple with no more than 3 elements or just a single scalar element. The elements could be ratios (<=1) or absolute numbers.

The first element is for train set and later two are for validation and test sets, which can be a list respectively as multiple sets for validation and test.

If the ratio or number is not specified for the corresponding dataset, such dataset will not be generated. For example:

data_tvt Meaning
.7 Use 70% of data for training (no val or test data).
(.7, .1) Use 70% for training, 10% for val (no test data)
(.7, .15, .15) Use 70% for training, 15% for val and 15% for test
300 Use 300 samples for training
(300, [100, 100], [100, 100] Use 300 samples for training, 100 samples for validation (x2) and 100 for testing (x2)

optim

The name of optimizer (Currently only adam and sgd are supported).

learning_rate

Learning rate for optimizers

momentum

Momentum for SGD optimizer.

loss

The loss function. It's auto by default, meaning nn.MSELoss() for regression (num_classes=1) and nn.CrossEntropyLoss() for classification. You can specifiy your own loss function: loss=nn.L1Loss() for example.