import numpy as np
import torch
import torch.nn as nn

import torch.nn.functional as F


def mlp(dim, hidden_dim, output_dim, layers, activation):
    """Create a mlp from the configurations."""
    activation = {
        'relu': nn.ReLU
    }[activation]

    seq = [nn.Linear(dim, hidden_dim), activation()]
    for _ in range(layers):
        seq += [nn.Linear(hidden_dim, hidden_dim), activation()]
    seq += [nn.Linear(hidden_dim, output_dim)]

    return nn.Sequential(*seq)


class SeparableCritic(nn.Module):
    """Separable critic. where the output value is g(x) h(y). """

    def __init__(self, dim_past,dim_fut, hidden_dim, embed_dim, layers, activation,**extra_kwargs):
        super(SeparableCritic, self).__init__()
        self._g = mlp(dim_past, hidden_dim, embed_dim, layers, activation)
        self._h = mlp(dim_fut, hidden_dim, embed_dim, layers, activation)
       


    def forward(self, x, y):
        batch_size = x.size(0)
        x, y = x.reshape(batch_size, -1), y.reshape(batch_size, -1)  # just fattening the data
 #just fattening the data 
    
        scores = torch.matmul(self._h(y), self._g(x).t()) #the CPC score ? 
        return scores


class ConcatCritic(nn.Module):
    """Concat critic, where we concat the inputs and use one MLP to output the value."""

    def __init__(self, dim_past,dim_fut, hidden_dim, layers, activation, **extra_kwargs):
        super(ConcatCritic, self).__init__()
        # output is scalar score
        self._f = mlp(dim_past +dim_fut, hidden_dim, 1, layers, activation)

    def forward(self, x, y):
        batch_size = x.size(0)
        x, y = x.reshape(batch_size, -1), y.reshape(batch_size, -1)  # just fattening the data
 #just fattening the data 
    
        
        # Tile all possible combinations of x and y
        x_tiled = torch.stack([x] * batch_size, dim=0)
        y_tiled = torch.stack([y] * batch_size, dim=1)
        # xy is [batch_size * batch_size, x_dim + y_dim]
        xy_pairs = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [
                                 batch_size * batch_size, -1])
        # Compute scores for each x_i, y_j pair.
        scores = self._f(xy_pairs)
        return torch.reshape(scores, [batch_size, batch_size]).t()




class SequentialCritic(nn.Module):
    """our current proposed critic for the density estimate."""
    def __init__(self, dim_data, num_layer_lstm_projector, T_past=5, T_future=5):
        super(SequentialCritic, self).__init__()
        self.T_past = T_past      # nombre d'indices à regarder dans le passé (x)
        self.T_future = T_future  # nombre d'indices à regarder dans le futur (y)
        
        # LSTM pour encoder le passé
        self.seq_proj_past = nn.LSTM(
            input_size=dim_data,
            hidden_size=dim_data,
            num_layers=num_layer_lstm_projector,
            batch_first=True
        )
     
        # LSTM pour encoder le futur (calculé sur la séquence inversée)
        self.seq_proj_future = nn.LSTM(
            input_size=dim_data,
            hidden_size=dim_data,
            num_layers=num_layer_lstm_projector,
            batch_first=True
        )

    
        # Nombre total de scores : T_future (pour le futur) + T_past (pour le passé)
        self.score_weights = nn.Parameter(torch.ones(T_future + T_past))

    def forward(self, x, y):
         # 1) Encodage du passé
        out_x, _ = self.seq_proj_past(x)
        # x_embed est le résumé du passé (dernier pas de temps)
        x_embed = out_x[:, -1, :]

        # 2) Encodage du futur (en inversant la séquence)
        y_rev = torch.flip(y, dims=[1])
        out_y, _ = self.seq_proj_future(y_rev)
        # y_embed correspond au résumé du futur (premier élément de la séquence inversée)
        y_embed = out_y[:, -1, :]

        scores_list = []
        
        # Calcul des scores pour le futur : on prend les indices 0 à T_future-1
        for i in range(self.T_future):
            score_future = -torch.cdist(x_embed, y[:, i, :])**2
            scores_list.append(score_future)
        
        # Calcul des scores pour le passé : on prend les indices -1, -2, ..., -T_past
        for i in range(self.T_past):
            score_past = -torch.cdist(x[:, -(i+1), :], y_embed)**2
            scores_list.append(score_past)
        
        # Empilement des scores en un tenseur de forme (T_future + T_past, batch_size, batch_size)
        scores_tensor = torch.stack(scores_list, dim=0)
        
        # Normalisation des poids pour garantir que leur somme soit 1
        normalized_weights = F.softmax(self.score_weights, dim=0)
        weighted_scores = normalized_weights.view(-1, 1, 1) * scores_tensor
        
        # Somme pondérée des scores
        scores = weighted_scores.sum(dim=0)
        return scores 
    


# the critics used in evoRate - finally they just updated a bit there code for differents critics. 

class SeparableCriticEvoRate(nn.Module):
    """Separable critic. where the output value is g(x) h(y). """

    def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs):
        super(SeparableCriticEvoRate, self).__init__()
        self._g = mlp(dim, hidden_dim, embed_dim, layers, activation)
        self._h = mlp(dim, hidden_dim, embed_dim, layers, activation)
        
        self.seq_proj = nn.LSTM(input_size=dim, hidden_size=dim,num_layers=layers,batch_first=True)

    def forward(self, x, y):
        y = y[:,0,:] #we only take the first point (the evorate transformation).
        output, _ = self.seq_proj(x) #just add a LSTM to do the projection finally
        yhat = output[:,-1,:]  
        scores = torch.matmul(self._h(y), self._g(yhat).t())
        return scores


class ConcatCriticEvoRate(nn.Module):
    """Concat critic, where we concat the inputs and use one MLP to output the value."""

    def __init__(self, dim, hidden_dim, layers, activation, **extra_kwargs):
        super(ConcatCriticEvoRate, self).__init__()
        # output is scalar score
        self._f = mlp(dim * 2, hidden_dim, 1, layers, activation)
   
        self.seq_proj = nn.LSTM(input_size=dim, hidden_size=dim,num_layers=layers,batch_first=True)

    def forward(self, x, y):
        y = y[:,0,:] #only taking the first point for evoRate
        output, _ = self.seq_proj(x)
        yhat = output[:,-1,:]  
        scores = -torch.cdist(yhat, y)**2
        batch_size = x.size(0)
        return scores
