import torch
from torch import nn


class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class VAE(nn.Module):
    def __init__(self, img_size=32):
        super(VAE, self).__init__()
        nf = 32
        h1_dim = 1 * nf
        h2_dim = 2 * nf
        fc_dim = 4 * nf
        img_size = img_size
        feature_volume = (img_size // 4) * (img_size // 4) * h2_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(3, h1_dim, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(h1_dim, h2_dim, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Flatten(),  # Replacing Lambda layer with nn.Flatten()
            nn.Linear(feature_volume, fc_dim),
        )
        self.enc_fc_z_mean = nn.Linear(fc_dim, 64)
        self.enc_fc_z_log_var = nn.Linear(fc_dim, 64)

        self.decoder = nn.Sequential(
            nn.Linear(64, fc_dim),
            nn.Linear(fc_dim, feature_volume),
            Reshape(-1, h2_dim, img_size // 4, img_size // 4),
            nn.ReLU(),  # Adding ReLU activation after Linear layers
            nn.ConvTranspose2d(h2_dim, h1_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(h1_dim, 3, kernel_size=4, stride=2, padding=1),
        )
        self.mse_loss = nn.MSELoss(reduction="sum")

    def encode(self, x):
        h = self.encoder(x)
        z_mean = self.enc_fc_z_mean(h)
        z_log_var = self.enc_fc_z_log_var(h)
        return z_mean, z_log_var

    def decode(self, z):
        return self.decoder(z)

    def reparameterize(self, mu, logVar):
        std = torch.exp(logVar / 2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def loss(self, x):
        out, mu, logVar = self.forward(x)
        kl_divergence = -0.5 * torch.sum(1 + logVar - mu.pow(2) - logVar.exp())
        mse = self.mse_loss(out, x)
        return mse + kl_divergence

    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        return self.decode(z), z_mean, z_log_var


class VAE_miniimagenet(VAE):
    def __init__(self):
        super().__init__(img_size=84)


class VAE_tinyimagenet(VAE):
    def __init__(self):
        super().__init__(img_size=64)


class VAE_MLP(VAE):
    def __init__(self):
        super().__init__(img_size=64)
        self.encoder = nn.Sequential(
            nn.Linear(512, 400),
            nn.ReLU(),
            nn.Linear(400, 100),
        )
        self.decoder = nn.Sequential(
            nn.Linear(100, 400),
            nn.ReLU(),
            nn.Linear(400, 512),
        )
