import torch
import numpy as np
from torch import nn
from torch_geometric.graphgym.register import register_layer
from torch_geometric.graphgym.config import cfg
from ..utils import embed_1D_scalar, nodes_per_batch_sample

# TODO: conditional generation w/ sample_separate_t

@register_layer('graph_features_encoder')
class GraphFeaturesEncoder(nn.Module):
    def __init__(self, out_dim, hidden_dim=256, t_prec=1000):
        """
        Args:
            out_dim (int): Dimension of the output embeddings.
            hidden_dim (int): Dimension of intermediate embeddings.
        """
        super(GraphFeaturesEncoder, self).__init__()
        if cfg.train.mode == "dfm":
            self.feature_mixer = GymMlpMixer(cfg.gt.n_rbf_centroids, hidden_dim, 1, 1) # 3, 1 to 1, 1 if we condition only on a single feature
        else:
            self.feature_mixer = GymMlpMixer(cfg.gt.n_rbf_centroids, hidden_dim, 3, 1, out_fc_dim=out_dim)
        self.time_mixer = GymMlpMixer(hidden_dim, hidden_dim, 2, 1, norm=True, out_fc_dim=out_dim)
        
        # Learned null token embedding of the same dimension as output embeddings
        self.null_token = nn.Parameter(torch.zeros(1, hidden_dim), requires_grad=True)
        # Initialize the null token
        nn.init.normal_(self.null_token, mean=0.0, std=0.02)

        self.hidden_dim = hidden_dim
        self.t_prec = t_prec
        
    def forward(self, batch, unconditional_prop=0):
        """
        Args:
            batch (torch_geometric.data.Data): Data object containing the input features.
            unconditional (float): proportion of samples replaced by the learned null token
        
        Returns:
            torch.Geometric.data.Data: Data object with updated embedding in batch.c.
        """

        # Merge conditioning features
        batch = self.feature_mixer(batch)

        if cfg.train.mode == "dfm": # in order to use this class also for deep simulator

            # During training, replace a fraction of conditional inputs by the null token (mimics learning an unconditional model)
            uncond_idx = np.random.choice(np.arange(len(batch.c)), int(unconditional_prop * len(batch.c)), replace=False)
            batch.c[uncond_idx] = self.null_token

            # if cfg.gt.sample_separate_t:
                # Map t_x values from [0, 1] to integer values
            t_x = batch.t_x * self.t_prec
            if not hasattr(cfg.train, 'distortion_pow_e'): # cheat for backward compatibility --> to suppress
                # In fact these do not have to be integers
                t_x = t_x.int()
            # Sine t positional encoding
            time_embed_x = embed_1D_scalar(t_x, self.hidden_dim, self.t_prec).to(t_x.device)
            batch = self.time_mixer(batch, time_embed_x.unsqueeze(1), dest_feat='c_x')
            # Do the same for t_e
            t_e = batch.t_e * self.t_prec
            if not hasattr(cfg.train, 'distortion_pow_e'):
                t_e = t_e.int()
            time_embed_e = embed_1D_scalar(t_e, self.hidden_dim, self.t_prec).to(t_e.device)
            batch = self.time_mixer(batch, time_embed_e.unsqueeze(1), dest_feat='c_e')
            # Do the same for f
            if cfg.gt.get("sizing", False):
                t_f = batch.t_f * self.t_prec
                time_embed_f = embed_1D_scalar(t_f, self.hidden_dim, self.t_prec).to(t_f.device)
                batch = self.time_mixer(batch, time_embed_f.unsqueeze(1), dest_feat='c_f')
            # else:
            #     # Map t values from [0, 1] to integer values
            #     t = (batch.t * self.t_prec).int()
            #     # Sine t positional encoding
            #     time_embed = embed_1D_scalar(t, self.hidden_dim, self.t_prec).to(t.device)
            #     # Merge t with other conditioning features
            #     batch = self.time_mixer(batch, time_embed.unsqueeze(1))

        return batch
    

