"""Wrapper of the Trainer class"""
import inspect
from pytorch_lightning import Trainer as PlTrainer
from pytorch_lightning.callbacks.progress import (
ProgressBar as PlProgressBar,
ProgressBarBase
)
from .utils import collapse_suggest_config, warning_to_logging
class ProgressBar(PlProgressBar):DOCS
"""Align the Epoch in progress bar"""
def on_epoch_start(self, trainer, pl_module):DOCS
"""Try to align the epoch number"""
super().on_epoch_start(trainer, pl_module)
if self.max_epochs:
nchar = len(str(self.max_epochs - 1))
self.main_progress_bar.set_description(
f'Epoch {str(trainer.current_epoch).rjust(nchar)}'
)
class Trainer(PlTrainer): # pylint: disable=too-many-ancestorsDOCS
"""The Trainner class
`from_config` (aka `from_dict`) added as classmethod to instantiate trainer
from configuration dictionaries.
"""
# pylint: disable=signature-differs
@classmethodDOCS
def from_config(cls, config, **kwargs):
"""Create an instance from CLI arguments.
Examples:
>>> config = {'my_custom_arg': 'something'}
>>> trainer = Trainer.from_dict(config, logger=False)
Args:
config: The parser or namespace to take arguments from.
Only known arguments will be
parsed and passed to the :class:`Trainer`.
**kwargs: Additional keyword arguments that may override ones in
the parser or namespace.
These must be valid Trainer arguments.
"""
# we only want to pass in valid Trainer args,
# the rest may be user specific
valid_kwargs = inspect.signature(PlTrainer.__init__).parameters
trainer_kwargs = dict((name, config[name])
for name in valid_kwargs if name in config)
trainer_kwargs.update(**kwargs)
trainer_kwargs = collapse_suggest_config(trainer_kwargs)
return cls(**trainer_kwargs)
from_dict = from_config
def __init__(self, *args, **kwargs):
kwargs.setdefault('callbacks', [])
if not any(isinstance(callback, ProgressBarBase)
for callback in kwargs['callbacks']):
pbar_kwargs = {
('refresh_rate' if key == 'progress_bar_refresh_rate'
else key) : val
for key, val in kwargs.items()
if key in ('process_position', 'progress_bar_refresh_rate')
}
pbar = ProgressBar(**pbar_kwargs)
pbar.max_epochs = kwargs.get('max_epochs')
kwargs['callbacks'].append(pbar)
with warning_to_logging():
super().__init__(*args, **kwargs)
@propertyDOCS
def progress_bar_dict(self) -> dict:
"""Format progress bar metrics. """
metrics = super().progress_bar_dict
metrics = {
key: '%.3f' % val if isinstance(val, float) else val
for key, val in metrics.items()
}
return metrics
def fit(self, *args, **kwargs):DOCS
"""Train and validate the model"""
with warning_to_logging():
super().fit(*args, **kwargs)
def test(self, *args, **kwargs):DOCS
"""Test the model"""
with warning_to_logging():
super().test(*args, **kwargs)