"""Data module for plkit"""
from types import GeneratorType
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from itertools import islice
from diot import FrozenDiot
from pytorch_lightning import LightningDataModule
from torch.utils.data import (
DataLoader,
Dataset as TorchDataset,
IterableDataset as TorchIterableDataset,
random_split
)
from .exceptions import PlkitDataException
from .utils import (
normalize_tvt_ratio,
check_config,
logger
)
# pylint: disable=unused-argument
# The ids or keys for the data
DatasetType = Union[TorchDataset, TorchIterableDataset]
class Dataset(TorchDataset):DOCS
"""The dataset that used internally by Data class
Examples:
>>> ds = Dataset(data=[('a', 'x'), ('b', 'y'), ('c', 'z')], ids=[1, 2])
>>> len(ds) == 2
>>> ds[0] == ('b', 'y')
>>> ds[1] == ('c', 'z')
>>> # The features are what you get by
>>> # x, y = batch
Args:
data: The data for the dataset.
It could be a tuple of features. Each one should be an iterable,
which could be accessed by index
ids: The ids or keys of the data, which should be in the same order
of each feature in the iterable.
"""
def __init__(self,
data: Iterable[tuple],
ids: Optional[List[int]] = None) -> None:
self.data = data
self.ids = ids or list(range(len(data)))
def __len__(self) -> int:
return len(self.ids)
def __getitem__(self, idx: Union[int, str]) -> Tuple[Any]:
data_id = self.ids[idx]
return self.data[data_id]
class IterDataset(TorchIterableDataset):DOCS
"""Iterable dataset
The iterable dataset where each feature of the data is an iterable
Examples:
>>> feat1 = (x for x in range(10)
>>> feat2 = (x for x in range(10)
>>> ds = IterDataset(zip(feat1, feat2), ids=[4,3])
>>> next(ds) == (0, 0)
Args:
data: a tuple of iterable features
length: The length of the iterables
"""
# pylint: disable=abstract-method
def __init__(self,
data: Iterable[tuple],
length: int) -> None:
self.data = data
self.length = length
def __iter__(self) -> Iterable[Any]:
return iter(self.data)
class DataModule(LightningDataModule):DOCS
"""Data module for plkit"""
def __init__(self,
train_transforms=None,
val_transforms=None,
test_transforms=None,
config: Optional[FrozenDiot] = None) -> None:
super().__init__(train_transforms=train_transforms,
val_transforms=val_transforms,
test_transforms=test_transforms)
self.config = config or FrozenDiot()
check_config(self.config, 'batch_size')
self.num_workers = self.config.get('data_num_workers', 0)
self.tvt_ratio = normalize_tvt_ratio(self.config.get('data_tvt'))
self.data = None
self.splits = None
self._length = None
@propertyDOCS
def length(self) -> int:
"""The length of the data
This is required when `self.data_reader()` yields (it is a generator)
Returns:
The length of the data.
"""
return self._length
def data_reader(self) -> Union[Iterable[Any], Tuple[Iterator[Any]]]:DOCS
"""Read the data
Returns:
A tuple of iterables of features. Or it yields the following
Yields:
An iterable of tuple of features. In such a case, self.length
property is required to be defined.
"""
raise NotImplementedError # pragma: no cover
def _split_data_generator(
self, data: Iterable[tuple]
) -> Dict[str, Union[TorchIterableDataset, List[TorchIterableDataset]]]:
ret = {}
is_ratio = self.tvt_ratio[0] <= 1.0
if is_ratio and self.length is None:
raise PlkitDataException(
'Got generator from `data_reader` and ratios from '
'`config.data_tvt`, `self.length` should be recorded '
'in `data_reader`.'
)
# split using islice
start = 0
train_len = (round(self.tvt_ratio[0] * float(self.length))
if is_ratio else self.tvt_ratio[0])
ret['train'] = IterDataset(islice(data, start, train_len), train_len)
start += train_len
if self.tvt_ratio[1]:
ret['val'] = []
for val_ratio in self.tvt_ratio[1]:
val_len = (round(val_ratio * float(self.length))
if is_ratio else val_ratio)
ret['val'].append(IterDataset(
islice(data, start, start + val_len), val_len
))
start += val_len
if self.tvt_ratio[2]:
ret['test'] = []
for test_ratio in self.tvt_ratio[2]:
test_len = (round(test_ratio * float(self.length))
if is_ratio else test_ratio)
ret['test'].append(IterDataset(
islice(data, start, start + test_len), test_len
))
start += test_len
return ret
def _split_data_list(
self, data: List[Any]
) -> Dict[str, Union[IterDataset, List[IterDataset]]]:
ret = {}
is_ratio = self.tvt_ratio[0] <= 1.0
self._length = len(data)
all_ids = range(self.length)
train_len = (round(self.tvt_ratio[0] * float(self.length))
if is_ratio else self.tvt_ratio[0])
train_ids, rest_ids = random_split(
all_ids, [train_len, len(all_ids) - train_len]
)
ret['train'] = Dataset(data, train_ids)
if self.tvt_ratio[1]:
ret['val'] = []
for val_ratio in self.tvt_ratio[1]:
val_len = (round(val_ratio * float(self.length))
if is_ratio else val_ratio)
val_ids, rest_ids = random_split(
rest_ids, [val_len, len(rest_ids) - val_len]
)
ret['val'].append(Dataset(data, val_ids))
if self.tvt_ratio[2]:
ret['test'] = []
for test_ratio in self.tvt_ratio[2]:
test_len = (round(test_ratio * float(self.length))
if is_ratio else test_ratio)
test_ids, rest_ids = random_split(
rest_ids, [test_len, len(rest_ids) - test_len]
)
ret['test'].append(Dataset(data, test_ids))
return ret
def data_splits( # pylint: disable=unused-argumentDOCS
self,
data: Optional[Iterable[tuple]] = None,
stage: Optional[str] = None
) -> Dict[str, Union[DatasetType, List[DatasetType]]]:
"""Split data from data_source for each dataloader
Args:
data: The data read by self.data_reader()
stage: The stage argument same as the one from
`LightningDataModule.setup(...)`
Returns:
A dictionary with keys `train`, `val` and `test`, and values a
Dataset or an IterDataset (config.data_tvt will be ignored)
Or if config.data_tvt is specified, one could just return an
iterable of features, then the dataset will be automatically
split by config.data_tvt
"""
if not self.tvt_ratio:
return None
data = data or self.data
if isinstance(data, GeneratorType):
return self._split_data_generator(data)
return self._split_data_list(data)
def prepare_data(self, *args, **kwargs) -> None:DOCS
"""Prepare data"""
logger.info('Reading data ...')
self.data = self.data_reader()
def setup(self, stage: Optional[str] = None) -> None:DOCS
"""Setup data"""
if stage == 'fit':
# Only do it once.
# If you want it to be separate
# redefine this method
logger.info('Splitting data ...')
self.splits = self.data_splits(self.data, stage)
if not self.tvt_ratio and not self.splits:
raise PlkitDataException(
'No train-val-test ratio (data-tvt) specified in '
'configuration, then `data_splits` method should be '
'implemented for DataModule.'
)
def train_dataloader(self, *args, **kwargs) -> DataLoader:DOCS
"""Train data loaders"""
if 'train' not in self.splits:
return None
return DataLoader(self.splits['train'],
batch_size=self.config.batch_size,
num_workers=self.num_workers)
def val_dataloader(self,DOCS
*args,
**kwargs) -> Union[DataLoader, List[DataLoader]]:
"""Validation data loaders"""
if 'val' not in self.splits:
return None
ret = []
for val_data in self.splits['val']:
ret.append(DataLoader(val_data,
batch_size=self.config.batch_size,
num_workers=self.num_workers))
return ret[0] if len(ret) == 1 else ret
def test_dataloader(self,DOCS
*args,
**kwargs) -> Union[DataLoader, List[DataLoader]]:
"""Test data loaders"""
if 'test' not in self.splits:
return None
ret = []
for test_data in self.splits['test']:
ret.append(DataLoader(test_data,
batch_size=self.config.batch_size,
num_workers=self.num_workers))
return ret[0] if len(ret) == 1 else ret