import sys
PROJECT_PATH = "./Data_Pattern_Learnability" # Absolute path to the project directory
sys.path.append(PROJECT_PATH)  # Add the project path to sys.path

import torch.nn as nn
import torch
from utils.utils import logmeanexp_nodiag, js_fgan_lower_bound,sample_context_future
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm 
from utils.lower_bounds import infonce_lower_bound,nwj_lower_bound,tuba_lower_bound,js_lower_bound,smile_lower_bound,dv_upper_lower_bound
from utils.critics import SeparableCritic,ConcatCritic,SequentialCritic,SeparableCriticEvoRate,ConcatCriticEvoRate
from utils.process.process import SyntheticDataGenerator,MarkovGaussianGenerator,EvoRateGenerator,AutoRegressiveGenerator
import numpy as np


#-------------- 1)- the classic evoRate Estimator --------------


class evoRateEstimator(nn.Module):
    
    def __init__(self,dim_data,k,k_prime,num_layer_lstm_projector):
        super(evoRateEstimator, self).__init__()

        #the projector 
        self.seq_proj_past = nn.LSTM(
            input_size=dim_data,
            hidden_size=dim_data,
            num_layers=num_layer_lstm_projector,
            batch_first=True
        )
     
        self.seq_proj_future = nn.LSTM(
            input_size=dim_data,
            hidden_size=dim_data,
            num_layers=num_layer_lstm_projector,
            batch_first=True
        )
        self.k = k
        self.k_prime = k_prime
       

    def forward(self,x,y,evoRate_compute=False):
        """x: input, the past,
           y: output, the future 
           return : the predictive information estimation, derived from the evoRate estimator"""
        
        # 1) On encode le passé
        out_x, _ = self.seq_proj_past(x)  
        #out_x = self.seq_proj_past(x)               # out_x : (batch_size, T, hidden_dim)
        x_embed = out_x[:, -1, :]                   # on prend le dernier pas de temps comme résumé

       
        # 2) On encode le futur en BACKWARD, bidirectionnal version 
        y_rev = torch.flip(y, dims=[1])             # (batch_size, T, dim_data), renversé
        out_y, _ = self.seq_proj_past(y_rev)      # out_y : (batch_size, T, hidden_dim) - its speeding up calculation to use that.
        #y_embed = out_y[:, -1, :]                   # dernier pas de temps (qui correspond en fait au "vrai" début de la seq initiale)

        # if evoRate_compute:
        y_embed =  y[:,0,:] #for evorate - temporary we test it.
        #x_embed = x[:,-1,:]
        scores = -torch.cdist(x_embed, y_embed)**2 #the new score here used!! 

        z = logmeanexp_nodiag(scores, dim=(0, 1)) 
        dv = scores.diag().mean() - z
        js = js_fgan_lower_bound(scores)

        with torch.no_grad():
            dv_js = dv - js #là on fait juste un trick pour calculer sans rétropropagation ? 

        return js + dv_js #temp modif 
    

    #should be with trainers. but temporary we but it with estimators. 
    def train_evoRate_estimator(self, sequence, k, k_prime, dim_data,batch_size, iterations=2000, lr=5e-4):
        
        """
        Entraîne le modèle evoRateEstimator pour maximiser l'information prédictive.
        
        Arguments:
        chain3_bin      : array numpy de forme (N,) contenant la chaîne (ex: 2000 valeurs)
        context_length  : nombre d'observations pour le passé
        target_length   : nombre d'observations pour le futur
        batch_size      : nombre de fenêtres à extraire par itération
        iterations      : nombre d'itérations d'entraînement
        lr              : taux d'apprentissage
        
        Retourne :
        estimates : liste des estimations de l'information prédictive au cours de l'entraînement.
        """
        # Convertir chain3_bin en tenseur float et obtenir la longueur totale
    
        # Instanciation du modèle evoRateEstimator
        # Ici, dim_data = 1 (puisque chaque observation est un scalaire)
        model = evoRateEstimator(dim_data=dim_data, k = k, k_prime= k_prime, num_layer_lstm_projector=3)
        model = model.to("cpu")
        optimizer = optim.Adam(model.parameters(), lr=lr)
        
        estimates = []
        for _ in tqdm(range(iterations)):
            optimizer.zero_grad()

            x,y = sample_context_future(sequence,batch_size,k,k_prime,dim_data=dim_data)
            # Si x est de forme (N, 1), ajoute un batch dimension
            if x.ndim == 2:
                x = x.unsqueeze(0)  # Devient (1, N, 1)

