import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from torch_scatter import scatter_add, scatter_softmax
from Bio.PDB import PDBParser, DSSP, NeighborSearch, Selection, Residue
from enum import Enum

class AA(Enum):
    ALA = 'ALA'; ARG = 'ARG'; ASN = 'ASN'; ASP = 'ASP'; CYS = 'CYS'
    GLN = 'GLN'; GLU = 'GLU'; GLY = 'GLY'; HIS = 'HIS'; ILE = 'ILE'
    LEU = 'LEU'; LYS = 'LYS'; MET = 'MET'; PHE = 'PHE'; PRO = 'PRO'
    SER = 'SER'; THR = 'THR'; TRP = 'TRP'; TYR = 'TYR'; VAL = 'VAL'
    MSE = 'MSE'; HSE = 'HSE'; HSD = 'HSD'; HSP = 'HSP'; SEC = 'SEC'
    SEP = 'SEP'; TPO = 'TPO'; PTR = 'PTR'; XLE = 'XLE'

THREE_TO_ONE = {
        'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
        'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
        'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
        'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
        'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
        'MSE': 'M', 'HSE': 'H', 'HSD': 'H', 'HSP': 'H',
        'SEC': 'C', 'SEP': 'S', 'TPO': 'T', 'PTR': 'Y',
        'XLE': 'L',
    }







NEG_RES = {'ASP', 'GLU'}
POS_RES = {'LYS', 'ARG', 'HIS'}
NEG_ATOMS = {'ASP': ['OD1', 'OD2'], 'GLU': ['OE1', 'OE2']}
POS_ATOMS = {'LYS': ['NZ'], 'ARG': ['NH1', 'NH2'], 'HIS': ['ND1', 'NE2']}


