mdlearn.data.utils

Utility functions for handling PyTorch data objects.

Functions

train_valid_split(dataset[, split_pct, method])

Creates training and validation DataLoaders from dataset.

mdlearn.data.utils.train_valid_split(dataset: torch.utils.data.Dataset, split_pct: float = 0.8, method: str = 'random', **kwargs) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]

Creates training and validation DataLoaders from dataset.

Parameters
  • dataset (Dataset) – A PyTorch dataset class derived from torch.utils.data.Dataset.

  • split_pct (float) – Percentage of data to be used as training data after a split.

  • method (str, default=”random”) – Method to split the data. For random split use “random”, for a simple partition, use “partition”.

  • **kwargs – Keyword arguments to torch.utils.data.DataLoader. Includes, batch_size, drop_last, etc (see PyTorch Docs).

Raises

ValueError – If method is not “random” or “partition”.