# Si y est aussi 2D :
            if y.ndim == 2:
                y = y.unsqueeze(0)  # Devient (1, N, 1)

            mi_est = model(x, y)
            loss = -mi_est  # on maximise mi_est, donc on minimise -mi_est
            loss.backward()
            optimizer.step()
            estimates.append(mi_est.detach().cpu().item())
        return estimates



#please not that when taking k_prime = 1, and taking the ConcatCriticEvoRate, we recover the evoRate estimator. 

class evoPredEstimator:
    """
    This class is used to estimate the evoPred(k,k'). 
    """
    def __init__(self,
                 dim_data:int,
                 k:int,
                 k_prime:int,
                 batch_size:int,
                 estimator:str="smile",
                 type_of_critic:str="SeparableCritic",
                 **critic_kwargs):
        """
        Args:
            dim_data: the dimension of the observations
            T_past: correspond to k in the paper i.e. the past of evoPred
            T_fut: correspond to k' in the paper i.e. the future
            batch_size: the size of the batch in the estimator
            estimator: the lower bound you want to use for mutual information. 
                      Could be "nwj", "smile", "infonce", "tuba", "js", "interpolated"
            type_of_critic: the type of critic you want to use. 
                          Could be "SeparableCritic", "ConcatCritic", "SequentialCritic"
            critic_kwargs: additional parameters for the critic (hidden_dim, embed_dim, layers, activation)
        """
        self.dim_data = dim_data
        self.k = k 
        self.k_prime = k_prime 
        self.batch_size = batch_size
        self.estimator = estimator
        self.type_of_critic = type_of_critic

        # Default parameters for critics
        default_params = {
            'hidden_dim': 256,
            'embed_dim': 32,
            'layers': 2,
            'activation': "relu"
        }
        
        # Update defaults with provided kwargs
        critic_params = {**default_params, **critic_kwargs}

        if self.type_of_critic == "SeparableCritic":
            self.critic = SeparableCritic(
                dim_past=k*dim_data,
                dim_fut=k_prime*dim_data,
                hidden_dim=critic_params['hidden_dim'],
                embed_dim=critic_params['embed_dim'],
                layers=critic_params['layers'],
                activation=critic_params['activation']
            )
            
        elif self.type_of_critic == "ConcatCritic":
            self.critic = ConcatCritic(
                dim_past=dim_data*k,
                dim_fut=k_prime*dim_data,
                hidden_dim=critic_params['hidden_dim'],
                layers=critic_params['layers'],
                activation=critic_params['activation']
            )

        elif self.type_of_critic == "SequentialCritic":
            self.critic = SequentialCritic(
                dim_data=dim_data,
                num_layer_lstm_projector=critic_params['layers'],
                T_past=self.k,
                T_future=self.k_prime
            )

        elif self.type_of_critic == "SeparableCriticEvoRate":
            self.critic = SeparableCriticEvoRate(
                dim=dim_data,
                hidden_dim=critic_params['hidden_dim'],
                embed_dim=critic_params['embed_dim'],
                layers=critic_params['layers'],
                activation=critic_params['activation']
            )

        elif self.type_of_critic == "ConcatCriticEvoRate":
            self.critic = ConcatCriticEvoRate(
                dim=dim_data,
                hidden_dim=critic_params['hidden_dim'],
                layers=critic_params['layers'],
                activation=critic_params['activation']
            )

    def forward(self,x,y,clip=None):
        #x : [batch_size,T,d]
        #y : [batcj_size,T_prim,d]

       
        scores = self.critic(x,y) 

        if self.estimator == 'infonce':
            mi = infonce_lower_bound(scores)
        elif self.estimator == 'nwj':
            mi = nwj_lower_bound(scores)
        elif self.estimator == 'tuba':
            mi = tuba_lower_bound(scores, None)
        elif self.estimator == 'js':
            mi = js_lower_bound(scores)
        elif self.estimator == 'smile':
            mi = smile_lower_bound(scores, clip=clip)
        elif self.estimator == 'dv':
            mi = dv_upper_lower_bound(scores)
        return mi
    

    def train_estimator_on_large_sequence(self,sequence:np.ndarray, iterations:int=2000, lr:float=5e-4):

        """ this method allow the estimator to be directly train on a large sequence.
        sequence: (N,dimension_data), could be array numpy."""

        self.critic = self.critic.to("cpu")
        optimizer = optim.Adam(self.critic.parameters(), lr=float(lr))
        
        estimates = []
        for _ in tqdm(range(iterations)):
            optimizer.zero_grad()

            x,y = sample_context_future(sequence,self.batch_size,self.k,self.k_prime,dim_data=self.dim_data)
       
            mi_est = self.forward(x, y)
            loss = -mi_est  # on maximise mi_est, donc on minimise -mi_est
            loss.backward()
            optimizer.step()
            estimates.append(mi_est.detach().cpu().item())
        return estimates
    
    def train_on_sample_generator(self,generator:SyntheticDataGenerator, iterations:int=2000, lr:float=5e-4):

        self.critic = self.critic.to("cpu")
        optimizer = optim.Adam(self.critic.parameters(), lr=float(lr))
        
        estimates = []
        for _ in tqdm(range(iterations)):
            optimizer.zero_grad()

            x, y = generator.sample() #we can sampling directly from the generator (for synthetics toys)
            x, y = x.to("cpu"), y.to("cpu")
            mi_est = self.forward(x, y)
            loss = -mi_est  # on maximise mi_est, donc on minimise -mi_est
            loss.backward()
            optimizer.step()
            estimates.append(mi_est.detach().cpu().item())
        return estimates
        
