mdlearn.nn.models.aae.model

Classes

AAE(*args, **kwargs)

ChamferLoss(*args, **kwargs)

class mdlearn.nn.models.aae.model.AAE(*args: Any, **kwargs: Any)
discriminate(*args, **kwargs) torch.Tensor

Discriminator forward pass.

Parameters
  • *args – Variable length discriminator argument list.

  • **kwargs – Arbitrary discriminator keyword arguments.

Returns

torch.Tensor – The discriminator output.

reset_parameters() None

Reset encoder, decoder and discriminator parameters.

class mdlearn.nn.models.aae.model.ChamferLoss(*args: Any, **kwargs: Any)
batch_pairwise_dist(x: torch.Tensor, y: torch.Tensor) torch.Tensor
forward(preds: torch.Tensor, gts: torch.Tensor) torch.Tensor