import numpy as np
import torch
import torch.nn as nn
import math

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

import torch
import matplotlib.pyplot as plt


def logmeanexp_nodiag(x, dim=None, device='cpu'):
    batch_size = x.size(0)
    if dim is None:
        dim = (0, 1)

    logsumexp = torch.logsumexp(
        x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim)

    try:
        if len(dim) == 1:
            num_elem = batch_size - 1.
        else:
            num_elem = batch_size * (batch_size - 1.)
    except ValueError:
        num_elem = batch_size - 1
    return logsumexp - torch.log(torch.tensor(num_elem)).to(device)


#borne inf JS-fGAN --- https://arxiv.org/pdf/1606.00709 
def js_fgan_lower_bound(f):
    """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016)."""
    f_diag = f.diag()
    first_term = -F.softplus(-f_diag).mean()
    n = f.size(0)
    second_term = (torch.sum(F.softplus(f)) -
                   torch.sum(F.softplus(f_diag))) / (n * (n - 1.))
    return first_term - second_term


def weighted_logmeanexp_nodiag(X, P, device='cpu'):
    """
    Calcule : log ( 1 / B(B-1) * sum_{i ≠ j} exp(X_ij) * P_ij )
    
    Arguments:
    - X : Matrice de scores (B, B)
    - P : Matrice de poids/probabilités (B, B)
    
    Retourne:
    - logmeanexp pondérée par P, sans la diagonale
    """
    B = X.size(0)  # Taille du batch

    # Créer un masque pour exclure la diagonale
    mask = ~torch.eye(B, dtype=torch.bool, device=device)  # Matrice de booléens (True hors diagonale)

    # Appliquer le masque pour garder uniquement les éléments hors-diagonale
    X_filtered = X[mask]  # (B * (B-1),)
    P_filtered = P[mask]  # (B * (B-1),)

    # Calcul de la somme pondérée exponentielle
    weighted_exp_sum = torch.sum(torch.exp(X_filtered) * P_filtered)

    # Nombre d'éléments hors-diagonale (normalisation)
    num_elements = B * (B - 1)

    # Calcul final
    log_weighted_mean_exp = torch.log(weighted_exp_sum) - torch.log(torch.tensor(num_elements, dtype=torch.float, device=device))

    return log_weighted_mean_exp


def smile_lower_bound(f, score_prob=None,clip=None): #f = scores. 
    if score_prob is not None:

        f_pond = f*score_prob
        z = weighted_logmeanexp_nodiag(f,score_prob)
        dv = f_pond.diag().mean() - z 
        return dv 


    else:

        if clip is not None:
            f_ = torch.clamp(f, -clip, clip) #si on veut faire un tronckage des data. 
        else:
            f_ = f

        z = logmeanexp_nodiag(f_, dim=(0, 1)) #log-moyenne-exponentielle des data hors de la diagonnale.. 
        dv = f.diag().mean() - z #diff entre moyenne des paires positives, et celle des paires négatives (ici)- peut etre instable car dépend ds valeurs de f.
        #dv est une premiere approximation de l'info mutuelle enft
        js = js_fgan_lower_bound(f)  #lower bound de f avec fgan -- JS-fGAN avec une correction inspirée de DV.
        #stable mais sous-estime 
        #deuxieme approximation de l'info mutuelle 
        with torch.no_grad():
            dv_js = dv - js #là on fait juste un trick pour calculer sans rétropropagation ? 

        return dv #js + dv_js #temp modif 



def sample_context_future(array, batch_size, k, k_prime,dim_data):
    """
    Génère un batch de paires (x, y) à partir d'un array 1D.
    
    Arguments:
      array      : array 1D (numpy) de taille N.
      batch_size : nombre de paires à générer.
      T_context  : longueur du contexte (nombre d'observations passées).
      T_target   : longueur de la cible (nombre d'observations futures).
      
    Retourne:
      x_tensor   : tenseur de forme (batch_size, T_context) contenant les contextes.
      y_tensor   : tenseur de forme (batch_size, T_target) contenant les cibles.
    """
    N = len(array)

    # Les indices t valides sont dans [T_context, N - T_target]
    valid_indices = np.arange(k, N - k_prime + 1)
    chosen_t = np.random.choice(valid_indices, size=batch_size, replace=False)
    
    x_samples = []
    y_samples = []
    for t in chosen_t:
        x = array[t - k:t]        # Contexte : T_context valeurs avant t
        y = array[t:t + k_prime]           # Cible : T_target valeurs à partir de t
        x_samples.append(x)
        y_samples.append(y)
    
    # Convertir en tenseurs torch (type float32)
    # if dim_data == 0:
    #     x_tensor = torch.tensor(np.stack(x_samples), dtype=torch.float32).unsqueeze(-1)
    #     y_tensor = torch.tensor(np.stack(y_samples), dtype=torch.float32).unsqueeze(-1)
    # else:
    x_tensor = torch.tensor(np.stack(x_samples), dtype=torch.float32).unsqueeze(-1)
    y_tensor = torch.tensor(np.stack(y_samples), dtype=torch.float32).unsqueeze(-1)


    return x_tensor, y_tensor


def plot_trajectory(X_past, X_fut, batch_idx=0, delimiter='vertical'):
    """
    Trace une trajectoire complète pour un indice de batch donné.
    
    Paramètres:
      - X_past (torch.Tensor): tenseur de forme (batch_size, T_past, d)
      - X_fut (torch.Tensor): tenseur de forme (batch_size, T_fut, d)
      - batch_idx (int): indice du batch à tracer.
      - delimiter (str): 'vertical' (par défaut) ou 'horizontal' pour tracer la ligne de séparation.
    
    La fonction trace chaque dimension dans un sous-graphe séparé.
    """
    # Conversion du batch sélectionné en numpy
    past = X_past[batch_idx].detach().cpu().numpy()  # forme: (T_past, d)
    fut  = X_fut[batch_idx].detach().cpu().numpy()    # forme: (T_fut, d)
    
    T_past = past.shape[0]
    T_fut = fut.shape[0]
    d = past.shape[1]
    
    # Axes temporels pour le passé et le futur
    time_past = np.arange(-T_past + 1, 1)  # par exemple de -9 à 0 pour T_past=10
    time_fut  = np.arange(1, T_fut + 1)     # de 1 à T_fut
    
    # Création de sous-graphes pour chaque dimension
    fig, axes = plt.subplots(d, 1, figsize=(8, 2 * d), sharex=True)
    if d == 1:
        axes = [axes]  # rendre axes itérable si d==1
    
    for i in range(d):
        ax = axes[i]
        ax.plot(time_past, past[:, i], marker='o', label='Past')
        ax.plot(time_fut, fut[:, i], marker='o', label='Future')
        
        # Affichage du séparateur
        if delimiter == 'vertical':
            ax.axvline(x=0, color='k', linestyle='--', label="Separation" if i==0 else None)
        elif delimiter == 'horizontal':
            # Moins courant pour la séparation temps, mais possible si désiré.
            ax.axhline(y=0, color='k', linestyle='--', label="Separation" if i==0 else None)
        
        ax.set_ylabel(f"Dim {i+1}")
        ax.legend(loc='upper left')
        ax.grid(True)
    
    axes[-1].set_xlabel("Time")
    plt.suptitle(f"Trajectory for batch index {batch_idx}")
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