"""Forecastability, trend and seasonality analysis."""


import numpy as np

from scipy.stats import entropy


def forecastabilty(ts):
  """Forecastability Measure.

  Args:
    ts: time series

  Returns:
    1 - the entropy of the fourier transformation of
          time series / entropy of white noise
  """
  ts = (ts - ts.min())/(ts.max()-ts.min())
  fourier_ts = np.fft.rfft(ts).real
  fourier_ts = (fourier_ts - fourier_ts.min()) / (
      fourier_ts.max() - fourier_ts.min())
  fourier_ts /= fourier_ts.sum()
  entropy_ts = entropy(fourier_ts)
  fore_ts = 1-entropy_ts/(np.log(len(ts)))
  if np.isnan(fore_ts):
    return 0
  return fore_ts


def foreca_func(ts, window, jump=1):
  """Calculates the forecastability of a moving window.

  Args:
    ts: time series
    window: length of slices
    jump: skipped step when taking subslices

  Returns:
    a list of forecastability measures for all slices.
  """

  # ts = Trend(ts).detrend()
  if len(ts) <= 25:
    return forecastabilty(ts)
  fore_lst = np.array([
      forecastabilty(ts[i - window:i])
      for i in np.arange(window, len(ts), jump)
  ])
  fore_lst = fore_lst[~np.isnan(fore_lst)]  # drop nan
  return fore_lst     

if __name__ == "__main__":
    estimator = evoPredEstimator(dim_data=10, k=10, k_prime=10, batch_size=10, estimator="smile", type_of_critic="ConcatCritic")
    #generator_data = MarkovGaussianGenerator(T_past=10, T_fut=10, batch_size=10, rho=0.5, dim=10)
    generator_data = EvoRateGenerator(T_past=10, T_fut=10, batch_size=10, rho=0.5, dim=10)
    estimates=estimator.train_on_sample_generator(generator_data,iterations=10000)
    print(f" the estimation evoPred is {np.mean(estimates[-100:])}, while the true one is {generator_data.get_mutual_information()}")

    sequence = np.random.randn(1000, 10)
    estimates = estimator.train_estimator_on_large_sequence(sequence,iterations=10000)
    print(np.mean(estimates[-100:]))

    #case where you want to use a long sequence 
    generator = AutoRegressiveGenerator(batch_size=65, p=5, rho=0.7, dim=10)
    long_seq = generator.generate_long_array(N=100000)
    estimates = estimator.train_estimator_on_large_sequence(long_seq,iterations=10000)
    print(np.mean(estimates[-100:]))

# class evoRateEstimator(nn.Module):
    
#     def __init__(self,dim_data,num_layer_lstm_projector):
#         super(evoRateEstimator, self).__init__()

#         #the projector 
#         self.seq_proj_past = nn.LSTM(
#             input_size=dim_data,
#             hidden_size=dim_data,
#             num_layers=num_layer_lstm_projector,
#             batch_first=True
#         )
     
#         self.seq_proj_future = nn.LSTM(
#             input_size=dim_data,
#             hidden_size=dim_data,
#             num_layers=num_layer_lstm_projector,
#             batch_first=True
#         )

       

#     def forward(self,x,y,evoRate_compute=False):
#         """x: input, the past,
#            y: output, the future 
#            return : the predictive information estimation, derived from the evoRate estimator"""
        
