mdlearn.nn.models.ae.model

Classes

AE(*args, **kwargs)

Autoencoder base class module.

class mdlearn.nn.models.ae.model.AE(*args: Any, **kwargs: Any)

Autoencoder base class module.

__init__(encoder: torch.nn.Module, decoder: torch.nn.Module)
Parameters
  • encoder (torch.nn.Module) – The encoder module.

  • decoder (torch.nn.Module) – The decoder module.

decode(*args, **kwargs)

Decoder forward pass.

encode(*args, **kwargs)

Encoder forward pass.

recon_loss(x: torch.Tensor, recon_x: torch.Tensor) torch.Tensor

Compute the reconstruction loss between x and recon_x.

Parameters
  • x (torch.Tensor) – The input data.

  • recon_x (torch.Tensor) – The reconstruction of the input data x.

Returns

torch.Tensor – The reconstruction loss between x and recon_x.

reset_parameters()

Reset encoder and decoder parameters.