mdlearn.nn.models.vae.model
Classes
|
Variational autoencoder base class module. |
- class mdlearn.nn.models.vae.model.VAE(*args: Any, **kwargs: Any)
Variational autoencoder base class module. Inherits from
mdlearn.nn.models.ae.AE
.- __init__(encoder, decoder)
- Parameters
encoder (torch.nn.Module) – The encoder module.
decoder (torch.nn.Module) – The decoder module.
- encode(*args, **kwargs) torch.Tensor
Encoder forward pass and reparameterization of mu and logstd.
- Parameters
*args – Variable length encoder argument list.
**kwargs – Arbitrary encoder keyword arguments.
- Returns
torch.Tensor – The encoded \(z\)-latent batch tensor.
Notes
Clamps logstd using a max logstd of 10.
- kld_loss(mu: Optional[torch.Tensor] = None, logstd: Optional[torch.Tensor] = None) torch.Tensor
Computes the KLD loss, either for the passed arguments
mu
andlogstd
, or based on latent variables from last encoding.- Parameters
mu (torch.Tensor, optional) – The latent space for \(\mu\). If set to
None
, uses the last computation of \(\mu\).logstd (torch.Tensor, optional) – The latent space for \(\log\sigma\). If set to
None
, uses the last computation of \(\log\sigma^2\).
- Returns
torch.Tensor – KL divergence loss given
mu
andlogstd
.
Notes
Clamps logstd using a max logstd of 10.
- reparametrize(mu: torch.Tensor, logstd: torch.Tensor) torch.Tensor
Reparameterization trick for
mu
andlogstd
.- Parameters
mu (torch.Tensor) – First encoder output.
logstd (torch.Tensor) – Second encoder output.
- Returns
torch.Tensor – If training, return the reparametrized output. Otherwise, return
mu
.