AROMATIC_RES = {'PHE', 'TYR', 'TRP'}
AROMATIC_ATOMS = {
    'PHE': ['CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
    'TYR': ['CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
    'TRP': ['CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
}


HBOND_DONORS = {
    'ARG': ['NE', 'NH1', 'NH2'], 'HIS': ['ND1', 'NE2'], 'LYS': ['NZ'],
    'ASN': ['ND2'], 'GLN': ['NE2'], 'SER': ['OG'], 'THR': ['OG1'],
    'TRP': ['NE1'], 'TYR': ['OH']
}
HBOND_ACCEPTORS = {
    'ASP': ['OD1', 'OD2'], 'GLU': ['OE1', 'OE2'], 'HIS': ['ND1', 'NE2'],
    'ASN': ['OD1'], 'GLN': ['OE1'], 'SER': ['OG'], 'THR': ['OG1'],
    'TYR': ['OH']
}

def get_atom_coords(res, atom_names):
    
    coords = []
    for atom_name in atom_names:
        if atom_name in res:
            coords.append(res[atom_name].get_coord())
    return np.array(coords)

def get_centroid(coords):
    
    return np.mean(coords, axis=0) if len(coords) > 0 else None

def check_salt_bridge(res1, res2, cutoff=4.0):
    
    res1_name, res2_name = res1.get_resname(), res2.get_resname()
    
    pair1 = (res1_name in NEG_RES and res2_name in POS_RES)
    pair2 = (res1_name in POS_RES and res2_name in NEG_RES)

    if not (pair1 or pair2):
        return 0

    neg_res, pos_res = (res1, res2) if pair1 else (res2, res1)
    
    neg_coords = get_atom_coords(neg_res, NEG_ATOMS[neg_res.get_resname()])
    pos_coords = get_atom_coords(pos_res, POS_ATOMS[pos_res.get_resname()])

    if len(neg_coords) == 0 or len(pos_coords) == 0:
        return 0

    dist = np.min([np.linalg.norm(c1 - c2) for c1 in neg_coords for c2 in pos_coords])
    return 1 if dist < cutoff else 0

def check_hydrogen_bond(res1, res2, cutoff=3.5):
    
    res1_name, res2_name = res1.get_resname(), res2.get_resname()

    
    donors1 = get_atom_coords(res1, HBOND_DONORS.get(res1_name, []) + ['N'])
    acceptors1 = get_atom_coords(res1, HBOND_ACCEPTORS.get(res1_name, []) + ['O'])
    donors2 = get_atom_coords(res2, HBOND_DONORS.get(res2_name, []) + ['N'])
    acceptors2 = get_atom_coords(res2, HBOND_ACCEPTORS.get(res2_name, []) + ['O'])

    
    if len(donors1) > 0 and len(acceptors2) > 0:
        dist1 = np.min([np.linalg.norm(d - a) for d in donors1 for a in acceptors2])
        if dist1 < cutoff: return 1

    
    if len(donors2) > 0 and len(acceptors1) > 0:
        dist2 = np.min([np.linalg.norm(d - a) for d in donors2 for a in acceptors1])
        if dist2 < cutoff: return 1
        
    return 0

def check_aromatic_interaction(res1, res2, cutoff=6.0):
    
    res1_name, res2_name = res1.get_resname(), res2.get_resname()
    if res1_name not in AROMATIC_RES or res2_name not in AROMATIC_RES:
        return 0
    
    centroid1 = get_centroid(get_atom_coords(res1, AROMATIC_ATOMS[res1_name]))
    centroid2 = get_centroid(get_atom_coords(res2, AROMATIC_ATOMS[res2_name]))

    if centroid1 is not None and centroid2 is not None:
        if np.linalg.norm(centroid1 - centroid2) < cutoff:
            return 1
    return 0

def check_cation_pi_interaction(res1, res2, cutoff=6.0):
    
    res1_name, res2_name = res1.get_resname(), res2.get_resname()
    
    pair1 = (res1_name in POS_RES and res2_name in AROMATIC_RES)
    pair2 = (res1_name in AROMATIC_RES and res2_name in POS_RES)

    if not (pair1 or pair2):
        return 0
        
    cation_res, aromatic_res = (res1, res2) if pair1 else (res2, res1)
    
    cation_coords = get_atom_coords(cation_res, POS_ATOMS[cation_res.get_resname()])
    if len(cation_coords) == 0: return 0
    
    aromatic_centroid = get_centroid(get_atom_coords(aromatic_res, AROMATIC_ATOMS[aromatic_res.get_resname()]))

    if aromatic_centroid is not None:
        dist = np.min([np.linalg.norm(c - aromatic_centroid) for c in cation_coords])
        if dist < cutoff:
            return 1
    return 0


def dihedral(c1, c2, c3, c4) -> torch.Tensor:
    
    
    b1 = c2 - c1        
    b2 = c3 - c2        
    b3 = c4 - c3        

    
    n1 = torch.linalg.cross(b1, b2)   
    n2 = torch.linalg.cross(b2, b3)   

    
    b2_norm = torch.linalg.norm(b2, dim=-1, keepdim=True).clamp(min=1e-8)
    m1 = torch.linalg.cross(b2 / b2_norm, n1)  

    
    x = (n1 * n2).sum(dim=-1)   
    y = (m1 * n2).sum(dim=-1)   

    return torch.atan2(y, x)



def normalize_vector(v, dim=-1, eps=1e-8):
    norm = torch.linalg.norm(v, dim=dim, keepdim=True).clamp(min=eps)
    return v / norm

def project_v2v(v, u, dim=-1):
    
    inner = (v * u).sum(dim=dim, keepdim=True)
    return inner * u

def construct_3d_basis(center: torch.Tensor, p1: torch.Tensor, p2: torch.Tensor) -> torch.Tensor:
    
    v1 = p1 - center
    e1 = normalize_vector(v1, dim=-1)

    v2 = p2 - center
    u2 = v2 - project_v2v(v2, e1, dim=-1)
    e2 = normalize_vector(u2, dim=-1)

    e3 = torch.linalg.cross(e1, e2, dim=-1)

    return torch.cat([e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1)], dim=-1)

class ChemicalProperty(nn.Module):
    hydropathy = {
        'A':1.8,'R':-4.5,'N':-3.5,'D':-3.5,'C':2.5,
        'Q':-3.5,'E':-3.5,'G':-0.4,'H':-3.2,'I':4.5,
        'L':3.8,'K':-3.9,'M':1.9,'F':2.8,'P':-1.6,
        'S':-0.8,'T':-0.7,'W':-0.9,'Y':-1.3,'V':4.2
    }
    charge = {
        'A':0,'R':1,'N':0,'D':-1,'C':0,
        'E':-1,'Q':0,'G':0,'H':0.1,'I':0,
        'L':0,'K':1,'M':0,'F':0,'P':0,
        'S':0,'T':0,'W':0,'Y':0,'V':0
    }
    volume = {
        'A':67,'R':148,'N':96,'D':91,'C':86,
        'E':109,'Q':114,'G':48,'H':118,'I':124,
        'L':124,'K':135,'M':124,'F':135,'P':90,
        'S':73,'T':93,'W':163,'Y':141,'V':105
    }
    polarity = {
        'A':0,'R':1,'N':1,'D':1,'C':1,
        'E':1,'Q':1,'G':0,'H':1,'I':0,
        'L':0,'K':1,'M':0,'F':0,
        'P':0,'S':1,'T':1,'W':0,'Y':1,'V':0
    }
    def forward(self, aa: str) -> torch.Tensor:
        h = self.hydropathy.get(aa, 0.)
        c = self.charge.get(aa, 0.)
        v = self.volume.get(aa, 0) / 100.
        p = self.polarity.get(aa, 0)
        return torch.tensor([h, c, v, p], dtype=torch.float)

class SequenceSSChemicalFeature:
    AA_TYPES = "ARNDCQEGHILKMFPSTWYV"
    SS_TYPES = "HBEGITS-"
    def __init__(self):
        self.chem_extractor = ChemicalProperty()

    def __call__(self, residue, dssp_dict: dict) -> tuple:
        
        one_letter = THREE_TO_ONE.get(residue.get_resname(), "X")
        if one_letter in self.AA_TYPES:
            idx = self.AA_TYPES.index(one_letter)
            seq_onehot = F.one_hot(torch.tensor(idx), num_classes=20).float()
        else:
            seq_onehot = torch.zeros(20, dtype=torch.float)

        
        chain_id = residue.get_parent().id
        dssp_key = (chain_id, residue.get_id())
        ss_letter = str(dssp_dict[dssp_key][2]) if len(dssp_dict[dssp_key]) > 2 else "-"
        if ss_letter not in self.SS_TYPES:
            ss_letter = '-'
        ss_idx = self.SS_TYPES.index(ss_letter)
        ss_onehot = F.one_hot(torch.tensor(ss_idx), num_classes=8).float()

        
        chem_feature = self.chem_extractor(one_letter)

        return seq_onehot, ss_onehot, chem_feature

class PDBGraphBuilder:
    max_num_heavyatoms = 15
    restype_to_heavyatom_names = {
        AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
        AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    '', 'OXT'],
        AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    '', 'OXT'],
        AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    '', 'OXT'],
        AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
        AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    '', 'OXT'],
        AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    '', 'OXT'],
        AA.GLY: ['N', 'CA', 'C', 'O', '',    '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
        AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    '', 'OXT'],
        AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    '', 'OXT'],
        AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    '', 'OXT'],
        AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    '', 'OXT'],
        AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    '', 'OXT'],
        AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    '', 'OXT'],
        AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    '', 'OXT'],
        AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
        AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],
        AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'],
        AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    '', 'OXT'],
        AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],
    }
    def __init__(self, seq_dist_cut=32):
        self.parser = PDBParser(QUIET=True)
        self.seq_dist_cut = seq_dist_cut
        self.seq_ss_chem = SequenceSSChemicalFeature()
        self.lambdas = torch.tensor([1., 2., 5., 10., 30.]).view(1, -1) 

    def _get_residue_heavyatom_info(self, res: Residue):
        pos_heavyatom = torch.zeros([self.max_num_heavyatoms, 3], dtype=torch.float)
        mask_heavyatom = torch.zeros([self.max_num_heavyatoms, ], dtype=torch.bool)
        try:
            restype = AA(res.get_resname())
            for idx, atom_name in enumerate(self.restype_to_heavyatom_names[restype]):
                if atom_name == '': continue
                if atom_name in res:
                    pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)
                    
                    mask_heavyatom[idx] = True
        except (KeyError, ValueError):
            pass
        return pos_heavyatom, mask_heavyatom

    def _compute_dssp(self, model, pdb_path: str) -> dict:
        
        try:
            dssp = DSSP(model, pdb_path, dssp="mkdssp")
            
            return dict(dssp)
        except Exception as e:
            print(f"DSSP calculation failed: {e}")
            return {}

    def extract_bb_coords_from_res(self, residues):
        ca = torch.tensor(np.array([r['CA'].get_coord() for r in residues]), dtype=torch.float)
        c  = torch.tensor(np.array([r['C'].get_coord() for r in residues]), dtype=torch.float)
        n  = torch.tensor(np.array([r['N'].get_coord() for r in residues]), dtype=torch.float)
        return ca, c, n

    def compute_node_features(self, residues, bb_coords, dssp_dict):
        ca, c, n = bb_coords
        N = len(residues)
        K = len(self.lambdas.view(-1)) 

        feat_list = []
        for i, r in enumerate(residues):
            
            seq, ss, chem = self.seq_ss_chem(r, dssp_dict)

            
            
            if i > 0:
                phi = dihedral(c[i-1], n[i], ca[i], c[i])
            else:
                phi = torch.tensor(0.0) 

            
            if i < N-1:
                psi = dihedral(n[i], ca[i], c[i], n[i+1])
                omega = dihedral(ca[i], c[i], n[i+1], ca[i+1])
            else:
                psi, omega = torch.tensor(0.0), torch.tensor(0.0) 

            angle = torch.stack([torch.sin(phi), torch.cos(phi),
                                 torch.sin(psi), torch.cos(psi),
                                 torch.sin(omega), torch.cos(omega)])

            
            rho_placeholder = torch.zeros(K, dtype=torch.float)

            
            
            feat = torch.cat([seq, ss, chem, angle, rho_placeholder])
            feat_list.append(feat)

        return torch.stack(feat_list, dim=0)

    def compute_chain_edges(self, bb_coords, R, t):
        ca = bb_coords[0]
        N = len(ca)
        if N < 16: 
            k = max(N - 1, 1)
        else:
            k = min(max(int(0.04 * N), 8), 16)
             
        edge_index = knn_graph(ca, k=k, loop=False)
        src_list, dst_list = edge_index[0], edge_index[1]
        
        dist_tensor = torch.linalg.norm(ca[src_list] - ca[dst_list], dim=-1)
        
        edge_attr = self.compute_edge_features(src_list, dst_list, dist_tensor, R, t)
        rho = self.compute_surface_aware(ca, src_list, dst_list, dist_tensor)
        return edge_index, edge_attr, rho

    def compute_edge_features(self, src_list, dst_list, dist_tensor, R, t):
        

        
        seq_d = torch.abs(src_list - dst_list)
        seq_d = torch.clamp(seq_d, max=self.seq_dist_cut)
        seq_oh = F.one_hot(seq_d, num_classes=self.seq_dist_cut + 1).float()

        
        contact = (dist_tensor <= 8).float().unsqueeze(-1)

        
        dist_feat = self.distance_featurizer(dist_tensor, divisor=5.0)

        
        rel_pos = t[src_list] - t[dst_list]
        basis_dst = R[dst_list]
        basis_src = R[src_list]

        p = torch.einsum('bji,bi->bj', basis_dst, rel_pos)
        R_rel = torch.matmul(basis_dst.transpose(1, 2), basis_src) 

        
        R_rel_flat = R_rel.transpose(-1, -2).reshape(R_rel.shape[0], -1)              

        
        ori_tensor = torch.cat([p, R_rel_flat], dim=-1)          
        
        return torch.cat([seq_oh, contact, dist_feat, ori_tensor], dim=1)

    def distance_featurizer(self, dist_list: torch.Tensor, divisor: float) -> torch.Tensor:
        if dist_list.ndim == 1:
            dist_list = dist_list.unsqueeze(-1)

        device = dist_list.device
        dtype = dist_list.dtype
        r = torch.arange(15, device=device, dtype=dtype)
        sigmas = 1.5 ** r
        
        d2 = (dist_list / divisor) ** 2
        denom = 2.0 * (sigmas ** 2)
        rbf = torch.exp(-d2 / denom)
        return rbf

    def compute_surface_aware(self, ca, src_list, dst_list, dist_tensor):
        N = ca.size(0)
        
        diff = ca[src_list] - ca[dst_list]
        w_raw = torch.exp(-dist_tensor.pow(2).view(-1, 1) / self.lambdas.view(1,-1))
        
        
        w = scatter_softmax(w_raw, src_list, dim=0, dim_size=N) 
        
        weighted_diff = w.unsqueeze(-1) * diff.unsqueeze(1) 
        weighted_vec = scatter_add(weighted_diff, src_list, dim=0, dim_size=N) 

        weighted_dist_val = w * dist_tensor.unsqueeze(-1) 
        weighted_dist = scatter_add(weighted_dist_val, src_list, dim=0, dim_size=N) 
        
        norm_num = weighted_vec.norm(dim=-1) 
        rho = norm_num / (weighted_dist + 1e-8)
        
        return rho

    def compute_interface_edge_features(self, src, dst, dist_list, chain1, chain2, res_long, res_short):
        
        
        R_all = torch.cat([chain1.R, chain2.R], dim=0)
        t_all = torch.cat([chain1.t, chain2.t], dim=0)
        dist_tensor = torch.tensor(dist_list, dtype=torch.float)

        
        dist_feat = self.distance_featurizer(dist_tensor, divisor=5.0)

        
        rel_pos = t_all[src] - t_all[dst]
        basis_dst = R_all[dst]
        basis_src = R_all[src]

        p = torch.einsum('bji,bi->bj', basis_dst, rel_pos)

        
        
        R_rel = torch.matmul(basis_dst.transpose(1, 2), basis_src)  
        
        
        R_rel_flat = R_rel.transpose(-1, -2).reshape(R_rel.shape[0], -1)  

        
        ori_tensor = torch.cat([p, R_rel_flat], dim=-1)  

        
        offset_short = len(res_long)
        chem_extractor = self.seq_ss_chem.chem_extractor

        
        aa_src = [THREE_TO_ONE.get(res_long[i].get_resname(), 'X') for i in src]
        aa_dst = [THREE_TO_ONE.get(res_short[j - offset_short].get_resname(), 'X') for j in dst]

        
        cp_src = torch.stack([chem_extractor(aa) for aa in aa_src])
        cp_dst = torch.stack([chem_extractor(aa) for aa in aa_dst])

        
        same_polarity = (cp_src[:, 3] == cp_dst[:, 3]).float().unsqueeze(-1)
        opposite_charge = ((cp_src[:, 1] * cp_dst[:, 1]) < 0).float().unsqueeze(-1)
        diff_volume = torch.abs(cp_src[:, 2] - cp_dst[:, 2]).unsqueeze(-1)
        diff_hydropathy = torch.abs(cp_src[:, 0] - cp_dst[:, 0]).unsqueeze(-1)
        
        chem_tensor = torch.cat([same_polarity, opposite_charge, diff_volume, diff_hydropathy], dim=1)

        
        interaction_feats = []
        offset_short = len(res_long)
        for i, j in zip(src, dst):
            res1 = res_long[i]
            res2 = res_short[j - offset_short]

            
            hbond = check_hydrogen_bond(res1, res2)
            salt_bridge = check_salt_bridge(res1, res2)
            aromatic = check_aromatic_interaction(res1, res2)
            cation_pi = check_cation_pi_interaction(res1, res2)

            interaction_feats.append([hbond, salt_bridge, aromatic, cation_pi])
        
        interaction_tensor = torch.tensor(interaction_feats, dtype=torch.float)

        
        
        return torch.cat([dist_feat, ori_tensor, chem_tensor, interaction_tensor], dim=1)

    def compute_interface_edges(self, chain1, chain2, res_long, res_short):
        offset_short = len(res_long)
        atoms_l = Selection.unfold_entities(res_long, 'A')
        atoms_s = Selection.unfold_entities(res_short, 'A')
        ns = NeighborSearch(atoms_l + atoms_s)
        src, dst, dist_list = [], [], []
        
        
        for i, rl in enumerate(res_long):
            
            neighbor_res_in_short_chain = set()
            for atom in rl:
                if atom.element == 'H': continue
                
                neighbor_residues = ns.search(atom.get_coord(), 6.0, 'R')
                for res in neighbor_residues:
                     
                    if res in res_short:
                        neighbor_res_in_short_chain.add(res)
            
            
            if neighbor_res_in_short_chain:
                
                for rs in neighbor_res_in_short_chain:
                    try:
                        j = res_short.index(rs) + offset_short
                        src.append(i)
                        dst.append(j)
                        dist_list.append(np.linalg.norm(rl['CA'].get_coord() - rs['CA'].get_coord()))
                    except ValueError:
                        
                        continue
        
        if not src: 
            print(f"Warning: No interface edges found between {chain1.id} and {chain2.id}.")
            exit()

        
        edge_attr = self.compute_interface_edge_features(
            src, dst, dist_list, chain1, chain2, res_long, res_short
        )

        edge_index = torch.tensor([src, dst], dtype=torch.long)
        return edge_index, edge_attr

    def build_chain_graph(self, residues, bb_coords, dssp_dict):
        ca, c, n = bb_coords
        R = construct_3d_basis(ca, c, n)
        t = ca
        
        node_attr = self.compute_node_features(residues, bb_coords, dssp_dict)
        edge_index, edge_attr, rho = self.compute_chain_edges(bb_coords, R, t)
        
        
        node_attr[:, -rho.shape[1]:] = rho

        
        pos_heavyatom_list = []
        mask_heavyatom_list = []
        for r in residues:
            pos, mask = self._get_residue_heavyatom_info(r)
            pos_heavyatom_list.append(pos)
            mask_heavyatom_list.append(mask)
        
        pos_heavyatom = torch.stack(pos_heavyatom_list, dim=0) 
        mask_heavyatom = torch.stack(mask_heavyatom_list, dim=0) 
        
        
        return Data(
            x=node_attr, 
            edge_index=edge_index, 
            edge_attr=edge_attr, 
            R=R, 
            t=t,
            pos_heavyatom=pos_heavyatom,    
            mask_heavyatom=mask_heavyatom   
        )

    def build_interface_graph(self, chain1, chain2, res_long, res_short):
        node_attr_all = torch.cat([chain1.x, chain2.x], dim=0)

        pos_heavyatom_all = torch.cat([chain1.pos_heavyatom, chain2.pos_heavyatom], dim=0)
        mask_heavyatom_all = torch.cat([chain1.mask_heavyatom, chain2.mask_heavyatom], dim=0)

        edge_index, edge_attr = self.compute_interface_edges(
            chain1, chain2, res_long, res_short
        )
        
        R_all = torch.cat([chain1.R, chain2.R], dim=0)
        t_all = torch.cat([chain1.t, chain2.t], dim=0)

        return Data(
            x=node_attr_all, 
            edge_index=edge_index, 
            edge_attr=edge_attr, 
            R=R_all, 
            t=t_all,
            pos_heavyatom=pos_heavyatom_all,    
            mask_heavyatom=mask_heavyatom_all   
        )

    def build(self, pdb_path):
        
        
        
        id = os.path.basename(pdb_path).replace('.pdb', '')
        structure = self.parser.get_structure(id, pdb_path)
        model = structure[0]
        print(f"\nProcessing {id}...")
        
        
        dssp_complex = self._compute_dssp(model, pdb_path)

        chains_dict = {ch.id: ch for ch in model}
        sorted_chain_ids = sorted(chains_dict.keys(), key=lambda k: len(chains_dict[k]), reverse=True)

        if len(sorted_chain_ids) < 2:
             print(f"Warning: Less than two chains in {pdb_path}. Skipping.")
             return None
        
        chains = {ch.id: [r for r in ch if r.get_resname() in THREE_TO_ONE and 'CA' in r] for ch in model}
        chains = {k: v for k, v in chains.items() if v}
        
        sorted_chains = sorted(chains.values(), key=len, reverse=True)
        res_long, res_short = sorted_chains[0], sorted_chains[1]
        
        bb_coords_long = self.extract_bb_coords_from_res(res_long)
        bb_coords_short = self.extract_bb_coords_from_res(res_short)

        chain1_graph = self.build_chain_graph(res_long, bb_coords_long, dssp_complex)
        chain2_graph = self.build_chain_graph(res_short, bb_coords_short, dssp_complex)
        
        
        interface_graph = self.build_interface_graph(
            chain1_graph, chain2_graph, res_long, res_short
        )
        
        return {'id': id, 'chain_1_graph': chain1_graph, 'chain_2_graph': chain2_graph, 'interface_graph': interface_graph}

if __name__ == "__main__":
    pdb_path = "."
    builder = PDBGraphBuilder()
    result = builder.build(pdb_path)
    print(result)
    