import torch
import numpy as np
from torch import nn
from torch_geometric.graphgym.register import register_layer
from torch_geometric.graphgym.config import cfg
from ..utils import embed_1D_scalar, nodes_per_batch_sample


# @register_layer('graph_features_encoder')
# class GraphFeaturesEncoder(nn.Module):
#     def __init__(self, out_dim, hidden_dim=256, t_prec=1000):
#         """
#         Args:
#             out_dim (int): Dimension of the output embeddings.
#             hidden_dim (int): Dimension of intermediate embeddings.
#         """
#         super(GraphFeaturesEncoder, self).__init__()
#         if cfg.train.mode == "dfm":
#             self.feature_mixer = GymMlpMixer(cfg.gt.n_rbf_centroids, hidden_dim, 1, 1) # 3, 1 to 1, 1 if we condition only on a single feature
#         else:
#             self.feature_mixer = GymMlpMixer(cfg.gt.n_rbf_centroids, hidden_dim, 3, 1, out_fc_dim=out_dim)
#         self.time_mixer = GymMlpMixer(hidden_dim, hidden_dim, 2, 1, norm=True, out_fc_dim=out_dim)
        
#         # Learned null token embedding of the same dimension as output embeddings
#         self.null_token = nn.Parameter(torch.zeros(1, hidden_dim), requires_grad=True)
#         # Initialize the null token
#         nn.init.normal_(self.null_token, mean=0.0, std=0.02)

#         self.hidden_dim = hidden_dim
#         self.t_prec = t_prec
        
#     def forward(self, batch, unconditional_prop=0):
#         """
#         Args:
#             batch (torch_geometric.data.Data): Data object containing the input features.
#             unconditional (float): proportion of samples replaced by the learned null token
        
#         Returns:
#             torch.Geometric.data.Data: Data object with updated embedding in batch.c.
#         """

#         # Merge conditioning features
#         batch = self.feature_mixer(batch)

#         if cfg.train.mode == "dfm": # in order to use this class also for deep simulator

#             # During training, replace a fraction of conditional inputs by the null token (mimics learning an unconditional model)
#             uncond_idx = np.random.choice(np.arange(len(batch.c)), int(unconditional_prop * len(batch.c)), replace=False)
#             batch.c[uncond_idx] = self.null_token

#             # if cfg.gt.sample_separate_t:
#                 # Map t_x values from [0, 1] to integer values
#             t_x = batch.t_x * self.t_prec
#             if not hasattr(cfg.train, 'distortion_pow_e'): # cheat for backward compatibility --> to suppress
#                 # In fact these do not have to be integers
#                 t_x = t_x.int()
#             # Sine t positional encoding
#             time_embed_x = embed_1D_scalar(t_x, self.hidden_dim, self.t_prec).to(t_x.device)
#             batch = self.time_mixer(batch, time_embed_x.unsqueeze(1), dest_feat='c_x')
#             # Do the same for t_e
#             t_e = batch.t_e * self.t_prec
#             if not hasattr(cfg.train, 'distortion_pow_e'):
#                 t_e = t_e.int()
#             time_embed_e = embed_1D_scalar(t_e, self.hidden_dim, self.t_prec).to(t_e.device)
#             batch = self.time_mixer(batch, time_embed_e.unsqueeze(1), dest_feat='c_e')
#             # Do the same for f
#             if cfg.gt.get("sizing", False):
#                 t_f = batch.t_f * self.t_prec
#                 time_embed_f = embed_1D_scalar(t_f, self.hidden_dim, self.t_prec).to(t_f.device)
#                 batch = self.time_mixer(batch, time_embed_f.unsqueeze(1), dest_feat='c_f')
#             # else:
#             #     # Map t values from [0, 1] to integer values
#             #     t = (batch.t * self.t_prec).int()
#             #     # Sine t positional encoding
#             #     time_embed = embed_1D_scalar(t, self.hidden_dim, self.t_prec).to(t.device)
#             #     # Merge t with other conditioning features
#             #     batch = self.time_mixer(batch, time_embed.unsqueeze(1))

#         return batch
    

