import torch
from torch import nn
import torch.nn.functional as F

from lib.tgt import TGT_Encoder

from . import layers
from . import consts as C

def print_statistics(tensor, name):
    if tensor.device.index == 0:  # 确保只在主显卡上打印
        print(f"{name} min: {tensor.min()}, max: {tensor.max()}, mean: {tensor.mean()}")

class TGT_Multi(nn.Module):
    def __init__(self,
                 model_height,
                 layer_multiplier    = 1         ,
                 upto_hop            = 32        ,
                 embed_3d_type       = 'gaussian',
                 num_3d_kernels      = 128       ,
                 num_dist_bins       = 128       ,
                 **layer_configs
                 ):
        super().__init__()
        
        self.model_height        = model_height
        self.layer_multiplier    = layer_multiplier
        self.upto_hop            = upto_hop
        self.embed_3d_type       = embed_3d_type
        self.num_3d_kernels      = num_3d_kernels
        self.num_dist_bins       = num_dist_bins
        
        self.node_width          = layer_configs['node_width']
        self.edge_width          = layer_configs['edge_width']
        self.angle_width         = layer_configs['angle_width']
        self.torsion_angle_width = layer_configs['torsion_angle_width']
        
        self.layer_configs = layer_configs
        self.encoder = TGT_Encoder(model_height     = self.model_height      ,
                                   layer_multiplier = self.layer_multiplier  ,
                                   node_ended       = True                   ,
                                   edge_ended       = True                   ,
                                   egt_simple       = False                  ,
                                   **self.layer_configs)
        
        self.input_embed = layers.EmbedInput(node_width      = self.node_width     ,
                                             edge_width      = self.edge_width     ,
                                             angle_width      = self.angle_width   ,
                                             torsion_angle_width = self.torsion_angle_width,
                                             upto_hop        = self.upto_hop       ,
                                             embed_3d_type   = self.embed_3d_type  ,
                                             num_3d_kernels  = self.num_3d_kernels )
        
        self.final_ln_node = nn.LayerNorm(self.node_width)
        self.pred = nn.Linear(self.node_width, 1)
        self.angle_pred_prop = nn.Linear(self.angle_width, 1)
        self.torsion_angle_pred_prop = nn.Linear(self.torsion_angle_width, 1)
        nn.init.constant_(self.pred.bias, C.HL_MEAN)
        
        self.final_ln_edge = nn.LayerNorm(self.edge_width)
        self.final_ln_angle = nn.LayerNorm(self.angle_width)
        self.final_ln_torsion_angle = nn.LayerNorm(self.torsion_angle_width)
        self.dist_pred = nn.Linear(self.edge_width, num_dist_bins)
        self.angle_pred = nn.Linear(self.angle_width, num_dist_bins)
        self.torsion_angle_pred = nn.Linear(self.torsion_angle_width, num_dist_bins)

    def forward(self, inputs):
        g = self.input_embed(inputs)
        g = self.encoder(g)
        
        # h = g.h[:, 0,:]
        # h = g.h[:, 1:,:]
        h = self.final_ln_node(h)
        
        # nodem = g.node_mask.float().unsqueeze(dim=-1)
        # h = (h*nodem).sum(dim=1)/(nodem.sum(dim=1)+1e-9)
        
        h = self.pred(h).squeeze(dim=-1)
        
        e = g.e
        e = self.final_ln_edge(e)
        e = self.dist_pred(e)

        p = g.p
        p = self.final_ln_angle(p)
        p4h = self.angle_pred_prop(p[:, 0,:]).squeeze(dim=-1)
        p = self.angle_pred(p)

        t = g.t
        t = self.final_ln_torsion_angle(t)
        t4h = self.torsion_angle_pred_prop(t[:, 0,:]).squeeze(dim=-1)
        t = self.torsion_angle_pred(t)
        
        h = h + p4h + t4h

        return h, e[:, 1:, 1:,:], p[:, 1:,:], t[:, 1:,:]
