import torch.nn as nn
import torch.nn.functional as F

class AEBlock(nn.Module):
    def __init__(self, in_dim, loss_type='mse'):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, in_dim // 2),
            nn.ReLU(),
            nn.Linear(in_dim // 2, in_dim // 4),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(in_dim // 4, in_dim // 2),
            nn.ReLU(),
            nn.Linear(in_dim // 2, in_dim),
            nn.Sigmoid()
        )
        self.loss_type = loss_type

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return recon, z

    def compute_loss(self, x, recon):
        if self.loss_type == 'mse':
            return F.mse_loss(recon, x, reduction='none')
        elif self.loss_type == 'l1':
            return F.l1_loss(recon, x, reduction='none')
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")