import torch 
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
import math
from typing import *

class ConvNormAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvNormAct, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

class ConvTransposeNormAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, 
                 activation='relu'):
        super(ConvTransposeNormAct, self).__init__()
        self.conv_trans = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
        self.bn = nn.BatchNorm2d(out_channels)
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        else:
            self.activation = None

    def forward(self, x):
        x = self.conv_trans(x)
        x = self.bn(x)
        if self.activation != None:
            x = self.activation(x)
        return x

class Encoder(nn.Module):
    def __init__(self, num_classes = 1000, latent_len = 1024, img_ch = 3, img_h = 224, img_w = 224):
        super(Encoder, self).__init__()
        self.img_ch = img_ch 
        self.img_h = img_h
        self.img_w = img_w
        self.linear1 = nn.Linear(num_classes, 1*img_h*img_w)
        self.conv1 = ConvNormAct(4, 32, 4, 2, 1)
        self.conv2 = ConvNormAct(32, 64, 4, 2, 1)
        self.conv3 = ConvNormAct(64, 128, 4, 2, 1)
        self.conv4 = ConvNormAct(128, 256, 4, 2, 1)
        self.conv5 = ConvNormAct(256, 512, 4, 2, 1)
        self.conv6 = ConvNormAct(512, 1024, 4, 2, 1)
        self.linear2 = nn.Linear(1024*int(img_h/64)*int(img_w/64), latent_len*4)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(latent_len*4, latent_len*2)
        self.relu3 = nn.ReLU()
        self.linear4 = nn.Linear(latent_len*2, latent_len)
        self.linear5 = nn.Linear(latent_len*2, latent_len)

    def forward(self, x, cond):
        # x: (B, C, H, W); cond: (B, n_classes)
        cond = self.linear1(cond)
        cond_cnn = torch.reshape(cond, (cond.shape[0], 1, self.img_h, self.img_w))
        x_cond = torch.cat((x, cond_cnn), dim=1) # (B, C+1, H, W)
        q = self.conv1(x_cond)
        q = self.conv2(q)
        q = self.conv3(q)
        q = self.conv4(q)
        q = self.conv5(q)
        q = self.conv6(q)
        flat_q = torch.flatten(q, start_dim=1)
        h_q = self.linear2(flat_q)
        h_q = self.relu2(h_q)
        h_q = self.linear3(h_q)
        h_q = self.relu3(h_q)
        z_mu = self.linear4(h_q)
        z_log_var = self.linear5(h_q)
        return z_mu, z_log_var 

class Decoder(nn.Module):
    def __init__(self, num_classes = 1000, latent_len = 1024, img_ch = 3, img_h = 224, img_w = 224):
        super(Decoder, self).__init__()
        self.img_ch = img_ch
        self.img_h = img_h
        self.img_w = img_w
        self.linear1 = nn.Linear(latent_len + num_classes, latent_len*2)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(latent_len*2, latent_len*4)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(latent_len*4, 1024*int(img_h/64)*int(img_w/64))
        self.conv1 = ConvTransposeNormAct(1024, 512, 4, 2, 1, (img_h >> 5) % 2)
        self.conv2 = ConvTransposeNormAct(512, 256, 4, 2, 1, (img_h >> 4) % 2)
        self.conv3 = ConvTransposeNormAct(256, 128, 4, 2, 1, (img_h >> 3) % 2)
        self.conv4 = ConvTransposeNormAct(128, 64, 4, 2, 1, (img_h >> 2) % 2)
        self.conv5 = ConvTransposeNormAct(64, 32, 4, 2, 1, (img_h >> 1) % 2)
        self.conv6 = ConvTransposeNormAct(32, 3, 4, 2, 1, img_h % 2, None)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, x_encoded):
        # x_encoded: (B, latent_len + n_classes)
        h_p = self.linear1(x_encoded)
        h_p = self.relu1(h_p)
        h_p = self.linear2(h_p)
        h_p = self.relu2(h_p)
        h = self.linear3(h_p)
        p = torch.reshape(h, (h.shape[0], 1024, int(self.img_h/64), int(self.img_w/64)))
        p = self.conv1(p)
        p = self.conv2(p)
        p = self.conv3(p)
        p = self.conv4(p)
        p = self.conv5(p)
        p = self.conv6(p)
        x_decoded = torch.flatten(p, start_dim=1)
        # for i in range(x_decoded.shape[0]):
        #     x_decoded = x_decoded / torch.amax(x_decoded, 1).unsqueeze(1)
        return x_decoded
    
class CVAE(nn.Module):
    def __init__(self, num_classes = 1000, latent_len = 1024, img_ch = 3, img_h = 224, img_w = 224):
        super(CVAE, self).__init__()
        assert(img_h == img_w)
        self.encoder = Encoder(num_classes, latent_len, img_ch, img_h, img_w)
        self.decoder = Decoder(num_classes, latent_len, img_ch, img_h, img_w)

    def reparametrize(self, mean, std):
        q = torch.distributions.Normal(mean, std)
        return q.rsample()

    def forward(self, x, cond):
        # x: (B, C, H, W); cond: (B, n_classes); eps: (B, latent_len)
        z_mu, z_log_var = self.encoder(x, cond)
        z_sigma = torch.exp(0.5*z_log_var)
        z = self.reparametrize(z_mu, z_sigma)
        z_cond = torch.cat((z, cond), dim=1)
        x_decoded = self.decoder(z_cond)
        return x_decoded, z_mu, z_sigma, z
    
class ConditionalVAE(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 img_size:int = 64,
                 **kwargs) -> None:
        super(ConditionalVAE, self).__init__()

        self.latent_dim = latent_dim
        self.img_size = img_size

        self.embed_class = nn.Linear(num_classes, img_size * img_size)
        self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        else:
            raise NotImplementedError

        in_channels += 1 # To account for the extra label channel
        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*int(img_size/32)*int(img_size/32), latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*int(img_size/32)*int(img_size/32), latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1]*int(img_size/32)*int(img_size/32))

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Sigmoid()) # nn.Tanh()

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, int(self.img_size/32), int(self.img_size/32))
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        y = kwargs['labels'].float()
        embedded_class = self.embed_class(y)
        embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
        embedded_input = self.embed_data(input)

        x = torch.cat([embedded_input, embedded_class], dim = 1)
        mu, log_var = self.encode(x)

        z = self.reparameterize(mu, log_var)

        z = torch.cat([z, y], dim = 1)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self, recons, input, mu, log_var, rec_loss='mse', kld_weight=0.0025,
                      **kwargs) -> dict:
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset
        if rec_loss == 'mse':
            recons_loss = F.mse_loss(recons, input)
        elif rec_loss == 'bce':
            bce_loss_fn = nn.BCELoss()
            recons_loss = bce_loss_fn(recons, input)
        elif rec_loss == 'l1':
            l1_loss_fn = nn.L1Loss()
            recons_loss = l1_loss_fn(recons, input)
        elif rec_loss == 'gaussian':
            dist = torch.distributions.Normal(recons, 1.0)
            log_p = dist.log_prob(input)
            recons_loss = - log_p.mean()
        else:
            raise NotImplementedError

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int,
               **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        y = kwargs['labels'].float()
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        z = torch.cat([z, y], dim=1)
        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x, **kwargs)[0]
