import torch
import torch.nn as nn


SEQ_SLICE = slice(0, 20)   
SS_SLICE  = slice(20, 28)  
CHEM_SLICE = slice(28, 32) 
ANGLE_SLICE = slice(32, 38) 
RHO_SLICE = slice(38, 43)   


HEAVY_MAX = 15
HEAVY_DIM = HEAVY_MAX * 3  

def make_activation(name: str = "silu"):
    name = (name or "silu").lower()
    if name == "silu":
        return nn.SiLU()
    if name == "relu":
        return nn.ReLU(inplace=True)
    if name == "gelu":
        return nn.GELU()
    if name == "elu":
        return nn.ELU(inplace=True)
    if name == "leaky_relu":
        return nn.LeakyReLU(0.1, inplace=True)
    if name == "tanh":
        return nn.Tanh()
    
    return nn.SiLU()

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, act="silu", dropout=0.0, norm=True):
        super().__init__()
        dims = [in_dim] + list(hidden_dims) + [out_dim]
        layers = []
        for i in range(len(dims)-2):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if norm:
                layers.append(nn.LayerNorm(dims[i+1]))
            layers.append(make_activation(act))
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class NodeEncoder(nn.Module):
    
    def __init__(self, config: dict):
        super().__init__()
        
        self.include_heavy = bool(config.get("include_heavy_atom_coords", False))
        act = config.get("activation", "silu")
        dropout = float(config.get("dropout", 0.0))
        use_ln = bool(config.get("use_layernorm", True))

        
        d_chem_hidden = config.get("chem_hidden", 32)
        d_angle_hidden = config.get("angle_hidden", 32)
        d_rho_hidden = config.get("rho_hidden", 32)
        d_heavy_hidden = config.get("heavy_hidden", 64)  
        d_ss_hidden = config.get("ss_hidden", 64)

        
        d_fuse_hidden = config.get("fuse_hidden", 128)
        d_out = config.get("node_embed_dim", 64)   
        d_ss_out = config.get("ss_embed_dim", 64)   

        
        self.ss_mlp = MLP(in_dim=8, hidden_dims=[d_ss_hidden], out_dim=d_ss_out,
                          act=act, dropout=dropout, norm=use_ln)

        
        self.chem_mlp = MLP(in_dim=4, hidden_dims=[d_chem_hidden], out_dim=d_chem_hidden,
                            act=act, dropout=dropout, norm=use_ln)
        self.angle_mlp = MLP(in_dim=6, hidden_dims=[d_angle_hidden], out_dim=d_angle_hidden,
                             act=act, dropout=dropout, norm=use_ln)
        self.rho_mlp = MLP(in_dim=5, hidden_dims=[d_rho_hidden], out_dim=d_rho_hidden,
                           act=act, dropout=dropout, norm=use_ln)

        if self.include_heavy:
            
            self.heavy_mlp = MLP(in_dim=HEAVY_DIM, hidden_dims=[d_heavy_hidden], out_dim=d_heavy_hidden,
                                 act=act, dropout=dropout, norm=use_ln)

        
        fuse_in = d_chem_hidden + d_angle_hidden + d_rho_hidden + (d_heavy_hidden if self.include_heavy else 0)
        self.fuse_mlp = MLP(in_dim=fuse_in, hidden_dims=[d_fuse_hidden], out_dim=d_out,
                            act=act, dropout=dropout, norm=use_ln)

    @staticmethod
    def _extract_slices(x: torch.Tensor):
        seq = x[:, SEQ_SLICE]      
        ss = x[:, SS_SLICE]        
        chem = x[:, CHEM_SLICE]    
        angle = x[:, ANGLE_SLICE]  
        rho = x[:, RHO_SLICE]      
        return seq, ss, chem, angle, rho

    @staticmethod
    def _heavy_to_local_flat(pos_heavyatom: torch.Tensor,
                             mask_heavyatom: torch.Tensor,
                             R: torch.Tensor,
                             t: torch.Tensor) -> torch.Tensor:
        
        
        rel = pos_heavyatom - t.unsqueeze(1)                
        
        R_T = R.transpose(1, 2)                             
        x_local = torch.einsum('nij,naj->nai', R_T, rel)    
        
        if mask_heavyatom.dtype != torch.bool:
            mask = mask_heavyatom.bool()
        else:
            mask = mask_heavyatom
        x_local = x_local * mask.unsqueeze(-1)              
        
        return x_local.reshape(x_local.size(0), -1)         

    def forward(self, data):
        
        x = data.x
        device, dtype = x.device, x.dtype

        seq, ss_onehot, chem, angle, rho = self._extract_slices(x)

        
        ss_embed = self.ss_mlp(ss_onehot)

        
        chem_feat = self.chem_mlp(chem)
        angle_feat = self.angle_mlp(angle)
        rho_feat = self.rho_mlp(rho)

        feats = [chem_feat, angle_feat, rho_feat]

        
        if self.include_heavy:
            assert hasattr(data, "pos_heavyatom") and hasattr(data, "mask_heavyatom"), \
                "include_heavy_atom_coords=True æ¶ï¼data éåå« pos_heavyatom å mask_heavyatom"
            heavy_flat = self._heavy_to_local_flat(
                data.pos_heavyatom.to(device=device, dtype=dtype),
                data.mask_heavyatom.to(device=device),
                data.R.to(device=device, dtype=dtype),
                data.t.to(device=device, dtype=dtype),
            )
            heavy_feat = self.heavy_mlp(heavy_flat)
            feats.append(heavy_feat)

        fused = torch.cat(feats, dim=-1)
        node_embed = self.fuse_mlp(fused)

        
        return seq, ss_embed, node_embed
