mdlearn.data.utils
Utility functions for handling PyTorch data objects.
Functions
|
Creates training and validation DataLoaders from |
- mdlearn.data.utils.train_valid_split(dataset: torch.utils.data.Dataset, split_pct: float = 0.8, method: str = 'random', **kwargs) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]
Creates training and validation DataLoaders from
dataset
.- Parameters
dataset (Dataset) – A PyTorch dataset class derived from
torch.utils.data.Dataset
.split_pct (float) – Percentage of data to be used as training data after a split.
method (str, default=”random”) – Method to split the data. For random split use “random”, for a simple partition, use “partition”.
**kwargs – Keyword arguments to
torch.utils.data.DataLoader
. Includes,batch_size
,drop_last
, etc (see PyTorch Docs).
- Raises
ValueError – If
method
is not “random” or “partition”.