@register_layer('time_encoder')
class TimeEncoder(nn.Module):
    def __init__(self, out_dim, hidden_dim=256, t_prec=1000):
        """
        Simplified version of the above, in case one is only interested in conditioning on denoising time index
        Args:
            out_dim (int): Dimension of the output embeddings.
            hidden_dim (int): Dimension of intermediate embeddings.
        """
        super(TimeEncoder, self).__init__()

        self.time_mixer = GymMlpMixer(hidden_dim, hidden_dim, 1, 1, norm=True, out_fc_dim=out_dim)
        # if cfg.gt.get('conditional_gen', False):
        #     # self.spec_class_embedding = nn.ModuleList([
        #     #     torch.nn.Embedding(num_embeddings=cfg.gnn.n_bins, embedding_dim=hidden_dim // 2) for _ in range(cfg.gnn.n_spec)
        #     # ])
        #     # self.feature_mixer = GymMlpMixer((hidden_dim // 2) * 3, hidden_dim, 1, 1, norm=True, out_fc_dim=out_dim)

        #     #### TODO: TEST ON A SINGLE DIM ###
        #     self.spec_class_embedding = torch.nn.Embedding(num_embeddings=cfg.gnn.n_bins, embedding_dim=hidden_dim) 
        #     self.feature_mixer = GymMlpMixer(hidden_dim, hidden_dim, 1, 1, norm=True, out_fc_dim=out_dim)
        #     ###

        #     # Classifier-free guidance: learned null token embedding of the same dimension as output embeddings
        #     if cfg.gt.conditioning_loss == 'cfg':
        #         self.null_token = nn.Parameter(torch.zeros(1, out_dim), requires_grad=True)
        #         # Initialize the null token
        #         nn.init.normal_(self.null_token, mean=0.0, std=0.02)
        self.hidden_dim = hidden_dim
        self.t_prec = t_prec


    def forward(self, batch):
        """
        Args:
            batch (torch_geometric.data.Data): Data object containing the input features
        
        Returns:
            torch.Geometric.data.Data: Data object with updated embedding in batch.c
        """

        # # First encode the specifications
        # if cfg.gt.get('conditional_gen', False):
        #     # spec_embed = torch.cat([self.spec_class_embedding[i](batch.y[:, i]) for i in range(cfg.gnn.n_spec)], dim=-1)
        #     spec_embed = self.spec_class_embedding(batch.y[:, 0]) ### TODO: TEST ON GAIN
        #     batch = self.feature_mixer(batch, spec_embed.unsqueeze(1), dest_feat='c_spec')
        #     if cfg.gt.get('conditional_gen', False) and (cfg.gt.get('conditioning_loss', '') == 'cfg'): # Classifier-free guidance case
        #         # During training, replace a fraction of conditional inputs by the null token (mimics learning an unconditional model)
        #         uncond_idx = np.random.choice(np.arange(len(batch.c_spec)), int(unconditional_prop * len(batch.c_spec)), replace=False)
        #         batch.c_spec[uncond_idx] = self.null_token

        # if hasattr(cfg.gt, 'sample_separate_t') and cfg.gt.sample_separate_t:
            # Map t values from [0, 1] to integer values
        t_x = batch.t_x * self.t_prec
        # if not hasattr(cfg.train, 'distortion_pow_e'): # cheat for backward compatibility --> to suppress
        #     # In fact these do not have to be integers
        #     t_x = t_x.int()
        # Sine t positional encoding
        time_embed_x = embed_1D_scalar(t_x, self.hidden_dim, self.t_prec).to(t_x.device)
        # Pass t in batch.c and embed through mlp mixer
        # batch.c_init = time_embed_x.unsqueeze(1)
        batch = self.time_mixer(batch, time_embed_x.unsqueeze(1), dest_feat='c_x')

        # Do the same for e
        t_e = batch.t_e * self.t_prec
        # if not hasattr(cfg.train, 'distortion_pow_e'):
        #     t_e = t_e.int()
        time_embed_e = embed_1D_scalar(t_e, self.hidden_dim, self.t_prec).to(t_e.device)
        # batch.c_init = time_embed_e.unsqueeze(1)
        batch = self.time_mixer(batch, time_embed_e.unsqueeze(1), dest_feat='c_e')

        # Do the same for f
        if cfg.gt.get("sizing", False):
            t_f = batch.t_f * self.t_prec
            time_embed_f = embed_1D_scalar(t_f, self.hidden_dim, self.t_prec).to(t_f.device)
            # batch.c_init = time_embed_f.unsqueeze(1)
            # dest_feat = 'c_x' if (cfg.gt.get("process_feats_with_x", False) and cfg.train.get("noise_feat_only", False)) else 'c_f'
            batch = self.time_mixer(batch, time_embed_f.unsqueeze(1), dest_feat='c_f')

        # if cfg.gt.get('conditional_gen', False):
            
        #     spec = (batch.spec[:, cfg.gt.conditional_dim] * 100).clamp(max=3000)
        #     spec_embed = embed_1D_scalar(spec, self.hidden_dim, 3000).to(spec.device)
        #     batch = self.feature_mixer(batch, spec_embed.unsqueeze(1), dest_feat='c_spec')
            # # During training, replace a fraction of conditional inputs by the null token (mimics learning an unconditional model)
            # uncond_idx = np.random.choice(np.arange(len(batch.c_spec)), int(unconditional_prop * len(batch.c_spec)), replace=False)
            # # Add idx for which the simulation failed
            # null_idx = set(torch.where(batch.spec[:, cfg.gt.conditional_dim] == 0)[0].cpu().numpy())
            # uncond_idx = list(null_idx.union(set(uncond_idx)))
            # batch.c_spec[uncond_idx] = self.null_token

        return batch
    

