
import torch
import torch.nn as nn

def make_activation(name: str = "silu"):
    name = (name or "silu").lower()
    if name == "silu":
        return nn.SiLU()
    if name == "relu":
        return nn.ReLU(inplace=True)
    if name == "gelu":
        return nn.GELU()
    if name == "elu":
        return nn.ELU(inplace=True)
    if name == "leaky_relu":
        return nn.LeakyReLU(0.1, inplace=True)
    if name == "tanh":
        return nn.Tanh()
    return nn.SiLU()

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, act="silu", dropout=0.0, norm=True):
        super().__init__()
        dims = [in_dim] + list(hidden_dims) + [out_dim]
        layers = []
        for i in range(len(dims)-2):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if norm:
                layers.append(nn.LayerNorm(dims[i+1]))
            layers.append(make_activation(act))
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)




class ChainEdgeEncoder(nn.Module):
    
    def __init__(self, config: dict):
        super().__init__()
        act = config.get("activation", "silu")
        dropout = float(config.get("dropout", 0.0))
        use_ln = bool(config.get("use_layernorm", True))
        
        self.seq_dist_cut = int(config.get("seq_dist_cut", 32))
        L = self.seq_dist_cut + 1  
        self.rbf_dim = int(config.get("rbf_dim", 15))  
        self.ori_dim = int(config.get("ori_dim", 12))

        
        d_seq_hidden = config.get("edge_seq_hidden", 32)
        d_contact_hidden = config.get("edge_contact_hidden", 16)
        d_rbf_hidden = config.get("edge_rbf_hidden", 64)
        d_ori_hidden = config.get("edge_ori_hidden", 64)

        
        d_fuse_hidden = config.get("edge_fuse_hidden", 128)
        d_out = config.get("edge_embed_dim", 128)

        
        self.seq_mlp = MLP(L, [d_seq_hidden], d_seq_hidden, act=act, dropout=dropout, norm=use_ln)
        self.contact_mlp = MLP(1, [d_contact_hidden], d_contact_hidden, act=act, dropout=dropout, norm=use_ln)
        self.rbf_mlp = MLP(self.rbf_dim, [d_rbf_hidden], d_rbf_hidden, act=act, dropout=dropout, norm=use_ln)
        self.ori_mlp = MLP(self.ori_dim, [d_ori_hidden], d_ori_hidden, act=act, dropout=dropout, norm=use_ln)

        fuse_in = d_seq_hidden + d_contact_hidden + d_rbf_hidden + d_ori_hidden
        self.fuse_mlp = MLP(fuse_in, [d_fuse_hidden], d_out, act=act, dropout=dropout, norm=use_ln)

        
        self._seq_len = L
        self._contact_len = 1
        self._rbf_len = self.rbf_dim
        self._ori_len = self.ori_dim

    def forward(self, edge_attr: torch.Tensor):
        
        i = 0
        seq_oh = edge_attr[:, i:i+self._seq_len]; i += self._seq_len
        contact = edge_attr[:, i:i+self._contact_len]; i += self._contact_len
        rbf = edge_attr[:, i:i+self._rbf_len]; i += self._rbf_len
        ori = edge_attr[:, i:i+self._ori_len]; i += self._ori_len

        
        s = self.seq_mlp(seq_oh)
        c = self.contact_mlp(contact)
        r = self.rbf_mlp(rbf)
        o = self.ori_mlp(ori)

        fused = torch.cat([s, c, r, o], dim=-1)
        edge_embed = self.fuse_mlp(fused)
        return edge_embed




class InterfaceEdgeEncoder(nn.Module):
    
    def __init__(self, config: dict):
        super().__init__()
        act = config.get("activation", "silu")
        dropout = float(config.get("dropout", 0.0))
        use_ln = bool(config.get("use_layernorm", True))

        self.rbf_dim = int(config.get("rbf_dim", 15))
        self.ori_dim = int(config.get("ori_dim", 12))
        self.chem_dim = 4
        self.inter_dim = 4

        
        d_rbf_hidden = config.get("if_edge_rbf_hidden", 64)
        d_ori_hidden = config.get("if_edge_ori_hidden", 64)
        d_chem_bin_hidden = config.get("if_edge_chem_bin_hidden", 16)
        d_chem_cont_hidden = config.get("if_edge_chem_cont_hidden", 16)
        d_inter_hidden = config.get("if_edge_inter_hidden", 32)

        d_fuse_hidden = config.get("if_edge_fuse_hidden", 128)
        d_out = config.get("if_edge_embed_dim", 64)

        
        self.rbf_mlp = MLP(self.rbf_dim, [d_rbf_hidden], d_rbf_hidden, act=act, dropout=dropout, norm=use_ln)
        self.ori_mlp = MLP(self.ori_dim, [d_ori_hidden], d_ori_hidden, act=act, dropout=dropout, norm=use_ln)
        self.chem_bin_mlp = MLP(2, [d_chem_bin_hidden], d_chem_bin_hidden, act=act, dropout=dropout, norm=use_ln)
        self.chem_cont_mlp = MLP(2, [d_chem_cont_hidden], d_chem_cont_hidden, act=act, dropout=dropout, norm=use_ln)
        self.inter_mlp = MLP(self.inter_dim, [d_inter_hidden], d_inter_hidden, act=act, dropout=dropout, norm=use_ln)

        fuse_in = d_rbf_hidden + d_ori_hidden + d_chem_bin_hidden + d_chem_cont_hidden + d_inter_hidden
        self.fuse_mlp = MLP(fuse_in, [d_fuse_hidden], d_out, act=act, dropout=dropout, norm=use_ln)

        
        self._rbf_len = self.rbf_dim
        self._ori_len = self.ori_dim
        self._chem_len = self.chem_dim
        self._inter_len = self.inter_dim

    def forward(self, edge_attr: torch.Tensor):
        i = 0
        rbf = edge_attr[:, i:i+self._rbf_len]; i += self._rbf_len
        ori = edge_attr[:, i:i+self._ori_len]; i += self._ori_len
        chem = edge_attr[:, i:i+self._chem_len]; i += self._chem_len
        inter = edge_attr[:, i:i+self._inter_len]; i += self._inter_len

        
        chem_bin = chem[:, :2]
        chem_cont = chem[:, 2:]

        r = self.rbf_mlp(rbf)
        o = self.ori_mlp(ori)
        cb = self.chem_bin_mlp(chem_bin)
        cc = self.chem_cont_mlp(chem_cont)
        it = self.inter_mlp(inter)

        fused = torch.cat([r, o, cb, cc, it], dim=-1)
        edge_embed = self.fuse_mlp(fused)
        return edge_embed
