mdlearn.utils

Configurations and utilities for model building and training.

Functions

get_torch_optimizer(name, hparams, parameters)

Construct a PyTorch optimizer specified by name and hparams.

get_torch_scheduler(name, hparams, optimizer)

Construct a PyTorch lr_scheduler specified by name and hparams.

log_checkpoint(checkpoint_file, epoch, ...)

Write a torch .pt file containing the epoch, model, optimizer, and scheduler.

parse_args()

Parse command line arguments using argparse library

resume_checkpoint(checkpoint_file, model, ...)

Modifies model, optimizer, and scheduler with values stored in torch .pt file checkpoint_file to resume from a previous training checkpoint.

pydantic model mdlearn.utils.BaseModel

Show JSON schema
{
   "title": "BaseModel",
   "type": "object",
   "properties": {}
}

dump_yaml(cfg_path: str | Path)
classmethod from_yaml(filename: str | Path) _T
pydantic model mdlearn.utils.OptimizerConfig

pydantic schema for PyTorch optimizer which allows for arbitrary optimizer hyperparameters.

Show JSON schema
{
   "title": "OptimizerConfig",
   "description": "pydantic schema for PyTorch optimizer which allows\nfor arbitrary optimizer hyperparameters.",
   "type": "object",
   "properties": {
      "name": {
         "default": "Adam",
         "title": "Name",
         "type": "string"
      },
      "hparams": {
         "default": {},
         "title": "Hparams",
         "type": "object"
      }
   },
   "additionalProperties": true
}

Config:
  • extra: str = allow

Fields:
field hparams: dict[str, Any] = {}
field name: str = 'Adam'
class Config
extra = 'allow'
pydantic model mdlearn.utils.SchedulerConfig

pydantic schema for PyTorch scheduler which allows for arbitrary scheduler hyperparameters.

Show JSON schema
{
   "title": "SchedulerConfig",
   "description": "pydantic schema for PyTorch scheduler which allows for arbitrary\nscheduler hyperparameters.",
   "type": "object",
   "properties": {
      "name": {
         "default": "ReduceLROnPlateau",
         "title": "Name",
         "type": "string"
      },
      "hparams": {
         "default": {},
         "title": "Hparams",
         "type": "object"
      }
   },
   "additionalProperties": true
}

Config:
  • extra: str = allow

Fields:
field hparams: dict[str, Any] = {}
field name: str = 'ReduceLROnPlateau'
class Config
extra = 'allow'
pydantic model mdlearn.utils.WandbConfig

Show JSON schema
{
   "title": "WandbConfig",
   "type": "object",
   "properties": {
      "wandb_project_name": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "title": "Wandb Project Name"
      },
      "wandb_entity_name": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "title": "Wandb Entity Name"
      },
      "wandb_model_tag": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "title": "Wandb Model Tag"
      }
   }
}

Fields:
field wandb_entity_name: str | None = None
field wandb_model_tag: str | None = None
field wandb_project_name: str | None = None
init(cfg: BaseModel, model: torch.nn.Module, wandb_path: PathLike) wandb.config | None

Initialize wandb with model and config.

Parameters:
  • cfg (BaseModel) – Model configuration with hyperparameters and training settings.

  • model (torch.nn.Module) – Model to train, passed to wandb.watch(model) for logging.

  • wandb_path (PathLike) – Path to write wandb/ directory containing training logs.

Returns:

Optional[wandb.config] – wandb config object or None if wandb_project_name is None.

mdlearn.utils.get_torch_optimizer(name: str, hparams: dict[str, Any], parameters) torch.optim.Optimizer

Construct a PyTorch optimizer specified by name and hparams.

mdlearn.utils.get_torch_scheduler(name: str | None, hparams: dict[str, Any], optimizer: torch.optim.Optimizer) torch.optim.lr_scheduler._LRScheduler | None

Construct a PyTorch lr_scheduler specified by name and hparams.

Parameters:
  • name (Optional[str]) – Name of PyTorch lr_scheduler class to use. If name is None, simply return None.

  • hparams (Dict[str, Any]) – Hyperparameters to pass to the lr_scheduler.

  • optimizer (torch.optim.Optimizer) – The initialized optimizer.

Returns:

Optional[torch.optim.lr_scheduler._LRScheduler] – The initialized PyTorch scheduler, or None if name is None.

mdlearn.utils.log_checkpoint(checkpoint_file: str | Path, epoch: int, model: torch.nn.Module, optimizers: dict[str, torch.optim.Optimizer], scheduler: torch.optim.lr_scheduler._LRScheduler | None = None)

Write a torch .pt file containing the epoch, model, optimizer, and scheduler.

Parameters:
  • checkpoint_file (PathLike) – Path to save checkpoint file.

  • epoch (int) – The current training epoch.

  • model (torch.nn.Module) – The model whose parameters are saved.

  • optimizers (Dict[str, torch.optim.Optimizer]) – The optimizers whose parameters are saved.

  • scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]) – Optional scheduler whose parameters are saved.

mdlearn.utils.parse_args() Namespace

Parse command line arguments using argparse library

Returns:

argparse.Namespace – Dict like object containing a path to a YAML file accessed via the config property.

Example

>>> from mdlearn.utils import parse_args
>>> args = parse_args()
>>> # MyConfig should inherit from BaseModel
>>> cfg = MyConfig.from_yaml(args.config)
mdlearn.utils.resume_checkpoint(checkpoint_file: str | Path, model: torch.nn.Module, optimizers: dict[str, torch.optim.Optimizer], scheduler: torch.optim.lr_scheduler._LRScheduler | None = None) int

Modifies model, optimizer, and scheduler with values stored in torch .pt file checkpoint_file to resume from a previous training checkpoint.

Parameters:
  • checkpoint_file (PathLike) – Path to checkpoint file to resume from.

  • model (torch.nn.Module) – Module to update the parameters of.

  • optimizers (Dict[str, torch.optim.Optimizer]) – Optimizers to update.

  • scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]) – Optional scheduler to update.

Returns:

int – The epoch the checkpoint is saved plus one i.e. the current training epoch to start from.