import torch
import torch.nn as nn
from torch.nn import ModuleList as mdl
from torch.nn import ModuleDict as mdd
from torch.nn import BatchNorm1d as bn1
from torch.nn import Sequential as seq
from torch.nn import functional as F
import numpy as np 
import pdb

from ..utils.deepnet_tools import reparam_trick, gumbel_softmax


class multimodal_mixVAE(nn.Module):
    """
    Class for the neural network module for mixture of continuous and
    discrete random variables. The module contains an VAE using
    Gumbel-softmax distribution for the categorical and reparameterization
    for continuous latent variables.

    Methods
        encoder: encoder network.
        intermed: the intermediate layer for combining categorical and continuous RV.
        decoder: decoder network.
        forward: module for forward path.
        state_changes: module for the continues variable analysis
        reparam_trick: module for reparameterization.
        sample_gumbel: samples by adding Gumbel noise.
        gumbel_softmax_sample: Gumbel-softmax sampling module
        gumbel_softmax: Gumbel-softmax distribution module
        loss: loss function module
    """

    def __init__(self, modalities, networks, input_dim, n_categories, n_arm, x_drop, tau, eps, noise_model):
        """
        Class instantiation.
        
        input args
            modalities: a list of modalities.
            networks: a dictionary of networks for each modality.
            input_dim: input dimension (size of the input layer).
            n_categories: number of categories of the latent variables.
            n_arm: number of arms for each modality.
            x_drop: dropout probability at the first (input) layer.
            tau: temperature of the softmax layers, usually equals to 1/n_categories (0 < tau <= 1).
            eps: a small constant value to fix computation overflow.
            noise_model: a string variable, 'Gaussian' or 'ZINB' for the noise model.
        """
        super(multimodal_mixVAE, self).__init__()
        self.input_dim = input_dim
        self.n_categories = n_categories
        self.n_arm = n_arm
        self.tau = tau
        self.eps = eps
        self.noise_model = noise_model
        self.x_dp = {key: nn.Dropout(x_drop[key]) for key in modalities}
        self.encoder_layers = mdd({key: mdl() for key in modalities})
        self.decoder_layers = mdd({key: mdl() for key in modalities})
        self.qc = mdd({key: mdl() for key in modalities})
        self.mu = mdd({key: mdl() for key in modalities})
        self.sigma = mdd({key: mdl() for key in modalities})
        self.fc_lowD = mdd({key: mdl() for key in modalities})

        self.nonLin_f = nn.SiLU()
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()

        for m in modalities:
            if m == 'M':
                self.conv_layers, self.linear_encoder, self.deconv_layers, self.linear_decoder, self.qc[m], self.mu[m], self.sigma[m], self.fc_lowD[m] = networks[m].get_layers()
            else:
                self.encoder_layers[m], self.decoder_layers[m], self.qc[m], self.mu[m], self.sigma[m], self.fc_lowD[m] = networks[m].get_layers()
            


    def encoder(self, x, mod, arm):
        """
        Encodes the input data into latent variables for the specified modality.

        input args
            x: Input data tensor.
            mod: Modality type ('rna' or 'atac') to specify which encoding to use.

        return
            z: Latent variable representation.
            qc: Categorical distribution (softmax over categories).
        """
        if mod == 'M':
            x_ = torch.flatten(x, 2).transpose(1, 2)
            x_ = self.conv_layers[arm](x_)
            x_ = x_.view(x_.shape[0], -1)
            z = self.linear_encoder[arm](self.x_dp[mod](x_))
        else:
            z = self.encoder_layers[mod][arm](self.x_dp[mod](x))
        
        return z, F.softmax(self.qc[mod][arm](z), dim=-1)

    
    def intermed(self, x, mod, arm):
        """
        Intermediate processing step for obtaining state means and sigmoid values.

        input args
            x: Latent variable (concatenation of categorical and continuous representations).
            mod: Modality type ('rna' or 'atac').

        return
            mu: Mean of the continuous latent variable.
            sigmoid_mu: Sigmoid of the mean for activation.
        """
        return self.mu[mod][arm](x), self.sigmoid(self.sigma[mod][arm](x))

    
    def decoder(self, c, s, mod, arm):
        """
        Forward pass for the model, generating reconstructed data from the input.

        input args
            x: Input tensor.
            temp: Temperature for the Gumbel-softmax distribution.
            eval: Flag indicating whether to use the model in evaluation mode (disable dropout).
            pruning_mask: Mask for selectively pruning certain variables.

        return
            Reconstructed input data and latent variables.
        """
        z = torch.cat((c, s), dim=1)
        z = self.elu(self.fc_lowD[mod][arm](z))
        if mod == 'M':
            x_ = self.linear_decoder[arm](z)
            x_ = x_.view(x_.size(0), 10, 27)
            x = self.deconv_layers[arm](x_)
            x = x.transpose(1, 2)
            try:
                x = torch.unflatten(x, 2, (np.sqrt(self.input_dim[mod]).astype(int), np.sqrt(self.input_dim[mod]).astype(int)))
            except:
                x = self.unflatten(x)
            return x
        else:
            return self.decoder_layers[mod][arm](z)
        

    def forward(self, x, temp, eval=False, pruning_mask=[], hard=False, variational=True):
        """
        Forward pass for the model, generating reconstructed data from the input.
        
        input args
            x: Input tensor, either RNA or ATAC data.
            temp: Temperature for the Gumbel-softmax distribution. Controls sharpness of categorical sampling.
            eval: Boolean flag to set the model to evaluation mode (disables dropout).
            pruning_mask: Mask for selectively pruning certain variables (if applicable).

        return
            recon_x: a dictionary including the reconstructed data for each modality.
            x_low: a dictionary including a low dimensional representation of the input for each modality.
            qc: dictionary of pdf of the categorical variable for all modalities.
            s: dictionary of sample of the sate variable for all modalities.
            c: dictionary of sample of the categorical variable for all modalities.
            mu: dictionary of mean of the state variable for all modalities.
            log_var: dictionary of log of variance of the state variable for all modalities.
            log_qc: dictionary of log-likelihood value of categorical variables in a batch for all modalities.
        """
        
        modalities = x.keys()
        recon_x = dict.fromkeys(modalities)
        x_low = dict.fromkeys(modalities)
        c, s = dict.fromkeys(modalities), dict.fromkeys(modalities)
        mu, log_var = dict.fromkeys(modalities), dict.fromkeys(modalities)
        qc, log_qc = dict.fromkeys(modalities), dict.fromkeys(modalities)

        for m in modalities:
            recon_x[m] = [None] * self.n_arm[m]
            x_low[m] = [None] * self.n_arm[m]
            log_qc[m] = [None] * self.n_arm[m]
            qc[m] = [None] * self.n_arm[m]
            s[m], c[m] = [None] * self.n_arm[m], [None] * self.n_arm[m]
            mu[m], log_var[m] = [None] * self.n_arm[m], [None] * self.n_arm[m]

            for arm in range(self.n_arm[m]):
                x_low[m][arm], log_qc[m][arm] = self.encoder(x[m][arm], m, arm)
                if len(pruning_mask) > 0:
                    qc_tmp = F.softmax(log_qc[m][arm][:, pruning_mask] / self.tau, dim=-1)
                    qc[m][arm] = torch.zeros((log_qc[m][arm].size(0), log_qc[m][arm].size(1))).to(x[m][arm].device)
                    qc[m][arm][:, pruning_mask] = qc_tmp
                else:
                    qc[m][arm] = F.softmax(log_qc[m][arm] / self.tau, dim=-1)

                q_ = qc[m][arm].view(log_qc[m][arm].size(0), 1, self.n_categories)

                if eval:
                    c[m][arm] = gumbel_softmax(q_, 1, self.n_categories, temp, hard=True, gumble_noise=False, eps=self.eps)
                else:
                    c[m][arm] = gumbel_softmax(q_, 1, self.n_categories, temp, hard=hard, eps=self.eps)

                zz = torch.cat((x_low[m][arm], c[m][arm]), dim=1)

                mu[m][arm], var = self.intermed(zz, m, arm)
                log_var[m][arm] = (var + self.eps).log()

                if eval or (not variational):
                    s[m][arm] = mu[m][arm]
                else:
                    s[m][arm] = reparam_trick(mu[m][arm], log_var[m][arm])
                
                recon_x[m][arm] = self.decoder(c[m][arm], s[m][arm], m, arm)
                
        return recon_x, x_low, qc, s, c, mu, log_var, log_qc
    



