import torch
import torch.nn as nn
import torch.nn.functional as F

class BetaCVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, condition_dim, beta=1, hidden_dims=[512, 256, 128]):
        super(BetaCVAE, self).__init__()

        self.latent_dim = latent_dim
        self.beta = beta

        # Build Encoder
        modules = []
        in_dim = input_dim + condition_dim
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_dim, h_dim),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_dim = h_dim
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)


        # Build Decoder
        modules = []
        
        self.decoder_input = nn.Linear(latent_dim + condition_dim, hidden_dims[-1])

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.Linear(hidden_dims[i], hidden_dims[i+1]),
                    nn.BatchNorm1d(hidden_dims[i+1]),
                    nn.LeakyReLU())
            )
        
        self.decoder = nn.Sequential(*modules)
        
        self.final_layer = nn.Sequential(
                            nn.Linear(hidden_dims[-1], input_dim))

    def encode(self, x, c):

        input = torch.cat([x, c], dim = 1)
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z, c):

        input = torch.cat([z, c], dim = 1)
        result = self.decoder_input(input)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, x, c):
        mu, log_var = self.encode(x, c)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z, c), x, mu, log_var]

    def loss_function(self, *args, **kwargs):

        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + self.beta * kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self, num_samples, current_device, c):

        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)
        
        samples = self.decode(z, c)
        return samples

    def generate(self, x, c):


        return self.forward(x, c)[0] 