mdlearn.nn.models.vae.model

Classes

VAE(*args, **kwargs)

Variational autoencoder base class module.

class mdlearn.nn.models.vae.model.VAE(*args: Any, **kwargs: Any)

Variational autoencoder base class module. Inherits from mdlearn.nn.models.ae.AE.

__init__(encoder, decoder)
Parameters
  • encoder (torch.nn.Module) – The encoder module.

  • decoder (torch.nn.Module) – The decoder module.

encode(*args, **kwargs) torch.Tensor

Encoder forward pass and reparameterization of mu and logstd.

Parameters
  • *args – Variable length encoder argument list.

  • **kwargs – Arbitrary encoder keyword arguments.

Returns

torch.Tensor – The encoded \(z\)-latent batch tensor.

Notes

Clamps logstd using a max logstd of 10.

kld_loss(mu: Optional[torch.Tensor] = None, logstd: Optional[torch.Tensor] = None) torch.Tensor

Computes the KLD loss, either for the passed arguments mu and logstd, or based on latent variables from last encoding.

Parameters
  • mu (torch.Tensor, optional) – The latent space for \(\mu\). If set to None, uses the last computation of \(\mu\).

  • logstd (torch.Tensor, optional) – The latent space for \(\log\sigma\). If set to None, uses the last computation of \(\log\sigma^2\).

Returns

torch.Tensor – KL divergence loss given mu and logstd.

Notes

Clamps logstd using a max logstd of 10.

reparametrize(mu: torch.Tensor, logstd: torch.Tensor) torch.Tensor

Reparameterization trick for mu and logstd.

Parameters
  • mu (torch.Tensor) – First encoder output.

  • logstd (torch.Tensor) – Second encoder output.

Returns

torch.Tensor – If training, return the reparametrized output. Otherwise, return mu.