@register_layer('time_encoder')
class TimeEncoder(nn.Module):
    def __init__(self, out_dim, hidden_dim=256, t_prec=1000):
        """
        Simplified version of the above, in case one is only interested in conditioning on denoising time index
        Args:
            out_dim (int): Dimension of the output embeddings.
            hidden_dim (int): Dimension of intermediate embeddings.
        """
        super(TimeEncoder, self).__init__()

        self.time_mixer = GymMlpMixer(hidden_dim, hidden_dim, 1, 1, norm=True, out_fc_dim=out_dim)
        self.hidden_dim = hidden_dim
        self.t_prec = t_prec

        
    def forward(self, batch):
        """
        Args:
            batch (torch_geometric.data.Data): Data object containing the input features
        
        Returns:
            torch.Geometric.data.Data: Data object with updated embedding in batch.c
        """
        # if hasattr(cfg.gt, 'sample_separate_t') and cfg.gt.sample_separate_t:
            # Map t values from [0, 1] to integer values
        t_x = batch.t_x * self.t_prec
        if not hasattr(cfg.train, 'distortion_pow_e'): # cheat for backward compatibility --> to suppress
            # In fact these do not have to be integers
            t_x = t_x.int()
        # Sine t positional encoding
        time_embed_x = embed_1D_scalar(t_x, self.hidden_dim, self.t_prec).to(t_x.device)
        # Pass t in batch.c and embed through mlp mixer
        batch.c_init = time_embed_x.unsqueeze(1)
        batch = self.time_mixer(batch, dest_feat='c_x')
        # Do the same for e
        t_e = batch.t_e * self.t_prec
        if not hasattr(cfg.train, 'distortion_pow_e'):
            t_e = t_e.int()
        time_embed_e = embed_1D_scalar(t_e, self.hidden_dim, self.t_prec).to(t_e.device)
        batch.c_init = time_embed_e.unsqueeze(1)
        batch = self.time_mixer(batch, dest_feat='c_e')
        # Do the same for f
        if cfg.gt.get("sizing", False):
            t_f = batch.t_f * self.t_prec
            time_embed_f = embed_1D_scalar(t_f, self.hidden_dim, self.t_prec).to(t_f.device)
            batch.c_init = time_embed_f.unsqueeze(1)
            batch = self.time_mixer(batch, dest_feat='c_f')
        # else:
        #     # Map t values from [0, 1] to integer values
        #     t = (batch.t * self.t_prec).int()
        #     # Sine t positional encoding
        #     time_embed = embed_1D_scalar(t, self.hidden_dim, self.t_prec).to(t.device)
        #     # Pass t in batch.c and embed through mlp mixer
        #     batch.c_init = time_embed.unsqueeze(1)
        #     batch = self.time_mixer(batch)

        return batch
    

class GymMlpMixer(nn.Module):


    def __init__(self, in_pts, out_pts, in_dim, out_dim, norm=False, out_fc_dim=None):
        super().__init__()

        self.depth_wise = nn.Linear(in_pts, out_pts)
        self.point_wise = nn.Linear(in_dim, out_dim)
        if norm:
            self.norm = nn.LayerNorm(out_pts)
        self.act = nn.SiLU()
        if out_fc_dim is not None:
            self.out_fc = nn.Linear(out_pts, out_fc_dim)


    def forward(self, batch, extra_modality=None, dest_feat='c'):
        
        # Expected batch.c input size: bs, n_feats (3 for OCB), dim (typically equal to n_rbf_centroids)
        c_embed = batch.c_init

        # If an extra input is provided (such as another conditioning tensor), first concat the two in the feature axis (dim=1).
        # This supposes that batch.c is already defined.
        if extra_modality is not None:
            c_embed = batch.c
            # If c_x or c_e, broadcast to t_x / t_e dimension
            if dest_feat == 'c_x':
                c_embed = c_embed[batch.batch]
            elif dest_feat == 'c_e':
                nodes_per_graph = nodes_per_batch_sample(batch)
                edges_for_graph = (nodes_per_graph * (nodes_per_graph - 1))
                c_embed = c_embed.repeat_interleave(edges_for_graph, dim=0)
            elif dest_feat == 'c_f':
                c_embed = c_embed[batch.batch]
            c_embed = torch.cat([c_embed, extra_modality], dim=1)
            
        c_embed = self.depth_wise(c_embed)
        c_embed = self.point_wise(c_embed.transpose(1, 2)).transpose(1, 2)
        
        # Depth-wise LN, i.e. separately for each feature
        if hasattr(self, 'norm'):
            c_embed = self.norm(c_embed)
        c_embed = self.act(c_embed)

        if hasattr(self, 'out_fc'):
            c_embed = self.out_fc(c_embed).squeeze() # Output size: bs, out_fc_dim
        
        if dest_feat == 'c':
            batch.c = c_embed
        elif dest_feat == 'c_x':
            batch.c_x = c_embed
        elif dest_feat == 'c_e':
            batch.c_e = c_embed
        elif dest_feat == 'c_f':
            batch.c_f = c_embed
        
        return batch
