import torch
import torch.nn as nn
from torch import optim

from src.nn.encoder import Encoder
from src.nn.decoder import Decoder

from src.utils.sampling import sample_from_qz_given_x, sample_from_qc_given_x, modulate_words
from src.train.loss import compute_word_logprobs
from src.train.train import trainloop
from src.utils.functions import check_args, set_random_seed
from src.nn import modules


class CodedVAE(nn.Module):

    """
    Class implementing the Coded-DVAE.
    """

    def __init__(self, enc, dec, latent_dim, bits_info=None, code_words=None, G=None, H=None, max_iter=5, likelihood='gauss', beta=10, lr=1e-4, weight_decay=1e-4, inference='word', seed=None):

        super(CodedVAE, self).__init__()
        """
        Initialize an instance of the class.

        Parameters
        ----------
        enc : torch.nn.Module
            Module with the architecture of the encoder neural network without the output activation.
        dec : torch.nn.Module
            Module with the architecture of the decoder neural network.
        latent_dim : int
            Latent dimension of the model.
        bits_info : int, optional
            Number of information bits.
        code_words: torch.tensor, optional
            Codebook.
        G : torch.tensor, optional
            Matrix used to encode information words.
        H : torch.tensor, optional
            Matrix used to decode coded words
        likelihood: string, optional
            Distribution used to compute the reconstruction term. Default 'gauss'.
            - 'gauss': Gaussian likelihood.
            - 'ber': Bernoulli likelihood.
        beta: float, optional
            Temperature term that controls the decay of the exponentials in the smoothing transformation. Default to 10.
        lr: float, optional
            Learning rate. Default to 1e-4.
        weight_decay: float
            Weight decay. Default to 1e-4.
        inference: string
            Inference type. Default 'rep'.
            - 'uncoded' for the uncoded case.
            - 'word' for the coded case with inference at word level.
            - 'rep' for the coded case with inference at bit level using repetition codes.
        seed: int
            Seed for reproducibility.
        """

        # Configuration
        self.beta = torch.tensor(beta)
        self.likelihood = likelihood
        self.latent_dim = latent_dim
        self.inference = inference

        # Code words
        self.code_words = code_words

        # LDPC 
        self.G = G
        self.H = H
        
        # Bits info
        if bits_info is None:
            self.bits_info = self.latent_dim
        else:
            self.bits_info = bits_info

        # Check arguments
        check_args(self.inference, self.code_words, self.G, self.H)

        # Encoder
        self.encoder = Encoder(enc, inference_type=self.inference)
        # Decoder
        self.decoder = Decoder(dec)

        # Optimizers
        self.optimizer_encoder = optim.Adam(self.encoder.parameters(), lr=lr, weight_decay=weight_decay)
        self.optimizer_decoder = optim.Adam(self.decoder.parameters(), lr=lr, weight_decay=weight_decay)

        # Set device (GPU if available; otherwise CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

        # Set random seed
        if not (seed is None):
            set_random_seed(seed)
        

    def forward(self, x):

        """
        Forward pass.

        Parameters
        ----------
        x: torch.tensor
            Batch of data.
        """

        x = x.to(self.device)
        
        # Forward encoder
        encoder_out = self.encoder.forward(x)

        # Sample from the latent distribution
        # Uncoded case
        if self.inference == 'uncoded':
            z_sample = sample_from_qz_given_x(encoder_out, beta=self.beta, n_samples=1)
        # Coded case (word inference level) 
        if self.inference == 'word':
            logits, _ = compute_word_logprobs(encoder_out, self.code_words.to(self.device))
            c_sample, _ = sample_from_qc_given_x(logits, self.code_words.to(self.device), n_samples=1)
            z_sample = modulate_words(c_sample, beta=self.beta)
        # Coded case (bit inference level)     
        if self.inference == 'rep':
     
            logpm1 = torch.matmul(torch.log(encoder_out), self.H.to(self.device))
            logpm0 = torch.matmul(torch.log(1-encoder_out), self.H.to(self.device))

            log_marginals = torch.stack((logpm0, logpm1), dim=2)
            log_marginals_norm = log_marginals - torch.logsumexp(log_marginals, dim=-1, keepdim=True)

            # Introduce code structure
            qc = torch.matmul(torch.exp(log_marginals_norm[:,:,1]), self.G.to(self.device))
             # Modulate c to obtain z
            z_sample = sample_from_qz_given_x(qc, beta=self.beta, n_samples=1)

        # Forward decoder
        reconstructed = self.decoder.forward(z_sample[:,:,0])

        return z_sample, reconstructed
    
    
    def train(self, train_dataloader, n_epochs=100, n_epochs_wu=0, start_epoch=0, n_samples=1, train_enc=True, train_dec=True, verbose=True, wb=False):

        """
        Train the model for a given number of epochs.
            
            Parameters
            ----------
            model : CodedDVAE instance
                Model to be trained.
            train_dataloader : torch Dataloader
                Dataloader with the training set.
            n_epochs: int, optional
                Number of epochs. Default 100.
            n_epochs_wu: int
                Number of warmup epochs. For the first n_epochs_wu the model is trained in 'uncoded' mode.
            start_epoch: int, optional
                Epoch where the trainloop starts. This is useful to obtain coherent logs in weights and biases when we finetune a model.
            n_sampes : int, optional
                Number of samples used for computing the ELBO. The number of samples is 1 by default.
            train_enc : boolean, optional
                Flag to indicate if the parameters of the encoder need to be updated. True by default.
            train_enc : boolean, optional
                Flag to indicate if the parameters of the decoder need to be updated. True by default.
            verbose: boolean, optional
                Flag to print the ELBO during training. True by default.
            wb: boolean, optional
                Flag to log the ELBO, KL term and reconstruction term to Weights&Biases.

            Returns
            -------
            elbo_evol : list
                List containing the ELBO values obtained during training (1 value per epoch).
            kl_div_evol : list
                List containing the Kullback-Leibler divergence values obtained during training (1 value per epoch).
            reconstruction_evol : list
                List containing reconstruction term values obtained during training (1 value per epoch).
        """

        # Track loss evolution during training
        elbo_evol = []
        kl_evol = []
        rec_evol = []    

        # Neural Networks in training mode
        self.encoder.train()
        self.decoder.train()

        if n_epochs_wu>0:

            # set the output activation to Sigmoid(), as the warmup is always done with the uncoded version
            self.encoder.enc[-1] = nn.Sigmoid()
            # Warmup!
            print('Starting warmup...')
            elbo_evol_wu, kl_evol_wu, rec_evol_wu = trainloop(
                self, 
                train_dataloader, 
                n_epochs_wu, 
                start_epoch=start_epoch, 
                n_samples=n_samples, 
                train_enc=train_enc,
                train_dec=train_dec,
                inference='uncoded',
                verbose=verbose,
                wb=wb)
            
            elbo_evol.append(elbo_evol_wu)
            kl_evol.append(kl_evol_wu)
            rec_evol.append(rec_evol_wu)
            print('Warmup finished!')

        # Train!
        if self.inference == 'bit':
            self.encoder.enc[-1] = modules.ScaledTanh(factor=3.)
        print('Starting training...')
        elbo_evol_train, kl_evol_train, rec_evol_train = trainloop(
            self, 
            train_dataloader, 
            n_epochs+n_epochs_wu, 
            start_epoch=n_epochs_wu, 
            n_samples=n_samples, 
            train_enc=train_enc,
            train_dec=train_dec,
            inference=self.inference,
            verbose=verbose,
            wb=wb 
        )

        elbo_evol.append(elbo_evol_train)
        kl_evol.append(kl_evol_train)
        rec_evol.append(rec_evol_train)
        print('Training finished!')

        return elbo_evol, kl_evol, rec_evol
    

    def generate(self, n_samples=100):

        """
        Generate new samples following the generative model.

        Parameters
        ----------
        n_samples: int, optional
            Number of samples to generate.

        Returns
        -------
        Generated samples.

        """
        
        # Uncoded case
        if self.inference=='uncoded':
            # Uniform distribution for each bit
            m_probs = torch.ones((n_samples, self.latent_dim))*0.5
            # Sample z
            z_sample = sample_from_qz_given_x(m_probs.to(self.device), beta=self.beta, n_samples=1)
            # Forward decoder
            reconstructed = self.decoder.forward(z_sample[:,:,0])
        
        # Coded case (word inference level)
        if self.inference=='word':
            # Uniform distribution over words
            word_probs = torch.ones(self.code_words.shape[0])*(1/self.code_words.shape[0])
            idx = word_probs.multinomial(num_samples=n_samples, replacement=True)
            # Random sampled words with uniform probability
            c_sample = self.code_words[idx]
            z_sample = modulate_words(c_sample.to(self.device), beta=self.beta)
            # Forward decoder
            reconstructed = self.decoder.forward(z_sample[:,:])

        # Coded case (repetition codes)
        if self.inference=='rep':
            # Uniform distribution for each bit
            m_probs = torch.ones((n_samples, self.bits_info))*0.5
            # Sample information codewords
            m_sample = m_probs.bernoulli()
            # Obtain a codeword
            c = torch.matmul(m_sample, self.G)
            # Sample z
            z_sample = modulate_words(c.to(self.device), beta=self.beta)
            # Forward decoder
            reconstructed = self.decoder.forward(z_sample[:,:])

        return reconstructed
    

    def save(self, path):

        """
        Save model.

        Parameters
        ----------
        path: str
           Path where the model will be saved.

        """

        torch.save(self.state_dict(), path)
        print('Model saved at ' + path)





