import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from models.types_ import *
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
from util import TensorStorage
import copy

class ConditionalVAE(BaseVAE.BaseVAE):

    def __init__(self,
                 in_channels: int,
                 num_classes: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 img_size:int = 128,
                 **kwargs) -> None:
        super(ConditionalVAE, self).__init__()
        
        self.mu_meter = TensorStorage('Mu', ':.4e')
        self.logVar_meter = TensorStorage('Log', ':.2%')
        

        latent_dim = 32
        self.mu_ = torch.zeros(4096).to(device=torch.device("cuda:0"))
        self.log_var_ = torch.zeros(4096).to(device=torch.device("cuda:0"))
        self.latent_dim = latent_dim
        print("latent_dim:")
        print(latent_dim)
        self.img_size = img_size

        # self.embed_class = nn.Linear(num_classes, img_size * img_size)
        self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        modules = []
        if hidden_dims is None:
            hidden_dims = [128, 256, 512, 1024, 2048]

        in_channels = 1 # To account for the extra label channel
        # Build Encoder
        AE_e = [] 
        layers = [128, 128, 128, 128] 
        for i in range(len(layers)):
            if i == 0:
                AE_e.append(
                        nn.Sequential(
                            nn.Linear(img_size*5, layers[i]),
                            nn.BatchNorm1d(layers[i]),
                            nn.LeakyReLU()
                        )
                        
                )
            else:
                 AE_e.append(
                        nn.Sequential(
                            nn.Linear(layers[i-1], layers[i]),
                            nn.BatchNorm1d(layers[i]),
                            nn.LeakyReLU()
                        )
                 )
        AE_d = []
        for i in range(len(layers)):
            if i == 0:
                AE_d.append(
                    nn.Sequential(
                            nn.Linear(latent_dim, layers[i]),
                            nn.BatchNorm1d(layers[i]),
                            nn.LeakyReLU()
                        )
                )
            else:    
                AE_d.append(
                            nn.Sequential(
                                nn.Linear(layers[i-1], layers[i]),
                                nn.BatchNorm1d(layers[i]),
                                nn.LeakyReLU()
                            )
                    )
                     
        AE_d.append(
            nn.Sequential(
                nn.Linear(layers[i], img_size*5)
            )   
        )    
        self.AE_e = nn.Sequential(*AE_e)
        self.AE_d = nn.Sequential(*AE_d)
        
        
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_var = nn.Linear(128, latent_dim)

        self.ff = nn.Linear(latent_dim, 2)
        
        self.mm = nn.Sequential(
            nn.BatchNorm1d(latent_dim),
            nn.LeakyReLU()
        )
        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim , hidden_dims[-1]*4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 1,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 2048, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, stage='train') -> List[Tensor]:
        # y = kwargs['labels'].float()
        # embedded_class = self.embed_class(y)
        # embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
        device = torch.device('cuda:0')
        embedded_input = self.embed_data(input)
        z = input
        # x = torch.cat([embedded_input, embedded_class], dim = 1)       
        x = embedded_input
        x = x.view(x.shape[0],-1)
        hid = self.AE_e(x)
        mu = self.fc_mu(hid)
        log_var = self.fc_var(hid)
        hh = self.reparameterize(mu,log_var)
        output = self.AE_d(hh)
        output = output.view(-1, 1, 5, 64)
        mu_ = self.mu_meter.avg.detach().clone()
        log_var_ = self.logVar_meter.avg.detach().clone()
        if stage == 'train':
            self.mu_meter.update(torch.mean(mu, dim = 0))
            self.logVar_meter.update(torch.mean(log_var, dim = 0))
        mu_ = self.mu_meter.avg
        log_var_ = self.logVar_meter.avg


        # mu, log_var = self.encode(x)
        
        # if stage == 'train':
        #     self.mu_meter.update(torch.mean(mu, dim = 0))
        #     self.logVar_meter.update(torch.mean(log_var, dim = 0))

        # mu_ = self.mu_meter.avg
        # log_var_ = self.logVar_meter.avg
        kld0 = 0.5 * (log_var - log_var_ - 1 + torch.exp(log_var_)/torch.exp(log_var) + (mu_ - mu).pow(2)/torch.exp(log_var)) 
        # kld = self.ff(kld0)
        # kld0 = torch.mean(kld0, dim = -1)
        # z = self.reparameterize(mu, log_var)
        
        # if stage == 'train':
        #     self.mu_meter.update(torch.mean(mu, dim = 0))
        #     self.logVar_meter.update(torch.mean(log_var, dim = 0))
        
        kld = self.ff(kld0)
        # kld0 = torch.mean(kld0, dim = -1)
        # self.mu_ = mu_
        # 
        
        return output, kld, kld0
    
    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int,
               **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        y = kwargs['labels'].float()
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        z = torch.cat([z, y], dim=1)
        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W] 
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x, **kwargs)[0]



def CVAE(in_channels = 1, num_classes = 2, latent_dim = 4096, hidden_dims = None, img_size = 64):
    return ConditionalVAE(in_channels, num_classes, latent_dim, hidden_dims, img_size)