import torch
from torch import nn
import torch.nn.functional as F

from lib.tgt import TGT_Encoder

from . import layers

def print_statistics(tensor, name):
    if tensor.device.index == 0:  # 确保只在主显卡上打印
        print(f"{name} min: {tensor.min()}, max: {tensor.max()}, mean: {tensor.mean()}")

class TGT_Distance(nn.Module):
    def __init__(self,
                 model_height,
                 layer_multiplier    = 1         ,
                 upto_hop            = 32        ,
                 embed_3d_type       = 'gaussian',
                 num_3d_kernels      = 128       ,
                 num_dist_bins       = 128       ,
                 **layer_configs
                 ):
        super().__init__()
        
        self.model_height        = model_height
        self.layer_multiplier    = layer_multiplier
        self.upto_hop            = upto_hop
        self.embed_3d_type       = embed_3d_type
        self.num_3d_kernels      = num_3d_kernels
        self.num_dist_bins       = num_dist_bins
        
        self.node_width          = layer_configs['node_width']
        self.edge_width          = layer_configs['edge_width']
        self.angle_width         = layer_configs['angle_width']
        self.torsion_angle_width = layer_configs['torsion_angle_width']
        
        self.layer_configs = layer_configs
        self.encoder = TGT_Encoder(model_height     = self.model_height      ,
                                   layer_multiplier = self.layer_multiplier  ,
                                   node_ended       = False                  ,
                                   edge_ended       = True                   ,
                                   egt_simple       = False                  ,
                                   **self.layer_configs)
        
        self.input_embed = layers.EmbedInput(node_width      = self.node_width     ,
                                             edge_width      = self.edge_width     ,
                                             angle_width      = self.angle_width   ,
                                             torsion_angle_width = self.torsion_angle_width,
                                             upto_hop        = self.upto_hop       ,
                                             embed_3d_type   = self.embed_3d_type  ,
                                             num_3d_kernels  = self.num_3d_kernels )
        
        self.final_ln_edge = nn.LayerNorm(self.edge_width)
        self.final_ln_angle = nn.LayerNorm(self.angle_width)
        self.final_ln_torsion_angle = nn.LayerNorm(self.torsion_angle_width)
        self.dist_pred = nn.Linear(self.edge_width, num_dist_bins)
        self.angle_pred = nn.Linear(self.angle_width, num_dist_bins)
        self.torsion_angle_pred = nn.Linear(self.torsion_angle_width, num_dist_bins)
    
    def forward(self, inputs, e_repr=None, p_repr=None, t_repr=None):
        g = self.input_embed(inputs)

        if e_repr is not None:
            g.e[:, 1:, 1:] = g.e[:, 1:, 1:] + e_repr
        if p_repr is not None:
            g.p[:, 1:] = g.p[:, 1:] + p_repr
        if t_repr is not None:
            g.t[:, 1:] = g.t[:, 1:] + t_repr
        g = self.encoder(g)
 
        
        e = g.e
        e_out_repr = self.final_ln_edge(e)
        e = self.dist_pred(e_out_repr)

        p = g.p
        p_out_repr = self.final_ln_angle(p)
        p = self.angle_pred(p_out_repr)

        t = g.t
        t_out_repr = self.final_ln_torsion_angle(t)
        t = self.torsion_angle_pred(t_out_repr)
        # print("Shape of e:", e[:, 1:, 1:,:].size())
        # print("Shape of p:", p[:, 1:,:].size())
        # print("Shape of t:", t.size())

        # print("Shape of e_out:", e_out_repr[:, 1:, 1:,:].size())
        # print("Shape of p_out:", p_out_repr[:, 1:,:].size())
        # print("Shape of t_out:", t_out_repr.size())
        return e[:, 1:, 1:,:], p[:, 1:,:], t[:, 1:,:], e_out_repr[:, 1:, 1:,:], p_out_repr[:, 1:,:], t_out_repr[:, 1:,:]
