import torch
import torch.distributions as dist
from src.utils.sampling import sample_from_qz_given_x, sample_from_qc_given_x, modulate_words
from src.nn.modules import dclamp

def log_bernoulli(probs, observation):

    """
    Evaluate a Bernoulli distribution.
        
        Parameters
        ----------
        probs : torch.tensor
            Tensor with the probabilities to define the Bernoulli. [shape (batch_size, dimension)]
        observation: torch.tensor
            Batch of data.
    
        Returns
        -------
        Log probability.
    """

    bce = torch.nn.BCELoss(reduction='none')

    return -torch.sum(bce(probs, observation), dim=1)


def log_gaussian(x, mean, covar):

    """
    Evaluate a Multivariate Gaussian distribution with diagonal covariance matrix.
        
        Parameters
        ----------
        x : torch.tensor
            Batch of data.
        mean : torch.tensor
            Means of the distribution.
        covar : torch.tensor
            Value of the diagonal.

        Returns
        -------
        Log probability.
    """

    # MVN INDEPENDEN NORMAL DISTRIBUTIONS
    # Create a multivariate normal distribution with diagonal covariance
    gaussian = dist.independent.Independent(dist.Normal(mean, torch.sqrt(covar)), 1)
    
    return gaussian.log_prob(x)


def kl_div_bernoulli(q_probs, p_probs):

    """
    Compute KL Divergence D_KL(q|p) between two Bernoulli distributions.

    Parameters
        ----------
        q_probs : torch.tensor
            Probabilities that define the q distribution.
        p_probs : torch.tensor
           Probabilities that define the p distribution.

        Returns
        -------
        kl_div : torch.tensor
            Kullback-Leibler divergence between the given distributions.
    """

    q = dist.Bernoulli(dclamp(q_probs, min=0, max=1-1e-3)) # clamp to avoid numerical instabilities
    p = dist.Bernoulli(p_probs)

    kl_div = dist.kl.kl_divergence(q, p)

    kl_div = torch.sum(kl_div, dim=1)

    return kl_div


def compute_word_logprobs(bit_probs, code_words):

    """
    Compute the log probability of the words in the codebook.

    Parameters
        ----------
        bit_probs : torch.tensor
            Bit probabilities.
        code_words : torch.tensor
            Matrix containing the codebook.

        Returns
        -------
        logq : torch.tensor
            Unnormalized distribution over words.
        logq_norm : torch.tensor
            Normalized distribution over words.
    """

    # Sanity check
    assert torch.any(bit_probs < 0)==False, "Negative value encountered in bit probabilities."
    assert torch.any(bit_probs > 1)==False, "Value larger than 1 encountered in bit probabilities."
    assert torch.all(torch.logical_or(code_words == 0, code_words == 1)), "Invalid word encountered. All words should be binary vectors."

    # === Compute log(q(c|x,C)) [evaluate log(q_uncoded(c|x)) for code words] === #

    # 1. Extend the output of the encoder in a third dimension to obtain a tensor of shape [batch_size, K, n_words]
    # 2. Extend the code words matrix in a third dimension to obtain a tensor of shape [batch_size, K, n_words]
    # 3. Reduce the logq in dim=1 to obtain a matrix of shape [batch_size, n_words] containing the evaluation of log(q(c|x,C)) for each code word
    
    n_words = code_words.shape[0] 
    batch_size = bit_probs.shape[0]

    # Clamp to avoid numerical instabilities
    bit_probs = dclamp(bit_probs, min=0.001, max=0.999)

    # Evaluate log(q_uncoded(c|x)) for code words
    logq = log_bernoulli(bit_probs.unsqueeze(2).repeat(1, 1, n_words), code_words.T.unsqueeze(0).repeat(batch_size,1,1))
    
    # Clamp to avoid numerical instabilities
    logq = dclamp(logq, min=-100, max=1)

    # Sanity check
    assert torch.any(torch.isinf(logq))==False, "Invalid logq value (inf)."
    assert torch.any(torch.isnan(logq))==False, "Invalid logq value (nan)."

    # Normalization
    logq_norm = logq - logq.logsumexp(dim=-1, keepdim=True)

    # Sanity check
    assert torch.all(torch.exp(logq_norm) >= 0), "Negative value encountered in normalized probs."
    assert torch.all((torch.exp(logq_norm).sum(-1) - 1).abs() < 1e-5), "Normalized probabilities do not sum 1."

    return logq, logq_norm


