import torch
import wandb
from src.train.loss import get_elbo_uncoded, get_elbo_coded_word, get_elbo_rep
from src.train.gradients import compute_gloo


def train_step(model, x, inference=None, n_samples=1, train_enc=True, train_dec=True):

    """
    Train step.
        
        Parameters
        ----------
        model : CodedDVAE instance
            Model to be trained.
        x : torch.tensor
            Batch of data.
        inference: string
            Inference type.
            - 'uncoded' for the uncoded case.
            - 'word' for the coded case with inference at word level.
            - 'rep' for the coded case with inference at bit level using repetition codes.
        n_sampes : int, optional
            Number of samples used for computing the ELBO. The number of samples is 1 by default.
        train_enc : boolean, optional
            Flag to indicate if the parameters of the encoder need to be updated. True by default.
        train_enc : boolean, optional
            Flag to indicate if the parameters of the decoder need to be updated. True by default.
        
        Returns
        -------
        elbo : torch.tensor
            Value of the ELBO.
        kl_div : torch.tensor
            Value of the Kullback-Leibler divergence term in the ELBO.
        reconstruction: torch.tensor
            Value of the reconstruction term in the ELBO.
    """

    x = x.to(model.device)

    model.optimizer_encoder.zero_grad()
    model.optimizer_decoder.zero_grad()

    # Uncoded case
    if inference == 'uncoded':

        # Compute loss
        elbo, kl_div, reconstruction = get_elbo_uncoded(x, model.encoder, model.decoder, beta=model.beta, likelihood=model.likelihood, n_samples=n_samples)

        # Sanity check
        assert torch.any(torch.isinf(elbo))==False, "Invalid ELBO value (inf)."
        assert torch.any(torch.isnan(elbo))==False, "Invalid ELBO value (nan)."

        # Gradients
        loss = -elbo
        loss.backward()

        # Optimizer step
        if train_dec:
            torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), 1)
            model.optimizer_decoder.step()
        if train_enc:
            torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1)
            model.optimizer_encoder.step()

    # Coded case (repetition code)
    if inference == 'rep':

        # Compute loss
        elbo, kl_div, reconstruction = get_elbo_rep(x, model.encoder, model.decoder, model.G, model.H, beta=model.beta, likelihood=model.likelihood, n_samples=n_samples)

        # Sanity check
        assert torch.any(torch.isinf(elbo))==False, "Invalid ELBO value (inf)."
        assert torch.any(torch.isnan(elbo))==False, "Invalid ELBO value (nan)."

        # Gradients
        loss = -elbo
        loss.backward()

        # Optimizer step
        if train_dec:
            model.optimizer_decoder.step()
        if train_enc:
            model.optimizer_encoder.step()

    # Coded case (inference at word level)
    if inference == 'word':

        # Compute loss
        elbo, kl_div, reconstruction, elbo_not_reduced, c_sample_logprob = get_elbo_coded_word(x, model.encoder, model.decoder, model.code_words, beta=model.beta, likelihood=model.likelihood, n_samples=2)
        
        # Sanity check
        assert torch.any(torch.isinf(elbo))==False, "Invalid ELBO value (inf)."
        assert torch.any(torch.isnan(elbo))==False, "Invalid ELBO value (nan)."

        # Gradients
        grads = compute_gloo(elbo_not_reduced, c_sample_logprob, model.encoder, n_samples=2)
        loss = -elbo
        loss.backward()

        # Optimizer step
        if train_dec:
            torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), 1)
            model.optimizer_decoder.step()

        if train_enc:   
            for param, custom_grad in zip(model.encoder.parameters(), grads):
                param.grad += custom_grad
            torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1)
            model.optimizer_encoder.step() 
    
    return elbo, kl_div, reconstruction

    
    
def trainloop(model, train_dataloader, n_epochs, inference=None, n_samples=1, train_enc=True, train_dec=True, verbose=True, wb=False, start_epoch=0):

    """
    Trainloop to train the model for a given number of epochs.
        
        Parameters
        ----------
        model : CodedDVAE instance
            Model to be trained.
        train_dataloader : torch Dataloader
            Dataloader with the training set.
        n_epochs: int
            Number of epochs.
        inference: string
            Inference type.
            - 'uncoded' for the uncoded case.
            - 'word' for the coded case with inference at word level.
            - 'rep' for the coded case with inference at bit level using repetition codes.
        n_sampes : int, optional
            Number of samples used for computing the ELBO. The number of samples is 1 by default.
        train_enc : boolean, optional
            Flag to indicate if the parameters of the encoder need to be updated. True by default.
        train_enc : boolean, optional
            Flag to indicate if the parameters of the decoder need to be updated. True by default.
        verbose: boolean, optional
            Flag to print the ELBO during training. True by default.
        wb: boolean, optional
            Flag to log the ELBO, KL term and reconstruction term to Weights&Biases.
        start_epoch: int, optional
            Epoch where the trainloop starts. This is useful to obtain coherent logs in weights and biases when we finetune a model.
        
        Returns
        -------
        elbo_evolution : list
            List containing the ELBO values obtained during training (1 value per epoch).
        kl_div_evolution : list
            List containing the Kullback-Leibler divergence values obtained during training (1 value per epoch).
        reconstruction_evolution : list
            List containing reconstruction term values obtained during training (1 value per epoch).
    """

    elbo_evolution = []
    kl_evolution = []
    rec_evolution = []

    for e in range(start_epoch, n_epochs):

        elbo_epoch = 0
        kl_epoch = 0
        reconstruction_epoch = 0

        for x, _ in train_dataloader:    # Batches
            
            elbo, kl, reconstruction = train_step(model, x, inference=inference, n_samples=n_samples, train_enc=train_enc, train_dec=train_dec)

            elbo_epoch += elbo.item()
            kl_epoch += kl.item()
            reconstruction_epoch += reconstruction.item()
                    
        elbo_evolution.append(elbo_epoch/len(train_dataloader))
        kl_evolution.append(kl_epoch/len(train_dataloader))
        rec_evolution.append(reconstruction_epoch/len(train_dataloader))

        # Empty cache
        torch.cuda.empty_cache()   
        
        if wb:
            wandb.log({"elbo/epoch": elbo_epoch/len(train_dataloader),
                        "kl/epoch": kl_epoch/len(train_dataloader),
                        "reconstruction/epoch": reconstruction_epoch/len(train_dataloader),
                        "epoch:": e })
        
        if verbose:
            print("ELBO after %d epochs: %f" %(e+1, elbo_evolution[-1]))
    
        
    return elbo_evolution, kl_evolution, rec_evolution 