import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from torch_geometric.nn import global_mean_pool, global_add_pool
# import sys
# sys.path.append("models")
from .layers import ESLayer, ETLayer, merge_time_dim, separate_time_dim 
from models.EAencoder import EAencoder
import copy
import os



def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    half_dim = embedding_dim // 2
    # magic number 10000 is from transformers
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')
    return emb


class EA(nn.Module):
    def __init__(self, n_layers,n_copy_layer_list , T_all, node_dim, edge_dim, hidden_dim, time_emb_dim, act_fn,
                 learn_ref_frame, n_layers_ref, num_w, scale=1, pre_norm=True):
        super().__init__()
        self.s_modules = nn.ModuleList()
        self.t_modules = nn.ModuleList()
        self.copy_s_modules = nn.ModuleDict()
        self.copy_t_modules = nn.ModuleDict()
        self.n_layers = n_layers
        self.n_copy_layer_list=n_copy_layer_list
        self.time_emb_dim = time_emb_dim
        self.input_linear = nn.Linear(node_dim + time_emb_dim, hidden_dim)

        self.zero_x_params = nn.ParameterDict()
        self.zero_h_params = nn.ParameterDict()
        

        self.learn_ref_frame = learn_ref_frame
        self.n_layers_ref = n_layers_ref
        self.num_w = num_w  # Should normally equal to the length of the predicted trajectory, T_f

        self.scale = scale
        # self.freeze_non_ref_frame()
        
        # Parse activation
        if act_fn == 'silu':
            act_fn = nn.SiLU()
        else:
            raise NotImplementedError(act_fn)

        for i in range(n_layers):
            s_layer = ESLayer(node_dim=hidden_dim, edge_dim=edge_dim, hidden_dim=hidden_dim, act_fn=act_fn, normalize=True, pre_norm=pre_norm)
            t_layer = ETLayer(node_dim=hidden_dim, hidden_dim=hidden_dim, act_fn=act_fn, time_emb_dim=time_emb_dim)

            self.s_modules.append(s_layer)
            self.t_modules.append(t_layer)

            # self.copy_s_modules.append(None)
            # self.copy_t_modules.append(None)
            # self.zero_x_modules.append(None)
            # self.zero_h_modules.append(None)
        
       
            if i in self.n_copy_layer_list:
                layer_idx = str(i)
                copy_s_layer = ESLayer(node_dim=hidden_dim, edge_dim=edge_dim, hidden_dim=hidden_dim, act_fn=act_fn, normalize=True, pre_norm=pre_norm)
                copy_t_layer = ETLayer(node_dim=hidden_dim, hidden_dim=hidden_dim, act_fn=act_fn, time_emb_dim=time_emb_dim)

                zero_x_params = nn.Parameter(torch.zeros(num_w+T_all))

                zero_h_params = nn.Parameter(torch.zeros(hidden_dim))
                self.zero_x_params[layer_idx] = zero_x_params
                self.zero_h_params[layer_idx] = zero_h_params

                copy_s_layer.load_state_dict(copy.deepcopy(self.s_modules[i].state_dict()))
                copy_t_layer.load_state_dict(copy.deepcopy(self.t_modules[i].state_dict()))

                self.copy_s_modules[layer_idx]=copy_s_layer
                self.copy_t_modules[layer_idx]=copy_t_layer

                # self.zero_x_modules[i]=zero_x_layer   
                # self.zero_h_modules[i]=zero_h_layer
            
        # self.zero_x_modules=zero_module(self.zero_x_modules)
        # self.zero_h_modules=zero_module(self.zero_h_modules)
 




    def forward(self, diffusion_t, x_target, x_origin_cond, x_new_cond, h,  model_kwargs):
        """
        :param diffusion_t: The diffusion time step, shape [BN,]
        :param x: shape [BN, 3, T]
        :param h: shape [BN, H] or [BN, H, T
        :param edge_index: shape [2, BM]
        :param edge_attr: shape [BM, He]
        :param batch: shape [BN]
        """
        
        # Get condition mask and concat the condition frames
        
        edge_index, edge_attr, batch= model_kwargs['edge_index'], model_kwargs['edge_attr'], model_kwargs['batch']
        x_target = x_target * self.scale

        T_target =  x_target.size(-1)
        T_origin_cond = x_origin_cond.size(-1)
        T_new_cond = x_new_cond.size(-1)
        # print( "T_new_cond:", T_new_cond)
        T_all= T_target + T_origin_cond + T_new_cond
        
        # T_cond = x_cond.size(-1)
        diffusion_t = get_timestep_embedding(diffusion_t, embedding_dim=self.time_emb_dim)  # [BN, Ht]
        diffusion_t = diffusion_t.unsqueeze(-1).repeat(1, 1, T_all)  # [BN, Ht, T]
        t_emb = diffusion_t
        if h.dim() == 2:
            h = h.unsqueeze(-1).repeat(1, 1,  T_all)
            # print("h shape after unsqueeze:", h.shape)
        else:
            pass
    
        h = torch.cat((h, t_emb), dim=1)  # [BN, Hh+Ht, T]
        # print("input_linear weight:", self.input_linear.weight.device)
        # print("h device:", h.device)
        # print("T_all device:", T_all.device)
        self.input_linear = self.input_linear.to(h.device)
        h_all = separate_time_dim(self.input_linear(merge_time_dim(h)), t = T_all)  # [BN, H, T]
       
        h_new_cond = h_all[:, :, :T_new_cond]  # [BN, H, T_new_cond]
        h_origin_cond = h_all[:, :, T_new_cond:T_new_cond+T_origin_cond]  # [BN, H, T_origin_cond]
        h_target = h_all[:, :, -T_target:]  # [BN, H, T_target]
        
        
        

        # print(h_target.shape)
        
        if x_origin_cond is not None:
            x_f = torch.cat((x_origin_cond, x_target), dim = -1)
            h_f = torch.cat((h_origin_cond, h_target), dim = -1)
            # print("h shape after cat:", h.shape)
        else:
            x_f = x_target
        # print(f"init  x_f : {x_f.shape}, {torch.norm(x_f)}") 
            
        # print( T_origin_cond, T_new_cond, T_target)
        x_input = x_target # Record x in order to subtract it in the end for translation invariance
        # print(f"x_input: {x_input}")
        T_f =  T_origin_cond+ T_target
        T_c = T_new_cond + T_f
        assert h_f.size(-1)==x_f.size(-1)==T_f
        
        # print("h_f shape:", h_f.shape, "x_f shape:", x_f.shape)
       
        if edge_attr is not None:
            edge_attr_f = edge_attr.unsqueeze(-1).repeat(1, 1, T_f)  # [BM, He, T]
            edge_attr_c = edge_attr.unsqueeze(-1).repeat(1, 1, T_c)  # [BM, He, T]

       
        # print(x_f.shape, h_f.shape, edge_index.shape, edge_attr_f.shape, batch.shape)
        # print(x_c.shape, h_c.shape, edge_index.shape, edge_attr_c.shape, batch.shape)
        for i in range(self.n_layers):

            if i in self.n_copy_layer_list:
                # print("copy layer: ", i)
                x_c = torch.cat((x_new_cond, x_f), dim = -1)
                h_c = torch.cat((h_new_cond, h_f), dim = -1)
                assert h_c.size(-1)==x_c.size(-1)==T_c
                # print(i, x_c.shape, h_c.shape)
                x_freeze, h_freeze = self.s_modules[i](x_f, h_f, edge_index, edge_attr_f, batch)
                x_freeze, h_freeze = self.t_modules[i](x_freeze, h_freeze)  

                layer_idx = str(i)
                cs_x, cs_h = self.copy_s_modules[layer_idx](x_c, h_c, edge_index, edge_attr_c, batch)
                x_c, h_c = self.copy_t_modules[layer_idx](cs_x, cs_h)
               
                ct_x_CoM = global_mean_pool(x_c.mean(dim=-1), batch)[batch].unsqueeze(-1)
           
                ct_x_centered = x_c - ct_x_CoM
                z_x = ct_x_centered * self.zero_x_params[layer_idx].view(1, 1, -1)
                # print("zero:",torch.allclose(self.zero_x_params[layer_idx].data, torch.zeros_like(self.zero_x_params[layer_idx].data)))
                # assert torch.allclose(z_x.data, torch.zeros_like(z_x.data))
                z_h = h_c * self.zero_h_params[layer_idx].view(1, -1, 1)
                # print(self.zero_h_params[layer_idx])
                
                x_f = x_freeze + z_x[:, :, -T_f:] 
                h_f = h_freeze + z_h[:, :, -T_f:]
                # print("same:",torch.allclose(x_f.data, x_freeze.data)) 
                assert h_f.size(-1)==x_f.size(-1)==T_f
                
                x_new_cond = x_c[:, :, :T_new_cond]
                h_new_cond = h_c[:, :, :T_new_cond]
            else:
                
                x_f, h_f = self.s_modules[i](x_f, h_f, edge_index, edge_attr_f, batch)
                x_f, h_f = self.t_modules[i](x_f, h_f)  
                # print(f"{i} x_f : {x_f.shape}, {torch.norm(x_f)}") 
            
        # Let x be translation invariant
        # print(f"before x_f: {x_f}")
        x_out = x_f[:, :, -T_target:]  # [BN, 3, T_target
        # print(f"x_out: {x_out}")
        x_out = x_out - x_input
       
        x_out = x_out / self.scale
        h_out = h_f[:, :, -T_target:]
        assert h_out.size(-1)==x_out.size(-1)==20
        return x_out, h_out
    