def get_elbo_uncoded(x, encoder, decoder, prior_m=0.5, beta=10, likelihood='gauss', n_samples=1):

    """
    Compute the ELBO in the uncoded scenario.

    Parameters
        ----------
        x : torch.tensor
            Batch of data.
        encoder : Encoder instance
            Encoder of the model.
        decoder : Decoder instance
            Decoder of the model.
        prior_m : float, optional
            Prior bit probability. Default to 0.5.
        beta: float, optional
            Temperature term that controls the decay of the exponentials in the smoothing transformation. Default to 10.
        likelihood: string, optional
            Distribution used to compute the reconstruction term.
            - 'gauss': Gaussian likelihood.
            - 'ber': Bernoulli likelihood.
        n_samples: int, optional
            Number of samples used to estimate the ELBO. Default to 1.

        Returns
        -------
        elbo : torch.tensor
            Value of the ELBO.
        kl_div : torch.tensor
            Value of the Kullback-Leibler divergence term in the ELBO.
        reconstruction: torch.tensor
            Value of the reconstruction term in the ELBO.
    """

    N = x.shape[0]
    x_flat = x.view(N,-1)

    # Forward encoder
    bit_probs = encoder.forward(x)

    # Obtain n_samples from q(z|x) for each observed x
    qz_sample = sample_from_qz_given_x(bit_probs, beta=beta, n_samples=n_samples)  # shape [N, K, n_samples]

    # Compute the reconstruction term E_{q(z|x)}[log p(x|z)]
    reconstruction_sum = 0

    for n in range(n_samples):
        
        # Forward decoder
        out_decoder = decoder.forward(qz_sample[:, :, n]).view(-1, x_flat.shape[1])

        # Binary observation model
        if likelihood.lower() == 'ber':
            reconstruction_sum += log_bernoulli(out_decoder, x_flat)
        # Real observation model
        elif likelihood.lower() == 'gauss':
            covar = torch.ones(out_decoder.shape[1]).to(x_flat.device) * 0.1
            reconstruction_sum += log_gaussian(x_flat, out_decoder, covar)    # Fixed variance

    reconstruction = reconstruction_sum/n_samples

    # Compute the KL Divergence term
    prior_probs = (torch.ones(bit_probs.shape)*prior_m).to(x_flat.device)
    kl_div = kl_div_bernoulli(bit_probs, prior_probs)

    # Obtain the ELBO loss
    elbo = torch.sum((reconstruction-kl_div), dim=0)/N

    return elbo, torch.sum(kl_div, dim=0)/N, torch.sum(reconstruction, dim=0)/N


def get_elbo_rep(x, encoder, decoder, G, H, prior_m=0.5, beta=10, likelihood='gauss', n_samples=1):

    """
    Compute the ELBO in the uncoded scenario.

    Parameters
        ----------
        x : torch.tensor
            Batch of data.
        encoder : Encoder instance
            Encoder of the model.
        decoder : Decoder instance
            Decoder of the model.
        G : torch.tensor
            Matrix used to encode information words.
        H : torch.tensor
            Matrix used to decode coded words
        beta: float, optional
            Temperature term that controls the decay of the exponentials in the smoothing transformation. Default to 10.
        likelihood: string, optional
            Distribution used to compute the reconstruction term.
            - 'gauss': Gaussian likelihood.
            - 'ber': Bernoulli likelihood.
        n_samples: int, optional
            Number of samples used to estimate the ELBO. Default to 1.

        Returns
        -------
        elbo : torch.tensor
            Value of the ELBO.
        kl_div : torch.tensor
            Value of the Kullback-Leibler divergence term in the ELBO.
        reconstruction: torch.tensor
            Value of the reconstruction term in the ELBO.
    """

    N = x.shape[0]
    x_flat = x.view(N,-1)

    probs = encoder.forward(x)

    # Sanity check
    assert torch.any(torch.isinf(probs))==False, "Invalid probs value (inf)."
    assert torch.any(torch.isnan(probs))==False, "Invalid probs value (nan)."

    logpm1 = torch.matmul(torch.log(probs), H.to(probs.device))
    logpm0 = torch.matmul(torch.log(1-probs), H.to(probs.device))

    log_marginals = torch.stack((logpm0, logpm1), dim=2)

    log_marginals_norm = log_marginals - torch.logsumexp(log_marginals, dim=-1, keepdim=True)

    # Introduce code structure
    qc = torch.matmul(torch.exp(log_marginals_norm[:,:,1]), G.to(probs.device))

    # Obtain n_samples from q(z|x) for each observed x
    qz_sample = sample_from_qz_given_x(qc, beta=beta, n_samples=n_samples)  # shape [N, K, n_samples]
    
    # Compute the reconstruction term E_{q(z|x)}[log p(x|z)]
    reconstruction_sum = 0

    for n in range(n_samples):

        # Forward decoder
        out_decoder = decoder.forward(qz_sample[:,:,n]).view(-1, x_flat.shape[1])

        assert torch.any(torch.isinf(out_decoder))==False, "Invalid out_decoder value (inf)."
        assert torch.any(torch.isnan(out_decoder))==False, "Invalid out_decoder value (nan)."

        # Binary observation model
        if likelihood.lower() == 'ber':
            reconstruction_sum += log_bernoulli(out_decoder, x_flat)
        # Real observation model
        elif likelihood.lower() == 'gauss':
            covar = torch.ones(out_decoder.shape[1]).to(x_flat.device) * 0.1
            reconstruction_sum += log_gaussian(x_flat, out_decoder, covar)    # Fixed variance

    reconstruction = reconstruction_sum/n_samples

    # Compute the KL Divergence term
    prior_probs = (torch.ones(logpm1.shape)*prior_m).to(x_flat.device)
    kl_div = kl_div_bernoulli(torch.exp(log_marginals_norm[:,:,1]), prior_probs)

    # Obtain the ELBO loss
    elbo = torch.sum((reconstruction-kl_div), dim=0)/N

    return elbo, torch.sum(kl_div, dim=0)/N, torch.sum(reconstruction, dim=0)/N


