
import math

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np 

from . import consts as C
from lib.tgt import Graph

class EmbedInput(nn.Module):
    def __init__(self,
                 node_width,
                 edge_width,
                 angle_width,
                 torsion_angle_width,
                 upto_hop            = 32        ,
                 embed_3d_type       = 'gaussian',
                 num_3d_kernels      = 128       ,
                 ):
        super().__init__()
        
        self.node_width         = node_width
        self.edge_width         = edge_width
        self.angle_width        = angle_width
        self.upto_hop           = upto_hop
        self.num_3d_kernels     = num_3d_kernels
        self.embed_3d_type      = embed_3d_type
        self.angle_width = angle_width
        self.torsion_angle_width= torsion_angle_width
        
        self.nodef_embed = nn.Embedding(C.NUM_NODE_FEATURES*C.NODE_FEATURES_OFFSET+1,
                                        self.node_width, padding_idx=0)
        
        self.dist_embed = nn.Embedding(self.upto_hop+2, self.edge_width)
        self.featm_embed = nn.Embedding(C.NUM_EDGE_FEATURES*C.EDGE_FEATURES_OFFSET+1,
                                        self.edge_width, padding_idx=0)
        
        if self.embed_3d_type == 'gaussian':
            self.m3d_embed = Gaussian3DEmbed(self.edge_width,
                                            2*C.NODE_FEATURES_OFFSET+1,
                                            self.num_3d_kernels)
            self._node_j_offset = C.NODE_FEATURES_OFFSET
            self._edge_jk_offset = C.EDGE_FEATURES_OFFSET
            self.angle_embed_3d = Gaussian3DAngleEmbed(self.angle_width,3*C.NODE_FEATURES_OFFSET+\
            2* C.EDGE_FEATURES_OFFSET+1,self.num_3d_kernels)
            self.torsion_embed_3d = Gaussian3DAngleEmbed(self.torsion_angle_width,4*C.NODE_FEATURES_OFFSET+\
            3* C.EDGE_FEATURES_OFFSET+1,self.num_3d_kernels)
        elif self.embed_3d_type == 'fourier':
            self.m3d_embed = Fourier3DEmbed(self.edge_width,
                                            self.num_3d_kernels)
        elif self.embed_3d_type != 'none':
            raise ValueError('Invalid 3D embedding type')

        self.graph_token_node = nn.Embedding(1, self.node_width)
        self.graph_token_edge = nn.Embedding(1, self.edge_width)
        self.graph_token_angle = nn.Embedding(1, self.angle_width)
        self.graph_token_torsion = nn.Embedding(1, self.torsion_angle_width)
        
        self._uses_3d = (self.embed_3d_type != 'none')
        

    def embed_3d_dist(self, dist_input, nodef):
        if self.embed_3d_type == 'gaussian':
            num_nodes = nodef.size(1)
            nodes_i = nodef[:,:,0]                                  # (b,i)
            nodes_j = nodes_i + self._node_j_offset
            nodes_i = nodes_i.unsqueeze(2).expand(-1,-1,num_nodes)  # (b,i,j)
            nodes_j = nodes_j.unsqueeze(1).expand(-1,num_nodes,-1)  # (b,i,j)
            nodes_ij = torch.stack([nodes_i,nodes_j], dim=-1)       # (b,i,j,2)
            return self.m3d_embed(dist_input, nodes_ij)
        elif self.embed_3d_type == 'fourier':
            return self.m3d_embed(dist_input)
        else:
            raise ValueError('Invalid 3D embedding type')

    # def embed_3d_angle(self, angle_input, nodef, edgef):
    #     num_nodes = nodef.size(1)
    #     nodes_i = nodef[:,:,0]                                  # (b,i)
    #     nodes_j = nodes_i + self._node_j_offset
    #     nodes_k = nodes_j + self._node_j_offset
    #     nodes_i = nodes_i.unsqueeze(2).unsqueeze(3).expand(-1, -1, num_nodes, num_nodes)  # (b, i, j, k)
    #     nodes_j = nodes_j.unsqueeze(1).unsqueeze(3).expand(-1, num_nodes, -1, num_nodes)  # (b, i, j, k)
    #     nodes_k = nodes_k.unsqueeze(1).unsqueeze(2).expand(-1, num_nodes, num_nodes, -1)  # (b, i, j, k)
    #     featm_plane_ij = edgef.unsqueeze(-2).expand(-1,-1,-1,num_nodes,-1)[:,:,:,:,0] # (b,i,j,k)
    #     featm_plane_jk = edgef.transpose(1, 2).unsqueeze(-2).expand(-1, -1, -1, num_nodes, -1).transpose(1, 3)[:,:,:,:,0]\
    #                      + self._edge_jk_offset # (b,i,j,k)
    #     nodes_ijk = torch.stack([nodes_i,featm_plane_ij,nodes_j,featm_plane_jk,nodes_k], dim=-1)       # (b,i,j,k,5)

    #     return self.angle_embed_3d(angle_input, nodes_ijk)

    def embed_3d_angle(self, angle_input, g):
        # 提取 angle_indices
        ai = g.angle_indices[:, :, 0]  # Shape: (batch_size, max_angles)
        aj = g.angle_indices[:, :, 1]
        ak = g.angle_indices[:, :, 2]

        batch_size = g.node_features.size(0)
        num_angles = g.angle_indices.size(1)
        

        nodef = g.node_features.long()  # (b,i,f)
        node_i = nodef[torch.arange(batch_size).unsqueeze(1), ai]  # (batch_size, max_angles, f)
        node_j = nodef[torch.arange(batch_size).unsqueeze(1), aj]
        node_k = nodef[torch.arange(batch_size).unsqueeze(1), ak]


        nodes_i = node_i[:, :, 0]  # (batch_size, max_angles)
        nodes_j = node_j[:, :, 0] + self._node_j_offset
        nodes_k = node_k[:, :, 0] + self._node_j_offset*2


        edgef = g.feature_matrix.long()  # (b,i,j,f)
        featm_plane_ij = edgef[torch.arange(batch_size).unsqueeze(1), ai, aj, 0]  # (batch_size, max_angles)
        featm_plane_jk = edgef[torch.arange(batch_size).unsqueeze(1), aj, ak, 0] + self._edge_jk_offset


        nodes_ijk = torch.stack([nodes_i, featm_plane_ij, nodes_j, featm_plane_jk, nodes_k], dim=-1)  # (b, max_angles, 5)

        return self.angle_embed_3d(angle_input, nodes_ijk)

    def embed_3d_torsion(self, torsion_input, g):

        ai = g.torsion_indices[:, :, 0]  # Shape: (batch_size, max_torsions)
        aj = g.torsion_indices[:, :, 1]
        ak = g.torsion_indices[:, :, 2]
        al = g.torsion_indices[:, :, 3]

        batch_size = g.node_features.size(0)
        num_torsions = g.torsion_indices.size(1)
        

        nodef = g.node_features.long()  # (b,i,f)
        node_i = nodef[torch.arange(batch_size).unsqueeze(1), ai]  # (batch_size, max_torsions, f)
        node_j = nodef[torch.arange(batch_size).unsqueeze(1), aj]
        node_k = nodef[torch.arange(batch_size).unsqueeze(1), ak]
        node_l = nodef[torch.arange(batch_size).unsqueeze(1), al]


        nodes_i = node_i[:, :, 0]  # (batch_size, max_torsions)
        nodes_j = node_j[:, :, 0] + self._node_j_offset
        nodes_k = node_k[:, :, 0] + self._node_j_offset*2
        nodes_l = node_l[:, :, 0] + self._node_j_offset*3


        edgef = g.feature_matrix.long()  # (b,i,j,f)
        featm_plane_ij = edgef[torch.arange(batch_size).unsqueeze(1), ai, aj, 0]  # (batch_size, max_torsions)
        featm_plane_jk = edgef[torch.arange(batch_size).unsqueeze(1), aj, ak, 0] + self._edge_jk_offset
        featm_plane_kl = edgef[torch.arange(batch_size).unsqueeze(1), ak, al, 0] + self._edge_jk_offset*2


        nodes_ijk = torch.stack([nodes_i, featm_plane_ij, nodes_j, featm_plane_jk, nodes_k, featm_plane_kl, nodes_l], dim=-1)  # (b, max_torsions, 7)

        return self.torsion_embed_3d(torsion_input, nodes_ijk)

    def forward(self, inputs):
        g = Graph(inputs)
        
        nodef = g.node_features.long()              # (b,i,f)
        h = self.nodef_embed(nodef).sum(dim=2)      # (b,i,w,h) -> (b,i,h)


        graph_token_node = self.graph_token_node.weight.repeat(h.size(0), 1, 1)
        h = torch.cat([graph_token_node, h], dim=1)  # (b, i+1, h)
        
        dm0 = g.distance_matrix                     # (b,i,j)
        dm = dm0.long().clamp(max=self.upto_hop+1)  # (b,i,j)
        featm = g.feature_matrix.long()             # (b,i,j,f)

        e = self.dist_embed(dm) \
                + self.featm_embed(featm).sum(dim=-2)  # (b,i,j,f,e) -> (b,i,j,e)       

        if self._uses_3d:
            e = e + self.embed_3d_dist(g.dist_input, nodef)
            p = self.embed_3d_angle(g.angle_input.to(nodef.device), g)
            t = self.embed_3d_torsion(g.torsion_input.to(nodef.device), g)


        graph_token_edge = self.graph_token_edge.weight.view(1, 1, 1, self.edge_width).repeat(e.size(0), 1, e.size(2), 1)

        e = torch.cat([graph_token_edge, e], dim=1)  # (b, i+1, j, e)
        graph_token_edge = self.graph_token_edge.weight.view(1, 1, 1, self.edge_width).repeat(e.size(0), e.size(1), 1, 1)
        e = torch.cat([graph_token_edge, e], dim=2)  # (b, i+1, j+1, e)

        mask_dtype = e.dtype
        edge_mask = g.edge_mask.unsqueeze(-1).to(mask_dtype)

        virtual_edge_mask = torch.ones(edge_mask.size(0), 1, edge_mask.size(2),1, dtype=edge_mask.dtype, device=edge_mask.device)
        edge_mask = torch.cat([virtual_edge_mask, edge_mask], dim=1)
        virtual_edge_mask = torch.ones(edge_mask.size(0), edge_mask.size(1), 1,1, dtype=edge_mask.dtype, device=edge_mask.device)
        edge_mask = torch.cat([virtual_edge_mask, edge_mask], dim=2)

        mask = (1 - edge_mask) * torch.finfo(mask_dtype).min


        graph_token_angle = self.graph_token_angle.weight.repeat(h.size(0), 1, 1)
        p = torch.cat([graph_token_angle, p], dim=1)  # (b, i+1, h)

        angle_mask_dtype = p.dtype

        angle_mask = g.angle_mask.unsqueeze(-1).to(angle_mask_dtype)  

        virtual_angle_mask = torch.ones((angle_mask.size(0), 1, 1), dtype=angle_mask.dtype, device=angle_mask.device)

        angle_mask = torch.cat([virtual_angle_mask, angle_mask], dim=1).to(nodef.device)

        mask_a = (1 - angle_mask) * torch.finfo(angle_mask_dtype).min


        graph_token_torsion = self.graph_token_torsion.weight.repeat(h.size(0), 1, 1)
        t = torch.cat([graph_token_torsion, t], dim=1)  # (b, i+1, h)


        torsion_mask_dtype = t.dtype


        torsion_mask = g.torsion_mask.unsqueeze(-1).to(torsion_mask_dtype)  


        virtual_torsion_mask = torch.ones((torsion_mask.size(0), 1, 1), dtype=torsion_mask.dtype, device=torsion_mask.device)
  
        torsion_mask = torch.cat([virtual_torsion_mask, torsion_mask], dim=1).to(nodef.device)

        mask_t = (1 - torsion_mask) * torch.finfo(torsion_mask_dtype).min
        # h = torch.log1p(torch.abs(h)) * torch.sign(h)
        # e = torch.log1p(torch.abs(e)) * torch.sign(e)
        # p = torch.log1p(torch.abs(p)) * torch.sign(p)
        # t = torch.log1p(torch.abs(t)) * torch.sign(t)
        # print_statistics(h, 'h_embedding')
        # print_statistics(e, 'e_embedding')
        # print_statistics(p, 'p_embedding')
        # print_statistics(t, 't_embedding')
        g.h, g.e, g.p, g.t, g.mask, g.mask_a, g.mask_t = h, e, p, t, mask, mask_a, mask_t 
        return g

