from torch import nn

class ParticleAE(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_dim=32,
        latent_dim=16,
        activation=nn.ReLU,
    ):
        super().__init__()

        self.in_features = in_features
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.activation = activation()

        self.encoder = nn.Sequential(
            nn.Linear(in_features, self.hidden_dim),
            self.activation,
            nn.Linear(self.hidden_dim, self.hidden_dim),
            self.activation,
            nn.Linear(self.hidden_dim, self.latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.hidden_dim),
            self.activation,
            nn.Linear(self.hidden_dim, self.hidden_dim),
            self.activation,
            nn.Linear(self.hidden_dim, in_features),
        )

    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, z

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        z = self.decoder(z)
        return z
