mdlearn.nn.models.aae.point_3d_aae

Adversarial Autoencoder for 3D point cloud data (3dAAE)

Classes

AAE3d(*args, **kwargs)

class mdlearn.nn.models.aae.point_3d_aae.AAE3d(*args: Any, **kwargs: Any)
__init__(num_points: int, num_features: int = 0, latent_dim: int = 20, encoder_bias: bool = True, encoder_relu_slope: float = 0.0, encoder_filters: List[int] = [64, 128, 256, 256, 512], encoder_kernels: List[int] = [5, 5, 3, 1, 1], decoder_bias: bool = True, decoder_relu_slope: float = 0.0, decoder_affine_widths: List[int] = [64, 128, 512, 1024], discriminator_bias: bool = True, discriminator_relu_slope: float = 0.0, discriminator_affine_widths: List[int] = [512, 128, 64])

Adversarial Autoencoder module for point cloud data from the “Adversarial Autoencoders for Compact Representations of 3D Point Clouds” paper and adapted to work on atomic coordinate data in the “AI-Driven Multiscale Simulations Illuminate Mechanisms of SARS-CoV-2 Spike Dynamics” paper. Inherits from mdlearn.nn.models.aae.AAE.

Parameters
  • num_points (int) – Number of input points in point cloud.

  • num_features (int, optional) – Number of scalar features per point in addition to 3D coordinates, by default 0

  • latent_dim (int, optional) – Latent dimension of the encoder, by default 20

  • encoder_bias (bool, optional) – Use a bias term in the encoder Conv1d layers, by default True.

  • encoder_relu_slope (float, optional) – If greater than 0.0, will use LeakyReLU activiation in the encoder with negative_slope set to relu_slope, by default 0.0

  • encoder_filters (List[int], optional) – Encoder Conv1d filter sizes, by default [64, 128, 256, 256, 512]

  • encoder_kernels (List[int], optional) – Encoder Conv1d kernel sizes, by default [5, 5, 3, 1, 1]

  • decoder_bias (bool, optional) – Use a bias term in the decoder Linear layers, by default True

  • decoder_relu_slope (float, optional) – If greater than 0.0, will use LeakyReLU activiation in the decoder with negative_slope set to relu_slope, by default 0.0

  • decoder_affine_widths (List[int], optional) – Decoder Linear layers in_features, by default [64, 128, 512, 1024]

  • discriminator_bias (bool, optional) – Use a bias term in the discriminator Linear layers, by default True.

  • discriminator_relu_slope (float, optional) – If greater than 0.0, will use LeakyReLU activiation in the discriminator with negative_slope set to relu_slope, by default 0.0

  • discriminator_affine_widths (List[int], optional) – Discriminator Linear layers in_features, by default [512, 128, 64]

critic_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) torch.Tensor

Classification loss (critic) function.

Parameters
  • real_logits (torch.Tensor) – Discriminator output logits from prior distribution.

  • fake_logits (torch.Tensor) – Discriminator output logits from encoded latent vectors.

Returns

torch.Tensor – Classification loss i.e. the difference between logit means.

decoder_loss(fake_logit: torch.Tensor) torch.Tensor

Decoder/Generator loss.

Parameters

fake_logit (torch.Tensor) – Output of discriminator.

Returns

torch.Tensor – Negative mean of the fake logits.

forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]

Forward pass of encoder and decoder.

Parameters

x (torch.Tensor) – Input point cloud data.

Returns

Tuple[torch.Tensor, torch.Tensor] – The \(z\)-latent vector, and the recon_x reconstruction.

gp_loss(noise: torch.Tensor, z: torch.Tensor) torch.Tensor

Gradient penalty loss function.

Parameters
  • noise ([type]) – Random noise sampled from prior distribution.

  • z ([type]) – Encoded latent vectors.

Returns

torch.Tensor – The gradient penalty loss.

recon_loss(x: torch.Tensor, recon_x: torch.Tensor) torch.Tensor

Reconstruction loss using ChamferLoss.

Parameters
  • x (torch.Tensor) – The original input tensor.

  • recon_x (torch.Tensor) – The reconstructed output tensor.

Returns

torch.Tensor – Reconstruction loss measured by Chamfer distance.