import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .layers import ESLayer, ETLayer, merge_time_dim, separate_time_dim
from torch_geometric.nn import global_mean_pool, global_add_pool



def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    half_dim = embedding_dim // 2
    # magic number 10000 is from transformers
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')
    return emb


class Cencoder(nn.Module):
    def __init__(self, n_layers, addition_T, em_T_out, node_dim, edge_dim, hidden_dim,time_emb_dim , act_fn,
                  scale=1, pre_norm=True):
        super().__init__()
        self.s_modules = nn.ModuleList()
        self.t_modules = nn.ModuleList()
        self.n_layers = n_layers
        self.em_T_out = em_T_out
        self.addition_T = addition_T
        self.time_emb_dim = time_emb_dim
        self.input_linear = nn.Linear(node_dim , hidden_dim)
        # self.em_T_out_linear = nn.Linear(addition_T, em_T_out, bias=False)
        # self.em_H_out_linear = nn.Linear(addition_T, em_T_out, bias=False)
        self.scale = scale

        # Parse activation
        if act_fn == 'silu':
            act_fn = nn.SiLU()
        else:
            raise NotImplementedError(act_fn)

        for i in range(n_layers):
            self.s_modules.append(
                ESLayer(node_dim=hidden_dim, edge_dim=edge_dim, hidden_dim=hidden_dim, act_fn=act_fn,
                        normalize=True, pre_norm=pre_norm)
            )
            self.t_modules.append(
                ETLayer(node_dim=hidden_dim, hidden_dim=hidden_dim, act_fn=act_fn, time_emb_dim=time_emb_dim)
            )
    
    def forward(self, x_cond, h, model_kwargs):
        """
        :param x_cond: shape [BN, 3, T_c]
        :param h: shape [BN, Hh] or [BN, Hh, T_c]
        :param edge_index: shape [2, BM]
        :param edge_attr: shape [BM, He]
        :param batch: shape [BN]    
        """
        edge_index, edge_attr, batch= model_kwargs['edge_index'], model_kwargs['edge_attr'], model_kwargs['batch']
        x_new_cond, x_origin_cond = x_cond.split([self.addition_T, x_cond.size(-1) - self.addition_T], dim=-1)  # [BN, 3, T_addtion_c], [BN, 3, T_left_c]
        
       
        # x_cond = x_cond * self.scale
        T_new = x_new_cond.size(-1)
        T_origin = x_origin_cond.size(-1)

        if h.dim() == 2:
            h_new_cond = h.unsqueeze(-1).repeat(1, 1, T_new) 
            h_origin_cond = h.unsqueeze(-1).repeat(1, 1, T_origin)  # [BN, H, T_c]
        else:
            pass
        # print('h_new_cond',h_new_cond.shape)
        # print(merge_time_dim(h_new_cond).shape)
        h_new_cond =separate_time_dim(self.input_linear(merge_time_dim(h_new_cond)), t=T_new)  # [BN, H, T]
        h_origin_cond =separate_time_dim(self.input_linear(merge_time_dim(h_origin_cond)), t=T_origin)  # [BN, H, T]
        
        if edge_attr is not None:
            edge_attr = edge_attr.unsqueeze(-1).repeat(1, 1, T_new)  # [BM, He, T_c]

        for i in range(self.n_layers):
            x_new_cond, h_new_cond = self.s_modules[i](x_new_cond, h_new_cond, edge_index, edge_attr, batch)
            x_new_cond, h_new_cond = self.t_modules[i](x_new_cond, h_new_cond) # [BN, 3, T_c], [BN, H, T_c]
        
        #adjust output shape 
        # x_cond_CoM = global_mean_pool(x_cond.mean(dim=-1), batch)[batch].unsqueeze(-1)
        # x_centered = x_cond - x_cond_CoM
        # x_em = self.em_T_out_linear(x_centered)
        # x_em = x_em + x_cond_CoM

        # h_CoM = global_mean_pool(h.mean(dim=-1), batch)[batch].unsqueeze(-1)
        # h_centered = h - h_CoM
        # h_em=self.em_H_out_linear(h_centered)
        # h_em=h_em+h_CoM
        x_new_cond_em = x_new_cond / self.scale
        h_new_cond_em = h_new_cond

        return x_new_cond_em, h_new_cond_em, x_origin_cond, h_origin_cond


if __name__ == '__main__':
    import numpy as np

    BN = 5
  
    Hh = 16
    He = 2
    Hid = 32
    T = 5

    model = Cencoder(n_layers=3, addition_T=T, em_T_out=15, node_dim=Hh, edge_dim=He, hidden_dim=Hid, time_emb_dim=64, act_fn='silu',
                  scale=1, pre_norm=True)

    batch = torch.from_numpy(np.array([0, 0, 0, 1, 1])).long()
    row = [0, 0, 1, 3]
    col = [1, 2, 2, 4]
    row = torch.from_numpy(np.array(row)).long()
    col = torch.from_numpy(np.array(col)).long()
    h = torch.rand(BN, Hh)
    x_new_cond = torch.rand(BN, 3, 5)
    edge_index = torch.stack((row, col), dim=0)  # [2, BM]
    BM = edge_index.size(-1)
    edge_attr = torch.rand(BM, He)
    model_kwargs={}
    model_kwargs['edge_index'], model_kwargs['edge_attr'], model_kwargs['batch'] = edge_index, edge_attr, batch

    # t = torch.randint(0, 1000, size=(B,)).to(x_cond)[batch]
    x_out, h_out = model( x_new_cond, h, model_kwargs)
    print(x_out.shape,h_out.shape)
    # assert x_out.size() == x_cond.size()
    # assert h_out.size(0) == x_cond.size(0)
    # assert h_out.size(1) == Hid
    print('Test successful')