if __name__ == '__main__':
    import numpy as np

    BN = 5
    B = 2
    Hh = 16
    He = 2
    H = 32
   

    ea = EA(n_layers=6,n_copy_layer_list=[0,1,2],  T_all=15,  node_dim=Hh, edge_dim=He, hidden_dim=H, time_emb_dim=64, act_fn='silu',
                 learn_ref_frame=True, n_layers_ref=2, num_w=20, scale=1, pre_norm=True)

    batch = torch.from_numpy(np.array([0, 0, 0, 1, 1])).long()
    row = [0, 0, 1, 3]
    col = [1, 2, 2, 4]
    row = torch.from_numpy(np.array(row)).long()
    col = torch.from_numpy(np.array(col)).long()

    h = torch.rand(BN, Hh)
    x_cond=torch.rand(BN, 3, 15)
   
    x_target = torch.rand(BN, 3, 20)
    
    
    edge_index = torch.stack((row, col), dim=0)  # [2, BM]
    BM = edge_index.size(-1)
    edge_attr = torch.rand(BM, He)
    t = torch.randint(0, 1000, size=(B,)).to(x_target)[batch]
    model_kwargs={}
    model_kwargs['edge_index'], model_kwargs['edge_attr'], model_kwargs['batch']=edge_index, edge_attr, batch
    
    encoder =  EAencoder(n_layers=0, new_cond_T=5, origin_cond_T=10, node_dim=Hh, edge_dim=He, hidden_dim=H, time_emb_dim=64, act_fn='silu',
                  scale=1, pre_norm=True)

    x_new_cond,  x_origin_cond = encoder( x_cond, model_kwargs)
    x_out, h_out = ea(t, x_target, x_origin_cond, x_new_cond, h, model_kwargs)
    print("x_out shape:", x_out.shape)
    print("h_out shape:", h_out.shape)
    # assert x_out.size() == torch.cat(( x, x), dim=-1).size() 
    # assert h_out.size(0) == torch.cat(( x, x), dim=-1).size(0) 
    assert h_out.size(1) == H
    print('Test successful')