class Fourier3DEmbed(nn.Module):
    def __init__(self, num_heads, num_kernel,
                 min_dist=0.01, max_dist=20):
        assert num_kernel % 2 == 0
        
        super().__init__()
        self.num_heads = num_heads
        self.num_kernel = num_kernel
        self.min_dist = min_dist
        self.max_dist = max_dist
        
        wave_lengths = torch.exp(torch.linspace(math.log(2*min_dist),
                                                math.log(2*max_dist),
                                                num_kernel // 2))
        angular_freqs = 2 * math.pi / wave_lengths
        self.register_buffer('angular_freqs', angular_freqs)
        
        self.proj = nn.Linear(num_kernel, num_heads)
    
    def forward(self, dist):
        phase = dist.unsqueeze(-1) * self.angular_freqs
        sinusoids = torch.cat([torch.sin(phase), torch.cos(phase)], dim=-1)
        out = self.proj(sinusoids)
        return out


class Gaussian3DEmbed(nn.Module):
    def __init__(self, num_heads, num_edges, num_kernel):
        super(Gaussian3DEmbed, self).__init__()
        self.num_heads = num_heads
        self.num_edges = num_edges
        self.num_kernel = num_kernel

        self.gbf = GaussianLayer(self.num_kernel, num_edges)
        self.gbf_proj = NonLinear(self.num_kernel, self.num_heads)


    def forward(self, dist, node_type_edge):
        edge_feature = self.gbf(dist, node_type_edge.long())
        gbf_result = self.gbf_proj(edge_feature)
        return gbf_result


class Gaussian3DAngleEmbed(nn.Module):
    def __init__(self, num_heads, num_edges, num_kernel):
        super(Gaussian3DAngleEmbed, self).__init__()
        self.num_heads = num_heads
        self.num_edges = num_edges
        self.num_kernel = num_kernel

        self.gbf = GaussianLayer4angle(self.num_kernel, num_edges)
        self.gbf_proj = NonLinear(self.num_kernel, self.num_heads)


    def forward(self, angle, node_type_angle):
        angle_feature = self.gbf(angle, node_type_angle.long())
        gbf_result = self.gbf_proj(angle_feature)
        return gbf_result



@torch.jit.script
def gaussian(x, mean, std):
    pi = 3.14159
    a = (2*pi) ** 0.5
    return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)

@torch.jit.script
def circular_gaussian(x, mean, std):
    pi = 3.14159

    diff = torch.atan2(torch.sin(x - mean), torch.cos(x - mean))

    return torch.exp(-0.5 * ((diff / std) ** 2)) / (std * (2 * pi) ** 0.5)


class GaussianLayer(nn.Module):
    def __init__(self, K=128, edge_types=512*3):
        super().__init__()
        self.K = K
        self.means = nn.Embedding(1, K)
        self.stds = nn.Embedding(1, K)
        self.mul = nn.Embedding(edge_types, 1, padding_idx=0)
        self.bias = nn.Embedding(edge_types, 1, padding_idx=0)
        nn.init.uniform_(self.means.weight, 0, 3)
        nn.init.uniform_(self.stds.weight, 0, 3)
        nn.init.constant_(self.bias.weight, 0)
        nn.init.constant_(self.mul.weight, 1)

    def forward(self, x, edge_types):
        mul = self.mul(edge_types).sum(dim=-2)
        bias = self.bias(edge_types).sum(dim=-2)
        x = mul * x.unsqueeze(-1) + bias
        x = x.expand(-1, -1, -1, self.K)
        mean = self.means.weight.float().view(-1)
        std = self.stds.weight.float().view(-1).abs() + 1e-2
        return gaussian(x.float(), mean, std).type_as(self.means.weight)

class GaussianLayer4angle(nn.Module):
    def __init__(self, K=128, angle_types=512*3):
        super().__init__()
        self.K = K
        self.means = nn.Embedding(1, K)
        self.stds = nn.Embedding(1, K)
        self.mul = nn.Embedding(angle_types, 1, padding_idx=0)
        self.bias = nn.Embedding(angle_types, 1, padding_idx=0)
        
        # Initialize means and stds uniformly within a typical range for angles (in radians)
        nn.init.uniform_(self.means.weight, -3.14159, 3.14159)  # Means range adjusted for radians
        nn.init.uniform_(self.stds.weight, 0, 3.14159/2)  # Std range ensuring non-zero

        # Initialize multipliers and biases
        nn.init.constant_(self.bias.weight, 0)
        nn.init.constant_(self.mul.weight, 1)

    def forward(self, angle_input, angle_types):
        # angle_input shape is [batch_size, max_angle_num]
        # angle_types shape is [batch_size, max_angle_num, 5]
        
        # Apply angle type specific transformations
        mul = self.mul(angle_types).sum(dim=-2)   # Shape becomes [b, max_angle_num, 1]
        bias = self.bias(angle_types).sum(dim=-2) # Shape becomes [b, max_angle_num, 1]

        # Apply transformations to angles and expand for Gaussian computation
        transformed_angles = (mul * angle_input.unsqueeze(-1) + bias).expand(-1, -1, self.K)

        # Retrieve means and stds
        mean = self.means.weight.float().view(1, 1, self.K)
        std = self.stds.weight.float().view(1, 1, self.K).abs() + 1e-2
        
        # Compute Gaussian function
        gaussian_output = circular_gaussian(transformed_angles.float(), mean, std).type_as(self.means.weight)

        return gaussian_output

class NonLinear(nn.Module):
    def __init__(self, input, output_size, hidden=None):
        super(NonLinear, self).__init__()

        if hidden is None:
            hidden = input
        self.layer1 = nn.Linear(input, hidden)
        self.layer2 = nn.Linear(hidden, output_size)

    def forward(self, x):
        x = self.layer1(x)
        x = F.gelu(x)
        x = self.layer2(x)
        return x