def get_elbo_coded_word(x, encoder, decoder, code_words, beta=10, likelihood='gauss', n_samples=2):

    """
    Compute the ELBO in the coded case with inference at word level.

    Parameters
        ----------
        x : torch.tensor
            Batch of data.
        encoder : Encoder instance
            Encoder of the model.
        decoder : Decoder instance
            Decoder of the model.
        code_words: torch.tensor
            Codebook.
        beta: float, optional
            Temperature term that controls the decay of the exponentials in the smoothing transformation. Default to 10.
        likelihood: string, optional
            Distribution used to compute the reconstruction term.
            - 'gauss': Gaussian likelihood.
            - 'ber': Bernoulli likelihood.
        n_samples: int, optional
            Number of samples used to estimate the ELBO. Default to 2 (minimum to estimate the gradients using Reinforce LOO).

        Returns
        -------
        elbo : torch.tensor
            Value of the ELBO.
        kl_div : torch.tensor
            Value of the Kullback-Leibler divergence term in the ELBO.
        reconstruction: torch.tensor
            Value of the reconstruction term in the ELBO.
        elbo_not_reduced: torch.tensor
            Matrix with the ELBO values obtained for the data points in the batch for a given number of samples before reducing. [shape (n_onbservations, n_samples)]
        c_sample_logprob : torch.tensor
            The log probability of the sampled words c.
    """

    N = x.shape[0]
    x_flat = x.view(N, -1)
    n_words = code_words.shape[0]

    bit_probs = encoder.forward(x)
    logq, logq_norm = compute_word_logprobs(bit_probs, code_words.to(bit_probs.device))

    # Sample words
    words_sample, c_sample_logprob = sample_from_qc_given_x(logq, code_words.to(bit_probs.device), n_samples=n_samples)

    # Modulate the words sampling from p(z|c)
    z_sample = modulate_words(words_sample, beta=beta)

    # Compute the reconstruction term E_{q(z|x)}[log p(x|z)]
    reconstruction = torch.zeros((N,n_samples)).to(x_flat.device) # Preallocate with zeros, shape [batch_size, n_samples]
    for s in range(n_samples):
        
        # Forward decoder
        out_decoder = decoder.forward(z_sample[:, :, s]).reshape(-1, x_flat.shape[1])

        # Binary observation model
        if likelihood.lower() == 'ber':
            if torch.any(out_decoder > 1) or torch.any(out_decoder < 0):
                raise Exception("Invalid values encountered in out_decoder. All elements should be between 0 and 1.")
            reconstruction[:,s] = log_bernoulli(out_decoder, x_flat)
        # Real observation model
        elif likelihood.lower() == 'gauss':
            covar = torch.ones(out_decoder.shape[1]).to(x_flat.device) * 0.1
            reconstruction[:,s] = log_gaussian(x_flat, out_decoder, covar) 

    log_prior_c = torch.log(torch.tensor((1/n_words), dtype=torch.float32)).to(x_flat.device)
    kl_div = torch.sum(torch.exp(logq_norm)*(logq_norm-log_prior_c), dim=-1) # sum(q(c)*log(q(c)/p(c)))
    elbo_not_reduced = reconstruction - kl_div.unsqueeze(-1) # shape [batch_size, n_samples]
    elbo = torch.mean(elbo_not_reduced)

    return elbo, torch.mean(kl_div), torch.mean(reconstruction), elbo_not_reduced, c_sample_logprob





    

    