#         # 1) On encode le passé
#         out_x, _ = self.seq_proj_past(x)  
#         #out_x = self.seq_proj_past(x)               # out_x : (batch_size, T, hidden_dim)
#         x_embed = out_x[:, -1, :]                   # on prend le dernier pas de temps comme résumé

       
#         # 2) On encode le futur en BACKWARD, bidirectionnal version 
#         y_rev = torch.flip(y, dims=[1])             # (batch_size, T, dim_data), renversé
#         out_y, _ = self.seq_proj_past(y_rev)      # out_y : (batch_size, T, hidden_dim) - its speeding up calculation to use that.
#         #y_embed = out_y[:, -1, :]                   # dernier pas de temps (qui correspond en fait au "vrai" début de la seq initiale)

#         # if evoRate_compute:
#         y_embed =  y[:,0,:] #for evorate - temporary we test it.
#         #x_embed = x[:,-1,:]
#         scores = -torch.cdist(x_embed, y_embed)**2 #the new score here used!! 

#         z = logmeanexp_nodiag(scores, dim=(0, 1)) 
#         dv = scores.diag().mean() - z
#         js = js_fgan_lower_bound(scores)

#         with torch.no_grad():
#             dv_js = dv - js #là on fait juste un trick pour calculer sans rétropropagation ? 

#         return js + dv_js #temp modif 
    


# #should be with trainers. but temporary we but it with estimators. 
# def train_evoRate_estimator(sequence,past_contexte_length,dim_data,batch_size, iterations=2000, lr=5e-4):
#     """
#     Entraîne le modèle evoRateEstimator pour maximiser l'information prédictive.
    
#     Arguments:
#       chain3_bin      : array numpy de forme (N,) contenant la chaîne (ex: 2000 valeurs)
#       context_length  : nombre d'observations pour le passé
#       target_length   : nombre d'observations pour le futur
#       batch_size      : nombre de fenêtres à extraire par itération
#       iterations      : nombre d'itérations d'entraînement
#       lr              : taux d'apprentissage
      
#     Retourne :
#       estimates : liste des estimations de l'information prédictive au cours de l'entraînement.
#     """
#     # Convertir chain3_bin en tenseur float et obtenir la longueur totale
   
#     # Instanciation du modèle evoRateEstimator
#     # Ici, dim_data = 1 (puisque chaque observation est un scalaire)
#     model = evoRateEstimator(dim_data=dim_data, num_layer_lstm_projector=2)
#     model = model.to("cpu")
#     optimizer = optim.Adam(model.parameters(), lr=lr)
    
#     estimates = []
#     for _ in tqdm(range(iterations)):
#         optimizer.zero_grad()

#         x,y = sample_context_future(sequence,batch_size,past_contexte_length,10,dim_data=dim_data)
   
#         mi_est = model(x, y)
#         loss = -mi_est  # on maximise mi_est, donc on minimise -mi_est
#         loss.backward()
#         optimizer.step()
#         estimates.append(mi_est.detach().cpu().item())
#     return estimates




# #-------------- 2)- predictive information estimator  --------------

# class predictiveInformationEstimator(nn.Module):
#     """
#     Estime l'information prédictive (basée sur CPC) en combinant plusieurs scores
#     calculés à partir d'indices contrôlés par T_past (pour le passé) et T_future (pour le futur).
    
#     Les embeddings sont obtenus via des LSTM :
#       - x_embed (résumé du passé) est obtenu en prenant le dernier pas de temps de l'encodeur du passé.
#       - y_embed (résumé du futur) est obtenu en encodant la séquence future en sens inverse.
    
#     Pour le calcul des scores :
#       - Pour le futur, on regarde les indices de 0 à T_future-1 dans y.
#       - Pour le passé, on regarde les indices négatifs de -1 à -T_past dans x.
    
#     Les scores sont ensuite combinés de manière pondérée, les poids étant appris et normalisés via softmax.
#     """
#     def __init__(self, dim_data, num_layer_lstm_projector, T_past=5, T_future=5):
#         super(predictiveInformationEstimator, self).__init__()
        
#         self.T_past = T_past      # nombre d'indices à regarder dans le passé (x)
#         self.T_future = T_future  # nombre d'indices à regarder dans le futur (y)
        
#         # LSTM pour encoder le passé
#         self.seq_proj_past = nn.LSTM(
#             input_size=dim_data,
#             hidden_size=dim_data,
#             num_layers=num_layer_lstm_projector,
#             batch_first=True
#         )
     
