import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .layers import ESLayer, ETLayer, merge_time_dim, separate_time_dim
from torch_geometric.nn import global_mean_pool, global_add_pool



def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    half_dim = embedding_dim // 2
    # magic number 10000 is from transformers
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')
    return emb


class EAencoder(nn.Module):
    def __init__(self, n_layers, new_cond_T,  origin_cond_T):
        super().__init__()
        
        self.n_layers = n_layers
        self.origin_cond_T = origin_cond_T
        self.new_cond_T = new_cond_T

     
    def forward(self, x_cond, model_kwargs):
        """
        :param x_cond: shape [BN, 3, T_c]
        :param h: shape [BN, Hh] or [BN, Hh, T_c]
        :param edge_index: shape [2, BM]
        :param edge_attr: shape [BM, He]
        :param batch: shape [BN]    
        """
        x_new_cond, x_origin_cond = x_cond.split([self.new_cond_T, self.origin_cond_T], dim=-1)  # [BN, 3, T_addtion_c], [BN, 3, T_left_c]
        
        return x_new_cond, x_origin_cond


if __name__ == '__main__':
    import numpy as np

    BN = 5
  
    Hh = 16
    He = 2
    Hid = 32
    T = 5

    model = EAencoder(n_layers=0, new_cond_T=5, origin_cond_T=10, node_dim=Hh, edge_dim=He, hidden_dim=Hid, time_emb_dim=64, act_fn='silu',
                  scale=1, pre_norm=True)

    batch = torch.from_numpy(np.array([0, 0, 0, 1, 1])).long()
    row = [0, 0, 1, 3]
    col = [1, 2, 2, 4]
    row = torch.from_numpy(np.array(row)).long()
    col = torch.from_numpy(np.array(col)).long()
    h = torch.rand(BN, Hh)
    x_cond = torch.rand(BN, 3, 15)
    edge_index = torch.stack((row, col), dim=0)  # [2, BM]
    BM = edge_index.size(-1)
    edge_attr = torch.rand(BM, He)
    model_kwargs={}
    model_kwargs['edge_index'], model_kwargs['edge_attr'], model_kwargs['batch'] = edge_index, edge_attr, batch

    # t = torch.randint(0, 1000, size=(B,)).to(x_cond)[batch]
    x_new_cond,  x_origin_cond = model( x_cond, model_kwargs)
    print(x_new_cond.shape, x_origin_cond.shape)
    # assert x_out.size() == x_cond.size()
    # assert h_out.size(0) == x_cond.size(0)
    # assert h_out.size(1) == Hid
    print('Test successful')