mdlearn.nn.models.aae.point_3d_aae
Adversarial Autoencoder for 3D point cloud data (3dAAE)
Classes
|
- class mdlearn.nn.models.aae.point_3d_aae.AAE3d(*args: Any, **kwargs: Any)
- __init__(num_points: int, num_features: int = 0, latent_dim: int = 20, encoder_bias: bool = True, encoder_relu_slope: float = 0.0, encoder_filters: List[int] = [64, 128, 256, 256, 512], encoder_kernels: List[int] = [5, 5, 3, 1, 1], decoder_bias: bool = True, decoder_relu_slope: float = 0.0, decoder_affine_widths: List[int] = [64, 128, 512, 1024], discriminator_bias: bool = True, discriminator_relu_slope: float = 0.0, discriminator_affine_widths: List[int] = [512, 128, 64])
Adversarial Autoencoder module for point cloud data from the “Adversarial Autoencoders for Compact Representations of 3D Point Clouds” paper and adapted to work on atomic coordinate data in the “AI-Driven Multiscale Simulations Illuminate Mechanisms of SARS-CoV-2 Spike Dynamics” paper. Inherits from
mdlearn.nn.models.aae.AAE
.- Parameters
num_points (int) – Number of input points in point cloud.
num_features (int, optional) – Number of scalar features per point in addition to 3D coordinates, by default 0
latent_dim (int, optional) – Latent dimension of the encoder, by default 20
encoder_bias (bool, optional) – Use a bias term in the encoder Conv1d layers, by default True.
encoder_relu_slope (float, optional) – If greater than 0.0, will use LeakyReLU activiation in the encoder with
negative_slope
set torelu_slope
, by default 0.0encoder_filters (List[int], optional) – Encoder Conv1d filter sizes, by default [64, 128, 256, 256, 512]
encoder_kernels (List[int], optional) – Encoder Conv1d kernel sizes, by default [5, 5, 3, 1, 1]
decoder_bias (bool, optional) – Use a bias term in the decoder Linear layers, by default True
decoder_relu_slope (float, optional) – If greater than 0.0, will use LeakyReLU activiation in the decoder with
negative_slope
set torelu_slope
, by default 0.0decoder_affine_widths (List[int], optional) – Decoder Linear layers
in_features
, by default [64, 128, 512, 1024]discriminator_bias (bool, optional) – Use a bias term in the discriminator Linear layers, by default True.
discriminator_relu_slope (float, optional) – If greater than 0.0, will use LeakyReLU activiation in the discriminator with
negative_slope
set torelu_slope
, by default 0.0discriminator_affine_widths (List[int], optional) – Discriminator Linear layers
in_features
, by default [512, 128, 64]
- critic_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) torch.Tensor
Classification loss (critic) function.
- Parameters
real_logits (torch.Tensor) – Discriminator output logits from prior distribution.
fake_logits (torch.Tensor) – Discriminator output logits from encoded latent vectors.
- Returns
torch.Tensor – Classification loss i.e. the difference between logit means.
- decoder_loss(fake_logit: torch.Tensor) torch.Tensor
Decoder/Generator loss.
- Parameters
fake_logit (torch.Tensor) – Output of discriminator.
- Returns
torch.Tensor – Negative mean of the fake logits.
- forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]
Forward pass of encoder and decoder.
- Parameters
x (torch.Tensor) – Input point cloud data.
- Returns
Tuple[torch.Tensor, torch.Tensor] – The \(z\)-latent vector, and the
recon_x
reconstruction.
- gp_loss(noise: torch.Tensor, z: torch.Tensor) torch.Tensor
Gradient penalty loss function.
- Parameters
noise ([type]) – Random noise sampled from prior distribution.
z ([type]) – Encoded latent vectors.
- Returns
torch.Tensor – The gradient penalty loss.
- recon_loss(x: torch.Tensor, recon_x: torch.Tensor) torch.Tensor
Reconstruction loss using ChamferLoss.
- Parameters
x (torch.Tensor) – The original input tensor.
recon_x (torch.Tensor) – The reconstructed output tensor.
- Returns
torch.Tensor – Reconstruction loss measured by Chamfer distance.