mdlearn.nn.models.wae.model

Classes

WAE(*args, **kwargs)

Wasserstein autoencoder base class module.

class mdlearn.nn.models.wae.model.WAE(*args: Any, **kwargs: Any)

Wasserstein autoencoder base class module. Inherits from mdlearn.nn.models.vae.VAE.

mmdrf_loss(z: torch.Tensor, sigma: float, kernel: str, rf_dim: int, rf_resample: bool) torch.Tensor

Computes the loss \(|\mu_{real} - \mu_{fake}|_H\)

Parameters
  • z (torch.Tensor) – The \(z\)-latent vector.

  • sigma (float) – TODO

  • kernel (str) – The type of kernel function to use.

  • rf_dim (int) – Random features kernel dimension.

  • rf_resample (bool) – Whether ot not to resample the random features.

Returns

torch.Tensor – MMD RF loss.