Skip to content

plkit.data

module

plkit.data

Data module for plkit

Classes
class

plkit.data.Dataset(data, ids=None)

Bases
torch.utils.data.dataset.Dataset

The dataset that used internally by Data class

Examples
>>> ds = Dataset(data=[('a', 'x'), ('b', 'y'), ('c', 'z')], ids=[1, 2])
>>> len(ds) == 2
>>> ds[0] == ('b', 'y')
>>> ds[1] == ('c', 'z')
>>> # The features are what you get by
>>> # x, y = batch
Parameters
  • data (iterable of tuple) The data for the dataset. It could be a tuple of features. Each one should be an iterable, which could be accessed by index
  • ids (list of int, optional) The ids or keys of the data, which should be in the same order of each feature in the iterable.
class

plkit.data.IterDataset(data, length)

Bases
torch.utils.data.dataset.IterableDataset torch.utils.data.dataset.Dataset

Iterable dataset

The iterable dataset where each feature of the data is an iterable

Examples
>>> feat1 = (x for x in range(10)
>>> feat2 = (x for x in range(10)
>>> ds = IterDataset(zip(feat1, feat2), ids=[4,3])
>>> next(ds) == (0, 0)
Parameters
  • data (iterable of tuple) a tuple of iterable features
  • length (int) The length of the iterables
class

plkit.data.DataModule(*args, **kwargs)

Bases
pytorch_lightning.core.datamodule.LightningDataModule pytorch_lightning.core.hooks.DataHooks pytorch_lightning.core.hooks.CheckpointHooks

Data module for plkit

Attributes
  • dims A tuple describing the shape of your data. Extra functionality exposed in size.</>
  • has_prepared_data Return bool letting you know if datamodule.prepare_data() has been called or not.</>
  • has_setup_fit Return bool letting you know if datamodule.setup('fit') has been called or not.</>
  • has_setup_test Return bool letting you know if datamodule.setup('test') has been called or not.</>
  • length The length of the data
    This is required when self.data_reader() yields (it is a generator)</>
  • test_transforms Optional transforms (or collection of transforms) you can apply to test dataset</>
  • train_transforms Optional transforms (or collection of transforms) you can apply to train dataset</>
  • val_transforms Optional transforms (or collection of transforms) you can apply to validation dataset</>
Classes
  • _DataModuleWrapper type(object_or_name, bases, dict) type(object) -> the object's type type(name, bases, dict) -> a new type</>
Methods
  • add_argparse_args(parent_parser) (ArgumentParser) Extends existing argparse by default LightningDataModule attributes.</>
  • data_reader() (Union(iterable of any, (iterator of any))) Read the data</>
  • data_splits(data, stage) (dict(str: Union(dataset, iterabledataset, list of dataset or iterabledataset))) Split data from data_source for each dataloader</>
  • from_argparse_args(args, **kwargs) Create an instance from CLI arguments.</>
  • get_init_arguments_and_types() (List with tuples of 3 values) Scans the DataModule signature and returns argument names, types and default values.</>
  • on_load_checkpoint(checkpoint) Called by Lightning to restore your model. If you saved something with :meth:on_save_checkpoint this is your chance to restore this.</>
  • on_save_checkpoint(checkpoint) Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.</>
  • prepare_data(*args, **kwargs) Prepare data</>
  • setup(stage) Setup data</>
  • size(dim) (tuple or int) Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.</>
  • test_dataloader(*args, **kwargs) (Union(dataloader, list of dataloader)) Test data loaders</>
  • train_dataloader(*args, **kwargs) (DataLoader) Train data loaders</>
  • transfer_batch_to_device(batch, device) (any) Override this hook if your :class:~torch.utils.data.DataLoader returns tensors wrapped in a custom data structure.</>
  • val_dataloader(*args, **kwargs) (Union(dataloader, list of dataloader)) Validation data loaders</>
method

on_load_checkpoint(checkpoint)

Called by Lightning to restore your model. If you saved something with :meth:on_save_checkpoint this is your chance to restore this.

Parameters
  • checkpoint (dict(str: any)) Loaded checkpoint
Example

.. code-block:: python

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

method

on_save_checkpoint(checkpoint)

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters
  • checkpoint (dict(str: any)) Checkpoint to be saved
Example

.. code-block:: python

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc...) including amp scaling. There is no need for you to store anything about training.

class

pytorch_lightning.core.datamodule._DataModuleWrapper(*args, **kwargs)

type(object_or_name, bases, dict) type(object) -> the object's type type(name, bases, dict) -> a new type

Methods
  • __call__(cls, *args, **kwargs) A wrapper for LightningDataModule that:</>
staticmethod
__call__(cls, *args, **kwargs)

A wrapper for LightningDataModule that:

  1. Runs user defined subclass's init
  2. Assures prepare_data() runs on rank 0
  3. Lets you check prepare_data and setup to see if they've been called
method

size(dim=None) → tuple or int

Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.

abstract method

transfer_batch_to_device(batch, device)

Override this hook if your :class:~torch.utils.data.DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

  • :class:torch.Tensor or anything that implements .to(...)
  • :class:list
  • :class:dict
  • :class:tuple
  • :class:torchtext.data.batch.Batch

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).

