import torch
import torch.nn as nn


class CNNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2)
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.conv(input)
        output = self.relu(output)
        return output


class CNNTransposedLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNTransposedLayer, self).__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=4, stride=2, padding=1
        )
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.conv(input)
        output = self.relu(output)
        return output

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            CNNLayer(1, 16),
            CNNLayer(16, 32),
            CNNLayer(32, 32),
        )
        self.mean = nn.Linear(32, 2)
        self.logvar = nn.Linear(32, 2)
        
    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        elif x.dim() == 2:
            x = x.unsqueeze(0).unsqueeze(1)
        w = self.encoder(x)
        w = w.view(-1, 32)
        mean = self.mean(w)
        logvar = self.logvar(w)
        return mean, logvar
    
    def w_to_z(self, w):
        return self.mean(w), self.logvar(w)
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.dec_linear = nn.Linear(2, 7*7*32)
        self.decoder = nn.Sequential(
            CNNTransposedLayer(32, 32),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )
    
    def forward(self, z):
        z = self.dec_linear(z)
        z = z.view(-1, 32, 7, 7)
        output = self.decoder(z).squeeze()
        return output

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterization(self, mean, logvar):
        eps = torch.rand(logvar.shape).cuda()
        z = mean + eps * (logvar * 0.5).exp()
        return z

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterization(mean, logvar)
        output = self.decoder(z)
        return output, z, mean, logvar