mdlearn.nn.models.wae.model
Classes
|
Wasserstein autoencoder base class module. |
- class mdlearn.nn.models.wae.model.WAE(*args: Any, **kwargs: Any)
Wasserstein autoencoder base class module. Inherits from
mdlearn.nn.models.vae.VAE
.- mmdrf_loss(z: torch.Tensor, sigma: float, kernel: str, rf_dim: int, rf_resample: bool) torch.Tensor
Computes the loss \(|\mu_{real} - \mu_{fake}|_H\)
- Parameters
z (torch.Tensor) – The \(z\)-latent vector.
sigma (float) – TODO
kernel (str) – The type of kernel function to use.
rf_dim (int) – Random features kernel dimension.
rf_resample (bool) – Whether ot not to resample the random features.
- Returns
torch.Tensor – MMD RF loss.