Example::

def transfer_batch_to_device(self, batch, device)
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    else:
        batch = super().transfer_batch_to_device(data, device)
    return batch
Parameters
  • batch (any) A batch of data that needs to be transferred to a new device.
  • device (device) The target device as defined in PyTorch.
Returns (any)

A reference to the data on the new device.

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).

Note

This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support for your custom batch objects, you need to define your custom :class:~torch.nn.parallel.DistributedDataParallel or :class:~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel and override :meth:~pytorch_lightning.core.lightning.LightningModule.configure_ddp.

See Also
  • :func:~pytorch_lightning.utilities.apply_func.move_data_to_device
  • :func:~pytorch_lightning.utilities.apply_func.apply_to_collection
classmethod

add_argparse_args(parent_parser) → ArgumentParser

Extends existing argparse by default LightningDataModule attributes.

classmethod

from_argparse_args(args, **kwargs)

Create an instance from CLI arguments.

Parameters
  • args (Namespace or ArgumentParser) The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the :class:LightningDataModule.
  • **kwargs Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments.

Example::

parser = ArgumentParser(add_help=False)
parser = LightningDataModule.add_argparse_args(parser)
module = LightningDataModule.from_argparse_args(args)
classmethod

get_init_arguments_and_types()

Scans the DataModule signature and returns argument names, types and default values.

Returns (List with tuples of 3 values)

.

method

data_reader()

Read the data

Returns (Union(iterable of any, (iterator of any)))

A tuple of iterables of features. Or it yields the following

Yields (Union(iterable of any, (iterator of any)))

An iterable of tuple of features. In such a case, self.length property is required to be defined.

method

data_splits(data=None, stage=None)

Split data from data_source for each dataloader

Parameters
  • data (iterable of tuple, optional) The data read by self.data_reader()
  • stage (str, optional) The stage argument same as the one from LightningDataModule.setup(...)
Returns (dict(str: Union(dataset, iterabledataset, list of dataset or iterabledataset)))

A dictionary with keys train, val and test, and values a Dataset or an IterDataset (config.data_tvt will be ignored)

Or if config.data_tvt is specified, one could just return an iterable of features, then the dataset will be automatically split by config.data_tvt

method

prepare_data(*args, **kwargs)

Prepare data

method

setup(stage=None)

Setup data

method

train_dataloader(*args, **kwargs) → DataLoader

Train data loaders

method

val_dataloader(*args, **kwargs) → Union(dataloader, list of dataloader)

Validation data loaders

method

test_dataloader(*args, **kwargs) → Union(dataloader, list of dataloader)

Test data loaders