import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, latent_dim, in_channels=1):
        super().__init__()
        self.layers = nn.Sequential(
            # Layer 1: [B, 1, 32, 32] -> [B, 256, 16, 16]
            nn.Conv2d(in_channels, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Layer 2: [B, 256, 16, 16] -> [B, 128, 8, 8]
            nn.Conv2d(256, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Layer 3: [B, 128, 8, 8] -> [B, 64, 4, 4]
            nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Layer 4: [B, 64, 4, 4] -> [B, 32, 2, 2]
            nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Layer 5: [B, 32, 2, 2] -> [B, d, 1, 1]
            nn.Conv2d(32, latent_dim, kernel_size=2)
        )

    def forward(self, x):
        return self.layers(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim, out_channels=1):
        super().__init__()
        self.layers = nn.Sequential(
            # Layer 1: [B, d, 1, 1] -> [B, 32, 2, 2]
            nn.ConvTranspose2d(latent_dim, 32, kernel_size=2),
            nn.ReLU(),
            # Layer 2: [B, 32, 2, 2] -> [B, 64, 4, 4]
            nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            # Layer 3: [B, 64, 4, 4] -> [B, 128, 8, 8]
            nn.ConvTranspose2d(64, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            # Layer 4: [B, 128, 8, 8] -> [B, 256, 16, 16]
            nn.ConvTranspose2d(128, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            # Layer 5: [B, 256, 16, 16] -> [B, 1, 32, 32]
            nn.ConvTranspose2d(256, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        )

    def forward(self, s):
        return self.layers(s)
    
# ----- Autoencoder -----
class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(in_channels=1, latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim, out_channels=1)

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z
