import numpy as np
import torch
import torch.nn.functional as F


def logmeanexp_diag(x, device=None):
    """Compute logmeanexp over the diagonal elements of x."""
    batch_size = x.size(0)

    logsumexp = torch.logsumexp(x.diag(), dim=(0,))
    num_elem = batch_size

    return logsumexp - torch.log(torch.tensor(num_elem).float()).to(device)


def logmeanexp_nodiag(x, dim=None, device=None):
    batch_size = x.size(0)
    if dim is None:
        dim = (0, 1)

    # Créer le tensor inf sur le même device que x
    inf_tensor = torch.full((batch_size,), float('inf'), device=x.device)
    
    logsumexp = torch.logsumexp(x - torch.diag(inf_tensor), dim=dim)

    try:
        if len(dim) == 1:
            num_elem = batch_size - 1.
        else:
            num_elem = batch_size * (batch_size - 1.)
    except ValueError:
        num_elem = batch_size - 1
    
    return logsumexp - torch.log(torch.tensor(num_elem, device=x.device))

def tuba_lower_bound(scores, log_baseline=None, device=None):
    if log_baseline is not None:
        scores -= log_baseline[:, None]

    # First term is an expectation over samples from the joint,
    # which are the diagonal elements of the scores matrix.
    joint_term = scores.diag().mean()

    # Second term is an expectation over samples from the marginal,
    # which are the off-diagonal elements of the scores matrix.
    marg_term = logmeanexp_nodiag(scores, device=device).exp()
    return 1. + joint_term - marg_term


def nwj_lower_bound(scores, device=None):
    return tuba_lower_bound(scores - 1., device=device)


def infonce_lower_bound(scores, device=None):
    nll = scores.diag().mean() - scores.logsumexp(dim=1)
    mi = torch.tensor(scores.size(0), device=device).float().log() + nll
    mi = mi.mean()
    return mi


def js_fgan_lower_bound(f, device=None):
    """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016)."""
    f_diag = f.diag()
    first_term = -F.softplus(-f_diag).mean()
    n = f.size(0)
    second_term = (torch.sum(F.softplus(f)) - torch.sum(F.softplus(f_diag))) / (n * (n - 1.))
    return first_term - second_term


def js_lower_bound(f, device=None):
    """Obtain density ratio from JS lower bound then output MI estimate from NWJ bound."""
    nwj = nwj_lower_bound(f, device=device)
    js = js_fgan_lower_bound(f, device=device)

    with torch.no_grad():
        nwj_js = nwj - js

    return js + nwj_js


def dv_upper_lower_bound(f, device=None):
    """
    Donsker-Varadhan lower bound, but upper bounded by using log outside. 
    Similar to MINE, but did not involve the term for moving averages.
    """
    first_term = f.diag().mean()
    second_term = logmeanexp_nodiag(f, device=device)

    return first_term - second_term


def mine_lower_bound(f, buffer=None, momentum=0.9, device=None):
    """
    MINE lower bound based on DV inequality. 
    """
    if buffer is None:
        buffer = torch.tensor(1.0, device=device)
    first_term = f.diag().mean()

    buffer_update = logmeanexp_nodiag(f, device=device).exp()
    with torch.no_grad():
        second_term = logmeanexp_nodiag(f, device=device)
        buffer_new = buffer * momentum + buffer_update * (1 - momentum)
        buffer_new = torch.clamp(buffer_new, min=1e-4)
        third_term_no_grad = buffer_update / buffer_new

    third_term_grad = buffer_update / buffer_new

    return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update


def smile_lower_bound(f, clip=None, device=None):
    if clip is not None:
        f_ = torch.clamp(f, -clip, clip)
    else:
        f_ = f
    
    # Assurer que le tensor est sur le bon device
    z = logmeanexp_nodiag(f_, dim=(0, 1), device=device)
    dv = f.diag().mean() - z

    js = js_fgan_lower_bound(f, device=device)

    with torch.no_grad():
        dv_js = dv - js

    return js + dv_js


def estimate_mutual_information(estimator, x, y, critic_fn,
                                baseline_fn=None, alpha_logit=None, device=None, **kwargs):
    """Estimate variational lower bounds on mutual information.

  Args:
    estimator: string specifying estimator, one of:
      'nwj', 'infonce', 'tuba', 'js', 'interpolated'
    x: [batch_size, dim_x] Tensor
    y: [batch_size, dim_y] Tensor
    critic_fn: callable that takes x and y as input and outputs critic scores
      output shape is a [batch_size, batch_size] matrix
    baseline_fn (optional): callable that takes y as input 
      outputs a [batch_size]  or [batch_size, 1] vector
    alpha_logit (optional): logit(alpha) for interpolated bound
    device: device to use for computation

  Returns:
    scalar estimate of mutual information
    """
        
    if device is None:
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    x, y = x.to(device), y.to(device)
    scores = critic_fn(x, y)
    
    if baseline_fn is not None:
        # Some baselines' output is (batch_size, 1) which we remove here.
        log_baseline = torch.squeeze(baseline_fn(y)).to(device)
    
    if estimator == 'infonce':
        mi = infonce_lower_bound(scores, device=device)
    elif estimator == 'nwj':
        mi = nwj_lower_bound(scores, device=device)
    elif estimator == 'tuba':
        mi = tuba_lower_bound(scores, log_baseline, device=device)
    elif estimator == 'js':
        mi = js_lower_bound(scores, device=device)
    elif estimator == 'smile':
        mi = smile_lower_bound(scores, device=device, **kwargs)
    elif estimator == 'dv':
        mi = dv_upper_lower_bound(scores, device=device)
    
    return mi