class GymMlpMixer(nn.Module):


    def __init__(self, in_pts, out_pts, in_dim, out_dim, norm=False, out_fc_dim=None):
        super().__init__()

        self.depth_wise = nn.Linear(in_pts, out_pts)
        self.point_wise = nn.Linear(in_dim, out_dim)
        if norm:
            self.norm = nn.LayerNorm(out_pts)
        self.act = nn.SiLU()
        if out_fc_dim is not None:
            self.out_fc = nn.Linear(out_pts, out_fc_dim)


    def forward(self, batch, c_embed, dest_feat): # extra_modality=None
        
        # Expected c_embed size: bs, 1, dim (typically equal to n_rbf_centroids)
        # c_embed = batch.c_init

        # # If an extra input is provided (such as another conditioning tensor), first concat the two in the feature axis (dim=1).
        # # This supposes that batch.c is already defined.
        # if extra_modality is not None:
        #     c_embed = batch.c
        #     # If c_x or c_e, broadcast to t_x / t_e dimension
        #     if dest_feat == 'c_x':
        #         c_embed = c_embed[batch.batch]
        #     elif dest_feat == 'c_e':
        #         nodes_per_graph = nodes_per_batch_sample(batch)
        #         edges_for_graph = (nodes_per_graph * (nodes_per_graph - 1))
        #         c_embed = c_embed.repeat_interleave(edges_for_graph, dim=0)
        #     elif dest_feat == 'c_f':
        #         c_embed = c_embed[batch.batch]
        #     c_embed = torch.cat([c_embed, extra_modality], dim=1)
            
        c_embed = self.depth_wise(c_embed)
        c_embed = self.point_wise(c_embed.transpose(1, 2)).transpose(1, 2)
        
        # Depth-wise LN, i.e. separately for each feature
        if hasattr(self, 'norm'):
            c_embed = self.norm(c_embed)
        c_embed = self.act(c_embed)

        if hasattr(self, 'out_fc'):
            c_embed = self.out_fc(c_embed).squeeze() # Output size: bs, out_fc_dim

        # if (dest_feat != 'c_spec') and cfg.gt.get('conditional_gen', False) and cfg.gt.get('add_spec_cond_to_t', False):
        #     broadcast_dim = batch.batch[batch.triu_edge_index[0]].repeat(2) if dest_feat == 'c_e' else batch.batch
        #     c_embed = c_embed + batch.c_spec[broadcast_dim] ## TODO: essayer avec et sans quand sizing only?
        
        if dest_feat == 'c_spec':
            batch.c_spec = c_embed
        elif dest_feat == 'c_x':
            batch.c_x = c_embed
        elif dest_feat == 'c_e':
            batch.c_e = c_embed
        elif dest_feat == 'c_f':
            batch.c_f = c_embed
        
        return batch