class vae_model(nn.Module):
    """
    Class for the neural network module for variational inference. 
    The module contains an VAE with a continuouse latent representation using reparameterization trick.
    The default setting of this network is for smart-seq datasets. If you
    want to use another dataset, you may need to modify the network's
    parameters.

    Methods
        encoder: encoder network.
        decoder: decoder network.
        forward: module for forward path.
        reparam_trick: module for reparameterization.
        loss: ELBO loss function module
    """
    def __init__(self, input_dim, fc_dim, lowD_dim, n_layer, x_drop, beta, variational, device, eps, momentum):
        """
        Class instantiation.

        input args
            input_dim: input dimension (size of the input layer).
            fc_dim: dimension of the hidden layer.
            lowD_dim: dimension of the latent representation.
            n_layer: number of hidden layers in the encoder and decoder networks, at least one.
            x_drop: dropout probability at the first (input) layer.
            beta: regularizer for the KL divergence term.
            variational: a boolean variable, True uses sampling, False is just a regular AE.
            device: computing device, either 'cpu' or 'cuda'.
            eps: a small constant value to fix computation overflow.
            momentum: a hyperparameter for batch normalization that updates its running statistics.
        """
        super(vae_model, self).__init__()
        self.input_dim = input_dim
        self.fc_dim = fc_dim
        self.x_dp = nn.Dropout(x_drop)
        self.beta = beta
        self.varitional = variational
        self.eps = eps
        self.momentum = momentum
        self.device = device

        self.sigmoid = nn.Sigmoid()
        
        self.fc_mu = nn.Linear(fc_dim, lowD_dim)
        self.fc_sigma = nn.Linear(fc_dim, lowD_dim)

        self.layer_e = [nn.Linear(input_dim, fc_dim), nn.ReLU()]
        for i in range(n_layer - 1):
            self.layer_e.extend([nn.Linear(fc_dim, fc_dim), nn.ReLU()])
        self.layer_e = seq(*self.layer_e)
    
        self.layer_d = [nn.Linear(lowD_dim, fc_dim), nn.ReLU()]
        for i in range(n_layer - 1):
            self.layer_d.extend([nn.Linear(fc_dim, fc_dim), nn.ReLU()])

        self.layer_d.extend([nn.Linear(fc_dim, input_dim), nn.ReLU()])
        self.layer_d = seq(*self.layer_d)

        self.batchnorm = nn.BatchNorm1d(num_features=lowD_dim, eps=eps, momentum=momentum, affine=False)


    def encoder(self, x) -> torch.Tensor:
        h = self.layer_e(self.x_dp(x))
        if self.varitional:
            return self.fc_mu(h), self.sigmoid(self.fc_sigma(h))
        else:
            return self.batchnorm(self.fc_mu(h))


    def decoder(self, z) -> torch.Tensor:
        return self.layer_d(z)


    def forward(self, x) -> torch.Tensor:
        """
        input args
            x: input data.

        return
            recon_x: reconstructed data.
            z: latent representation following Gaussian distribution.
            mu: mean of the latent representation.
            var: variance of the latent representation.
        """
        mu, log_var = [], []
        if self.varitional:
            mu, var = self.encoder(x)
            log_var = torch.log(var + self.eps)
            z = self.reparam_trick(mu, var)
        else:
            z = self.encoder(x)

        recon_x = self.decoder(z)

        return recon_x, z, mu, log_var


    def reparam_trick(self, mu, sigma):
        """
        Generate samples from a normal distribution for reparametrization trick.

        input args
            mu: mean of the Gaussian distribution for q(z|x) = N(mu, sigma^2*I).
            sigma: variance of the Gaussian distribution for q(s|z,x) = N(mu, sigma^2*I).

        return
            a sample from Gaussian distribution N(mu, sigma^2*I).
        """
        std = sigma.sqrt()
        eps = torch.rand_like(std).to(self.device)
        return eps.mul(std).add(mu)
    
    
    def loss(self, x, recon_x, mu, log_var, mode='MSE'):
        """
        loss function of the cpl-mixVAE network including.

       input args
            x: input data.
            recon_x: reconstructed data.
            mu: mean of the Gaussian distribution for the latent variable.
            log_sigma: log of variance of the Gaussian distribution for the latent variable.
            mode: a string to define the type of reconstruction loss function, either MSE or MSE-BCE.

        return
            total_loss: ELBO loss in the variational inference.
            l_rec: reconstruction loss.
        """
        
        if mode == 'MSE':
            l_rec = F.mse_loss(recon_x, x, reduction='mean') #/ (x.size(0))
        
        elif mode == 'MSE-BCE':
            l_rec = F.mse_loss(recon_x, x, reduction='mean') #/ (x.size(0))
            rec_bin = torch.where(recon_x > 0.01, 1., 0.)
            x_bin = torch.where(x > 0.01, 1., 0.)
            l_rec += F.binary_cross_entropy(rec_bin, x_bin)

        if self.varitional:
            KLD = (-0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp(), dim=0)).sum()
            loss = l_rec + self.beta * KLD
        else:
            KLD = 0.
            loss = l_rec

        return loss, l_rec, KLD
    