#         # LSTM pour encoder le futur (calculé sur la séquence inversée)
#         self.seq_proj_future = nn.LSTM(
#             input_size=dim_data,
#             hidden_size=dim_data,
#             num_layers=num_layer_lstm_projector,
#             batch_first=True
#         )
        
#         # Nombre total de scores : T_future (pour le futur) + T_past (pour le passé)
#         self.score_weights = nn.Parameter(torch.ones(T_future + T_past))
    

#     def forward(self, x, y):
#         """
#         Arguments:
#           x : tenseur d'entrée du passé, de forme (batch_size, T_past_sequence, dim_data)
#           y : tenseur d'entrée du futur, de forme (batch_size, T_future_sequence, dim_data)
#         Retourne :
#           Une estimation de l'information prédictive.
#         """
#         # 1) Encodage du passé
#         out_x, _ = self.seq_proj_past(x)
#         # x_embed est le résumé du passé (dernier pas de temps)
#         x_embed = out_x[:, -1, :]

#         # 2) Encodage du futur (en inversant la séquence)
#         y_rev = torch.flip(y, dims=[1])
#         out_y, _ = self.seq_proj_future(y_rev)
#         # y_embed correspond au résumé du futur (premier élément de la séquence inversée)
#         y_embed = out_y[:, -1, :]

#         scores_list = []
        
#         # Calcul des scores pour le futur : on prend les indices 0 à T_future-1
#         for i in range(self.T_future):
#             score_future = -torch.cdist(x_embed, y[:, i, :])**2
#             scores_list.append(score_future)
        
#         # Calcul des scores pour le passé : on prend les indices -1, -2, ..., -T_past
#         for i in range(self.T_past):
#             score_past = -torch.cdist(x[:, -(i+1), :], y_embed)**2
#             scores_list.append(score_past)
        
#         # Empilement des scores en un tenseur de forme (T_future + T_past, batch_size, batch_size)
#         scores_tensor = torch.stack(scores_list, dim=0)
        
#         # Normalisation des poids pour garantir que leur somme soit 1
#         normalized_weights = F.softmax(self.score_weights, dim=0)
#         weighted_scores = normalized_weights.view(-1, 1, 1) * scores_tensor
        
#         # Somme pondérée des scores
#         scores = weighted_scores.sum(dim=0)

#         # Calcul du log-mean-exp hors diagonale et des bornes (fonctions définies ailleurs)
#         z = logmeanexp_nodiag(scores, dim=(0, 1))
#         dv = scores.diag().mean() - z
#         js = js_fgan_lower_bound(scores)

#         # Trick sans rétropropagation pour la différence dv - js
#         with torch.no_grad():
#             dv_js = dv - js

#         return js + dv_js



# #should be with trainers. but temporary we but it with estimators. 
# def train_predictive_information(sequence,
#                                  T_past,
#                                  T_futur,
#                                  dim_data,
#                                  batch_size,
#                                  iterations=2000, 
#                                  lr=5e-4):
#     """
#     Entraîne le modèle evoRateEstimator pour maximiser l'information prédictive.
    
#     Arguments:
#       chain3_bin      : array numpy de forme (N,) contenant la chaîne (ex: 2000 valeurs)
#       context_length  : nombre d'observations pour le passé
#       target_length   : nombre d'observations pour le futur
#       batch_size      : nombre de fenêtres à extraire par itération
#       iterations      : nombre d'itérations d'entraînement
#       lr              : taux d'apprentissage
      
#     Retourne :
#       estimates : liste des estimations de l'information prédictive au cours de l'entraînement.
#     """
#     # Convertir chain3_bin en tenseur float et obtenir la longueur totale
   
#     # Instanciation du modèle evoRateEstimator
#     # Ici, dim_data = 1 (puisque chaque observation est un scalaire)
#     model = predictiveInformationEstimator(dim_data=dim_data, num_layer_lstm_projector=2,T_past=T_past,T_future=T_futur)
#     model = model.to("cpu")
#     optimizer = optim.Adam(model.parameters(), lr=lr)
    
#     estimates = []
    

#     for i in tqdm(range(iterations)):
       
#         optimizer.zero_grad()

#         x,y = sample_context_future(sequence,batch_size,T_past,T_futur,dim_data=dim_data)
   
#         mi_est = model(x, y)
#         loss = -mi_est  # on maximise mi_est, donc on minimise -mi_est
#         loss.backward()
#         optimizer.step()
#         estimates.append(mi_est.detach().cpu().item())
#     return estimates
