mdlearn.nn.models.ae.model
Classes
|
Autoencoder base class module. |
- class mdlearn.nn.models.ae.model.AE(*args: Any, **kwargs: Any)
Autoencoder base class module.
- __init__(encoder: torch.nn.Module, decoder: torch.nn.Module)
- Parameters:
encoder (torch.nn.Module) – The encoder module.
decoder (torch.nn.Module) – The decoder module.
- decode(*args: Any, **kwargs: Any) torch.Tensor
Decoder forward pass.
- encode(*args: Any, **kwargs: Any) torch.Tensor
Encoder forward pass.
- recon_loss(x: torch.Tensor, recon_x: torch.Tensor) torch.Tensor
Compute the reconstruction loss between
xandrecon_x.- Parameters:
x (torch.Tensor) – The input data.
recon_x (torch.Tensor) – The reconstruction of the input data
x.
- Returns:
torch.Tensor – The reconstruction loss between
xandrecon_x.
- reset_parameters()
Reset encoder and decoder parameters.