# """
# Create dataset for epiformer encoder combining AsEP graphs with PDB structure files.

# This script processes the AsEP dataset and PDB structure files to create comprehensive
# data tensors containing atom-level, residue-level, and edge-level graphs for the epiformer encoder.

# Graph Hierarchy:
# 1. AtomMP: processes atom_graph → atom embeddings
# 2. EdgeMP: processes edge_graph (line graph of residue edges) → edge embeddings  
# 3. ResMP: processes residue_graph + aggregated atom + injected edge embeddings → final residue embeddings
# 4. Decoder: processes final antigen & antibody residue embeddings → interaction prediction
# """

# """
# TODO:
# - add the plm embeddings to the residue graph node features
# - remove the bipartite matrix, except edge_index_ag_ab
# """


"""

HeteroData(
  complex_id='complex_identifier',  # e.g., '1s78_0P'
  
  # ==================== NODE STORES ====================
  
  # Antigen Atom Level
  ag_atom={
    x=[num_ag_atoms, 28],     # Atom features (7 elements + 21 residue types)
    pos=[num_ag_atoms, 3],    # 3D atomic coordinates
  },
  
  # Antigen Residue Level  
  ag_res={
    x=[num_ag_residues, 105],                    # Residue features (RAAD: 20+16+12+48+9)
    plm=[num_ag_residues, plm_dim],              # PLM embeddings (ESM-2: 480/1280 dim)
    y=[num_ag_residues],                         # Epitope labels (0/1)
    pos=[num_ag_residues, 3],                    # CA coordinates
    edge_lists={                                 # Edge lists by relation type
      0=[num_r0_edges],                          # r0 edges (sequential ±1)
      1=[num_r1_edges],                          # r1 edges (sequential ±2)  
      2=[num_r2_edges],                          # r2 edges (k-NN)
      3=[num_r3_edges],                          # r3 edges (spatial)
    },
    edge_features_tensor=[total_edge_features, 100],  # All edge features stacked
    edge_keys_tensor=[total_edge_features, 3],        # Edge keys (src, dst, rel)
  },
  
  # Antibody Atom Level
  ab_atom={
    x=[num_ab_atoms, 28],     # Atom features (7 elements + 21 residue types)
    pos=[num_ab_atoms, 3],    # 3D atomic coordinates
  },
  
  # Antibody Residue Level
  ab_res={
    x=[num_ab_residues, 105],                    # Residue features (RAAD: 20+16+12+48+9)
    plm=[num_ab_residues, plm_dim],              # PLM embeddings (AbLang/AntiBERTy: 512 dim)
    y=[num_ab_residues],                         # Paratope labels (0/1)
    pos=[num_ab_residues, 3],                    # CA coordinates
    edge_lists={                                 # Edge lists by relation type
      0=[num_r0_edges],                          # r0 edges (sequential ±1)
      1=[num_r1_edges],                          # r1 edges (sequential ±2)  
      2=[num_r2_edges],                          # r2 edges (k-NN)
      3=[num_r3_edges],                          # r3 edges (spatial)
    },
    edge_features_tensor=[total_edge_features, 100],  # All edge features stacked
    edge_keys_tensor=[total_edge_features, 3],        # Edge keys (src, dst, rel)
  },
  
  # Edge Graph Nodes (Line Graph for EdgeMP)
  ag_edge={ 
    x=[num_ag_edge_nodes, 100],           # Edge node features (from residue graph edges)
    edge_mapping=[num_ag_edge_nodes, 3],  # Mapping to residue edges (src, dst, rel)
  },
  ab_edge={ 
    x=[num_ab_edge_nodes, 100],           # Edge node features (from residue graph edges)
    edge_mapping=[num_ab_edge_nodes, 3],  # Mapping to residue edges (src, dst, rel)
  },
  
  # ==================== EDGE STORES ====================
  
  # Atom-Level Bonds
  (ag_atom, atom_bond, ag_atom)={
    edge_index=[2, num_ag_atom_bonds],    # Spatial bonds within cutoff (4.5Å)
    edge_attr=[num_ag_atom_bonds, 17],    # Distance + RBF encoding (1+16)
  },
  (ab_atom, atom_bond, ab_atom)={
    edge_index=[2, num_ab_atom_bonds],    # Spatial bonds within cutoff (4.5Å)
    edge_attr=[num_ab_atom_bonds, 17],    # Distance + RBF encoding (1+16)
  },
  
  # Residue-Level Multi-Relational Edges
  
  # r0: Sequential ±1 neighbors
  (ag_res, r0, ag_res)={
    edge_index=[2, num_ag_r0_edges],      # Sequential distance = 1
    edge_attr=[num_ag_r0_edges, 100],     # RAAD edge features
  },
  (ab_res, r0, ab_res)={
    edge_index=[2, num_ab_r0_edges],      # Sequential distance = 1
    edge_attr=[num_ab_r0_edges, 100],     # RAAD edge features
  },
  
  # r1: Sequential ±2 neighbors
  (ag_res, r1, ag_res)={
    edge_index=[2, num_ag_r1_edges],      # Sequential distance = 2
    edge_attr=[num_ag_r1_edges, 100],     # RAAD edge features
  },
  (ab_res, r1, ab_res)={
    edge_index=[2, num_ab_r1_edges],      # Sequential distance = 2
    edge_attr=[num_ab_r1_edges, 100],     # RAAD edge features
  },
  
  # r2: k-NN spatial neighbors (k=10)
  (ag_res, r2, ag_res)={
    edge_index=[2, num_ag_r2_edges],      # k-nearest neighbors
    edge_attr=[num_ag_r2_edges, 100],     # RAAD edge features
  },
  (ab_res, r2, ab_res)={
    edge_index=[2, num_ab_r2_edges],      # k-nearest neighbors
    edge_attr=[num_ab_r2_edges, 100],     # RAAD edge features
  },
  
  # r3: Spatial proximity (within 8.0Å, not in r0-r2)
  (ag_res, r3, ag_res)={
    edge_index=[2, num_ag_r3_edges],      # Spatial cutoff edges
    edge_attr=[num_ag_r3_edges, 100],     # RAAD edge features
  },
  (ab_res, r3, ab_res)={
    edge_index=[2, num_ab_r3_edges],      # Spatial cutoff edges
    edge_attr=[num_ab_r3_edges, 100],     # RAAD edge features
  },
  
  # Line Graph (Edge-Edge) Connections
  (ag_edge, edge_connect, ag_edge)={
    edge_index=[2, num_ag_line_edges],    # Edges between edge nodes
    edge_attr=[num_ag_line_edges, 1],     # Edge type encoding
  },
  (ab_edge, edge_connect, ab_edge)={
    edge_index=[2, num_ab_line_edges],    # Edges between edge nodes
    edge_attr=[num_ab_line_edges, 1],     # Edge type encoding
  },
  
  # Cross-Hierarchy Connections
  
  # Atom-to-Residue mapping
  (ag_atom, belongs_to, ag_res)={ 
    edge_index=[2, num_ag_atoms]          # Each atom belongs to one residue
  },
  (ab_atom, belongs_to, ab_res)={ 
    edge_index=[2, num_ab_atoms]          # Each atom belongs to one residue
  },
  
  # Residue-to-Edge mapping
  (ag_res, connected_by, ag_edge)={ 
    edge_index=[2, num_ag_res_edge_connections]  # Residues connected by edge nodes
  },
  (ab_res, connected_by, ab_edge)={ 
    edge_index=[2, num_ab_res_edge_connections]  # Residues connected by edge nodes
  },
  
  # Inter-Protein Interactions
  (ag_res, interacts, ab_res)={ 
    edge_index=[2, num_interaction_pairs]  # Bipartite antigen-antibody interactions
  }
)

"""



"""
Create dataset for epiformer encoder combining AsEP graphs with PDB structure files.
Modified to create HeteroData objects for each complex sample.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, HeteroData
import numpy as np
import pandas as pd
from Bio.PDB import PDBParser, Select
from Bio.PDB.PDBIO import PDBIO
from biopandas.pdb import PandasPdb
import pickle
from typing import Dict, List, Tuple, Optional, Any
import warnings
warnings.filterwarnings("ignore")

# Import the graph construction classes from previous modules
import sys
sys.path.append(os.getcwd())

from scipy.spatial.transform import Rotation as R
import math
from torch_scatter import scatter_add

# ==================== ATOM GRAPH CONSTRUCTION ====================

class PDBProcessor:
    """Processes PDB files to extract atom information and build graphs"""
    def __init__(self, pdb_path, include_hydrogens=False):
        self.pdb_path = pdb_path
        self.include_hydrogens = include_hydrogens
        self.parser = PDBParser(QUIET=True)
        self.structure = self.parser.get_structure("pdb", pdb_path)
        self.atoms = []
        self.coords = []
        
    def extract_atoms(self):
        """Extract atoms and coordinates from PDB file"""
        for model in self.structure:
            for chain in model:
                for residue in chain:
                    # Skip heterogens and water
                    if residue.id[0] != ' ':
                        continue
                        
                    for atom in residue:
                        if not self.include_hydrogens and atom.element == 'H':
                            continue
                        self.atoms.append({
                            'name': atom.get_name(),
                            'element': atom.element,
                            'residue': residue.resname,
                            'residue_id': residue.id[1],
                            'chain_id': chain.id,
                            'coord': atom.get_coord()
                        })
                        self.coords.append(atom.get_coord())
        return self
    
    def get_atom_features(self):
        """Create feature vectors for atoms"""
        # Define feature mappings
        elements = ['C', 'N', 'O', 'S', 'P', 'H', 'OTHER']
        residues = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 
                    'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 
                    'THR', 'TRP', 'TYR', 'VAL', 'OTHER']
        
        features = []
        for atom in self.atoms:
            # One-hot encode element
            elem_idx = elements.index(atom['element']) if atom['element'] in elements else -1
            elem_onehot = torch.zeros(len(elements))
            elem_onehot[elem_idx if elem_idx != -1 else -1] = 1
            
            # One-hot encode residue
            res_idx = residues.index(atom['residue']) if atom['residue'] in residues else -1
            res_onehot = torch.zeros(len(residues))
            res_onehot[res_idx if res_idx != -1 else -1] = 1
            
            # Combine features (total: 7 + 21 = 28 dimensions)
            features.append(torch.cat([elem_onehot, res_onehot]))
        
        return torch.stack(features)
    
    def build_graph(self, cutoff=4.5):
        """Build atom graph with edges based on spatial proximity"""
        coords = torch.tensor(self.coords, dtype=torch.float)
        n_atoms = len(self.atoms)
        
        # Compute distance matrix
        dist_matrix = torch.cdist(coords, coords)
        
        # Create edges (within cutoff and excluding self-loops)
        src, dst = torch.nonzero((dist_matrix < cutoff) & (dist_matrix > 0), as_tuple=True)
        
        # Create edge attributes (distance + RBF expansion)
        distances = dist_matrix[src, dst].unsqueeze(1)
        offsets = torch.linspace(0, 20, 16)  # 16 RBF functions
        widths = torch.ones_like(offsets) * 1.0
        rbf_expanded = torch.exp(-0.5 * ((distances - offsets) / widths) ** 2)
        edge_attr = torch.cat([distances, rbf_expanded], dim=1)  # 1 + 16 = 17 dimensions
        
        # # Create residue mapping for aggregation
        # residue_indices = []
        # for atom in self.atoms:
        #     residue_indices.append(atom['residue_id'])

        # """
        # TODO: -make the residue indices sequential for atom-2-res mapping
        # """
        # residue_indices = make_sequential(residue_indices)

        # Create residue mapping for aggregation
        residue_indices = []
        residue_id_to_index = {}
        
        # First pass: map residue IDs to sequential indices
        unique_residues = sorted(set(atom['residue_id'] for atom in self.atoms))
        residue_id_to_index = {res_id: idx for idx, res_id in enumerate(unique_residues)}
        
        # Second pass: assign indices to atoms
        for atom in self.atoms:
            residue_indices.append(residue_id_to_index[atom['residue_id']])

        
        # Create PyG Data object
        return Data(
            x=self.get_atom_features(),  # (n_atoms, 28)
            pos=coords,                  # (n_atoms, 3)
            edge_index=torch.stack([src, dst]),  # (2, n_edges)
            edge_attr=edge_attr,         # (n_edges, 17)
            residue_indices=torch.tensor(residue_indices)  # For aggregation
        )

def make_sequential(lst):
    # Get the unique values and sort them
    unique_values = sorted(set(lst))

    # Create a mapping from the original values to sequential values
    value_to_sequential = {value: idx for idx, value in enumerate(unique_values)}

    # Transform the original list using the mapping
    sequential_list = [value_to_sequential[value] for value in lst]

    return sequential_list


def process_pdb_file(pdb_path, include_hydrogens=False):
    """
    Process a PDB file into a graph representation
    Args:
        pdb_path: Path to PDB file
        include_hydrogens: Whether to include hydrogen atoms
    Returns:
        PyG Data object with atom graph
    """
    processor = PDBProcessor(pdb_path, include_hydrogens)
    processor.extract_atoms()
    return processor.build_graph()


# ==================== RESIDUE GRAPH CONSTRUCTION ====================

class ResidueGraphBuilder:
    """Builds multi-relational residue graphs from PDB structures with full RAAD features"""
    
    def __init__(self, k_nn=10, spatial_cutoff=8.0, sequential_cutoffs=[1, 2]):
        self.k_nn = k_nn
        self.spatial_cutoff = spatial_cutoff
        self.sequential_cutoffs = sequential_cutoffs
        self.parser = PDBParser(QUIET=True)
        
        # RBF parameters
        self.rbf_centers = torch.linspace(0, 20, 16)
        self.rbf_width = 1.0

    def extract_residues(self, pdb_path):
        """Extract residue information from PDB file"""
        structure = self.parser.get_structure("pdb", pdb_path)
        residues = []
        
        for model in structure:
            for chain in model:
                for residue in chain:
                    # Skip heterogens and water
                    if residue.id[0] != ' ':
                        continue
                    
                    # Get CA atom position
                    if 'CA' in residue:
                        ca_coord = residue['CA'].get_coord()
                    else:
                        # Skip residues without CA
                        continue
                    
                    residues.append({
                        'resname': residue.resname,
                        'chain_id': chain.id,
                        'res_id': residue.id[1],
                        'ca_coord': ca_coord,
                        'residue_obj': residue
                    })
        
        return residues
        
    def compute_rbf(self, distances):
        """Compute radial basis function encoding"""
        if isinstance(distances, (int, float)):
            distances = torch.tensor([distances], dtype=torch.float)
        elif not isinstance(distances, torch.Tensor):
            distances = torch.tensor(distances, dtype=torch.float)
        
        # Expand distances to match RBF centers
        if distances.dim() == 0:
            distances = distances.unsqueeze(0)
        
        rbf = torch.exp(-0.5 * ((distances.unsqueeze(-1) - self.rbf_centers) / self.rbf_width) ** 2)
        return rbf.squeeze(0) if rbf.shape[0] == 1 else rbf

    def compute_positional_encoding(self, position, max_len=1000):
        """Compute sinusoidal positional encoding"""
        pe = torch.zeros(16)
        position = torch.tensor([position], dtype=torch.float)
        
        div_term = torch.exp(torch.arange(0, 16, 2).float() * 
                           -(math.log(10000.0) / 16))
        
        pe[0::2] = torch.sin(position * div_term)
        pe[1::2] = torch.cos(position * div_term)
        return pe

    def compute_local_coordinate_frame(self, ca_coord, c_coord, n_coord):
        """Compute local coordinate frame Q_i from CA, C, N atoms"""
        ca = torch.tensor(ca_coord, dtype=torch.float)
        c = torch.tensor(c_coord, dtype=torch.float)
        n = torch.tensor(n_coord, dtype=torch.float)
        
        # Create coordinate frame
        v1 = c - ca  # CA -> C
        v2 = n - ca  # CA -> N
        
        # Gram-Schmidt orthogonalization
        u1 = F.normalize(v1, dim=0)
        u2_temp = v2 - torch.dot(v2, u1) * u1
        u2 = F.normalize(u2_temp, dim=0)
        u3 = torch.cross(u1, u2)
        
        Q = torch.stack([u1, u2, u3], dim=1)  # 3x3 matrix
        return Q

    def compute_dihedral_angles(self, residue, prev_residue=None, next_residue=None):
        """Compute 6 backbone angles: phi, psi, omega, alpha, beta, gamma"""
        angles = torch.zeros(6)
        
        try:
            # Current residue atoms
            n_curr = torch.tensor(residue['N'], dtype=torch.float)
            ca_curr = torch.tensor(residue['CA'], dtype=torch.float)
            c_curr = torch.tensor(residue['C'], dtype=torch.float)
            
            # Phi angle (previous C - current N - current CA - current C)
            if prev_residue is not None and 'C' in prev_residue:
                c_prev = torch.tensor(prev_residue['C'], dtype=torch.float)
                phi = self.compute_dihedral(c_prev, n_curr, ca_curr, c_curr)
                angles[0] = phi
            
            # Psi angle (current N - current CA - current C - next N)
            if next_residue is not None and 'N' in next_residue:
                n_next = torch.tensor(next_residue['N'], dtype=torch.float)
                psi = self.compute_dihedral(n_curr, ca_curr, c_curr, n_next)
                angles[1] = psi
            
            # Omega angle (previous CA - previous C - current N - current CA)
            if prev_residue is not None and 'CA' in prev_residue and 'C' in prev_residue:
                ca_prev = torch.tensor(prev_residue['CA'], dtype=torch.float)
                c_prev = torch.tensor(prev_residue['C'], dtype=torch.float)
                omega = self.compute_dihedral(ca_prev, c_prev, n_curr, ca_curr)
                angles[2] = omega
            
            # Bond angles alpha, beta, gamma
            if 'O' in residue:
                o_curr = torch.tensor(residue['O'], dtype=torch.float)
                # Alpha: N-CA-C angle
                alpha = self.compute_bond_angle(n_curr, ca_curr, c_curr)
                angles[3] = alpha
                
                # Beta: CA-C-O angle  
                beta = self.compute_bond_angle(ca_curr, c_curr, o_curr)
                angles[4] = beta
                
                # Gamma: C-CA-N angle
                gamma = self.compute_bond_angle(c_curr, ca_curr, n_curr)
                angles[5] = gamma
                
        except Exception as e:
            print(f"Warning: Could not compute all angles: {e}")
        
        return angles

    def compute_dihedral(self, p1, p2, p3, p4):
        """Compute dihedral angle between 4 points"""
        v1 = p2 - p1
        v2 = p3 - p2  
        v3 = p4 - p3
        
        n1 = torch.cross(v1, v2)
        n2 = torch.cross(v2, v3)
        
        n1 = F.normalize(n1, dim=0)
        n2 = F.normalize(n2, dim=0)
        
        cos_angle = torch.clamp(torch.dot(n1, n2), -1, 1)
        angle = torch.acos(cos_angle)
        
        # Check sign
        if torch.dot(torch.cross(n1, n2), F.normalize(v2, dim=0)) < 0:
            angle = -angle
            
        return angle

    def compute_bond_angle(self, p1, p2, p3):
        """Compute bond angle at p2"""
        v1 = p1 - p2
        v2 = p3 - p2
        
        cos_angle = torch.dot(F.normalize(v1, dim=0), F.normalize(v2, dim=0))
        cos_angle = torch.clamp(cos_angle, -1, 1)
        return torch.acos(cos_angle)

    def compute_residue_features(self, residues):
        """Compute 105-dimensional node features for each residue"""
        aa_types = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 
                    'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 
                    'THR', 'TRP', 'TYR', 'VAL']
        
        features = []
        for idx, res in enumerate(residues):
            feature_parts = []
            
            # 1. Residue type embedding (20-D)
            res_type = res['resname'] if res['resname'] in aa_types else 'ALA'  # fallback
            res_idx = aa_types.index(res_type)
            res_onehot = torch.zeros(20)
            res_onehot[res_idx] = 1
            feature_parts.append(res_onehot)
            
            # 2. Positional encoding (16-D)
            pos_enc = self.compute_positional_encoding(idx)
            feature_parts.append(pos_enc)
            
            # Get atomic coordinates
            residue_atoms = res['residue_obj']
            coord_dict = {}
            for atom in residue_atoms:
                if atom.element != 'H':  # Skip hydrogens
                    coord_dict[atom.name] = atom.coord
            
            # 3. Bond/dihedral angles (12-D: 6 angles × 2 for sin/cos)
            prev_res = residues[idx-1] if idx > 0 else None
            next_res = residues[idx+1] if idx < len(residues)-1 else None
            
            prev_coord = {}
            next_coord = {}
            if prev_res is not None:
                for atom in prev_res['residue_obj']:
                    if atom.element != 'H':
                        prev_coord[atom.name] = atom.coord
            if next_res is not None:
                for atom in next_res['residue_obj']:
                    if atom.element != 'H':
                        next_coord[atom.name] = atom.coord
            
            angles = self.compute_dihedral_angles(coord_dict, prev_coord, next_coord)
            angle_features = torch.stack([torch.sin(angles), torch.cos(angles)], dim=1).flatten()
            feature_parts.append(angle_features)
            
            # 4. RBF distances (48-D: 3 distances × 16 RBF each)
            if 'CA' in coord_dict:
                ca_pos = torch.tensor(coord_dict['CA'], dtype=torch.float)
                distance_features = []
                
                for atom_name in ['C', 'N', 'O']:
                    if atom_name in coord_dict:
                        atom_pos = torch.tensor(coord_dict[atom_name], dtype=torch.float)
                        dist = torch.norm(ca_pos - atom_pos)
                        rbf_dist = self.compute_rbf(dist)
                        distance_features.append(rbf_dist)
                    else:
                        # Use default distance if atom missing
                        default_dist = 1.5  # Typical bond length
                        rbf_dist = self.compute_rbf(default_dist)
                        distance_features.append(rbf_dist)
                
                feature_parts.append(torch.cat(distance_features))
            else:
                # Fallback if CA missing
                feature_parts.append(torch.zeros(48))
            
            # 5. Local coordinate frame directions (9-D)
            if all(atom in coord_dict for atom in ['CA', 'C', 'N']):
                Q = self.compute_local_coordinate_frame(
                    coord_dict['CA'], coord_dict['C'], coord_dict['N']
                )
                frame_features = Q.flatten()
                feature_parts.append(frame_features)
            else:
                # Identity matrix as fallback
                feature_parts.append(torch.eye(3).flatten())
            
            # Concatenate all features (should sum to 105)
            full_feature = torch.cat(feature_parts)
            assert full_feature.shape[0] == 105, f"Feature dimension mismatch: {full_feature.shape[0]} != 105"
            features.append(full_feature)
        
        return torch.stack(features)

    def compute_edge_features(self, i, j, residues, ca_coords):
        """Compute 100-dimensional edge features between residues i and j"""
        feature_parts = []
        
        # Get residue objects
        res_i = residues[i]
        res_j = residues[j]
        
        # Get atomic coordinates
        coord_i = {}
        coord_j = {}
        for atom in res_i['residue_obj']:
            if atom.element != 'H':
                coord_i[atom.name] = torch.tensor(atom.coord, dtype=torch.float)
        for atom in res_j['residue_obj']:
            if atom.element != 'H':
                coord_j[atom.name] = torch.tensor(atom.coord, dtype=torch.float)
        
        # 1. Edge type encoding (4-D) - will be set by calling function
        # For now, create placeholder - this should be set based on relation type
        edge_type_onehot = torch.zeros(4)
        feature_parts.append(edge_type_onehot)
        
        # 2. Relative positional encoding (16-D)
        rel_pos = j - i if res_i['chain_id'] == res_j['chain_id'] else 0
        rel_pos_enc = self.compute_positional_encoding(rel_pos)
        feature_parts.append(rel_pos_enc)
        
        # 3. RBF distances to 4 backbone atoms (64-D: 4 × 16)
        ca_i = ca_coords[i]
        distance_features = []
        
        for atom_name in ['CA', 'C', 'N', 'O']:
            if atom_name in coord_j:
                atom_j = coord_j[atom_name]
                dist = torch.norm(ca_i - atom_j)
                rbf_dist = self.compute_rbf(dist)
                distance_features.append(rbf_dist)
            else:
                # Use CA-CA distance as fallback
                dist = torch.norm(ca_i - ca_coords[j])
                rbf_dist = self.compute_rbf(dist)
                distance_features.append(rbf_dist)
        
        feature_parts.append(torch.cat(distance_features))
        
        # 4. Direction vectors in local frame (12-D: 4 directions × 3D)
        if all(atom in coord_i for atom in ['CA', 'C', 'N']):
            Q_i = self.compute_local_coordinate_frame(
                coord_i['CA'], coord_i['C'], coord_i['N']
            )
            
            direction_features = []
            for atom_name in ['CA', 'C', 'N', 'O']:
                if atom_name in coord_j:
                    atom_j = coord_j[atom_name]
                    direction = atom_j - ca_i
                    direction_norm = torch.norm(direction)
                    if direction_norm > 1e-6:
                        direction = direction / direction_norm
                    else:
                        direction = torch.zeros(3)
                    
                    # Transform to local coordinate frame
                    local_direction = Q_i.T @ direction
                    direction_features.append(local_direction)
                else:
                    direction_features.append(torch.zeros(3))
            
            feature_parts.append(torch.cat(direction_features))
        else:
            # Fallback to zero directions
            feature_parts.append(torch.zeros(12))
        
        # 5. Quaternion orientation (4-D)
        if (all(atom in coord_i for atom in ['CA', 'C', 'N']) and 
            all(atom in coord_j for atom in ['CA', 'C', 'N'])):
            
            Q_i = self.compute_local_coordinate_frame(
                coord_i['CA'], coord_i['C'], coord_i['N']
            )
            Q_j = self.compute_local_coordinate_frame(
                coord_j['CA'], coord_j['C'], coord_j['N']
            )
            
            # Compute relative rotation Q_i^T @ Q_j
            rel_rotation = Q_i.T @ Q_j
            
            # Convert to quaternion
            quaternion = self.rotation_matrix_to_quaternion(rel_rotation)
            feature_parts.append(quaternion)
        else:
            # Identity quaternion as fallback
            feature_parts.append(torch.tensor([1.0, 0.0, 0.0, 0.0]))
        
        # Concatenate all features (should sum to 100)
        full_feature = torch.cat(feature_parts)
        
        # Handle the edge type encoding separately since it depends on relation type
        return full_feature

    def rotation_matrix_to_quaternion(self, R):
        """Convert rotation matrix to quaternion"""
        # Ensure it's a valid rotation matrix
        R = R.detach().numpy()
        try:
            rot = R.from_matrix(R)
            quat = rot.as_quat()  # Returns [x, y, z, w]
            # Convert to [w, x, y, z] format
            quat = np.array([quat[3], quat[0], quat[1], quat[2]])
        except:
            # Fallback to identity quaternion
            quat = np.array([1.0, 0.0, 0.0, 0.0])
        
        return torch.tensor(quat, dtype=torch.float)

    def build_multi_relational_graph(self, residues):
        """Build multi-relational residue graph with complete features"""
        n_residues = len(residues)
        ca_coords = torch.tensor([r['ca_coord'] for r in residues], dtype=torch.float)
        
        # Initialize edge lists for different relations
        edge_lists = {r: [] for r in range(4)}  # r0, r1, r2, r3
        
        # Sequential edges (r0: seq±1, r1: seq±2)  
        for i in range(n_residues):
            for j in range(n_residues):
                if i == j:
                    continue
                
                # Check if same chain
                if residues[i]['chain_id'] == residues[j]['chain_id']:
                    seq_dist = abs(residues[i]['res_id'] - residues[j]['res_id'])
                    
                    if seq_dist == 1:
                        edge_lists[0].append([i, j])  # r0
                    elif seq_dist == 2:
                        edge_lists[1].append([i, j])  # r1
        
        # Spatial edges
        dist_matrix = torch.cdist(ca_coords, ca_coords)
        
        # r2: k-NN edges
        for i in range(n_residues):
            distances = dist_matrix[i]
            # Exclude self and get k nearest
            _, indices = torch.topk(-distances, k=min(self.k_nn + 1, n_residues))
            for j in indices[1:self.k_nn + 1]:  # Skip self
                if j != i:
                    edge_lists[2].append([i, j.item()])
        
        # r3: spatial proximity (within cutoff, not already connected)
        spatial_edges = torch.nonzero((dist_matrix < self.spatial_cutoff) & (dist_matrix > 0)).tolist()
        for i, j in spatial_edges:
            # Check if not already connected by other relations
            already_connected = any([i, j] in edge_lists[r] for r in range(3))
            if not already_connected:
                edge_lists[3].append([i, j])

        
        # Compute edge features with relation types
        edge_features = {}
        relation_names = ['seq±1', 'seq±2', 'k-NN', 'spatial']

        for rel_type, edges in edge_lists.items():
            for src, dst in edges:
                edge_key = (src, dst, rel_type)
                edge_feat = self.compute_edge_features(src, dst, residues, ca_coords)
                
                # Set the edge type encoding (first 4 dimensions)
                edge_type_onehot = torch.zeros(4)
                edge_type_onehot[rel_type] = 1
                edge_feat[:4] = edge_type_onehot
                
                edge_features[edge_key] = edge_feat
        
        return edge_lists, ca_coords, edge_features

    def build_residue_graph(self, pdb_path):
        """Build complete residue graph from PDB file"""
        residues = self.extract_residues(pdb_path)
        node_features = self.compute_residue_features(residues)
        edge_lists, ca_coords, edge_features = self.build_multi_relational_graph(residues)
        
        # Convert to PyG format
        all_edges = []
        all_edge_attrs = []
        
        for rel_type, edges in edge_lists.items():
            for edge in edges:
                all_edges.append(edge)
                all_edge_attrs.append(edge_features[(edge[0], edge[1], rel_type)])
        
        if all_edges:
            edge_index = torch.tensor(all_edges).T
            edge_attr = torch.stack(all_edge_attrs)
        else:
            edge_index = torch.zeros((2, 0), dtype=torch.long)
            edge_attr = torch.zeros((0, 100))
        
        return Data(
            x=node_features,      # (n_residues, 105)
            pos=ca_coords,        # (n_residues, 3)  
            edge_index=edge_index,  # (2, n_edges)
            edge_attr=edge_attr,   # (n_edges, 100)
            edge_lists=edge_lists,  # Dict of edge lists by relation type
            edge_features=edge_features  # Original edge features dict
        )


# ==================== EDGE GRAPH (LINE GRAPH) CONSTRUCTION ====================

class LineGraphBuilder:
    """Constructs line graph from residue graph for edge message passing"""
    
    def __init__(self, num_angle_bins=8):
        self.num_angle_bins = num_angle_bins
        
    def fix_edge_graph_data(self, edge_graph):
        """
        Fix edge graph data to ensure node_to_edge contains tuples instead of lists
        """
        if hasattr(edge_graph, 'node_to_edge') and edge_graph.node_to_edge is not None:
            # Convert lists to tuples to make them hashable
            fixed_node_to_edge = []
            for item in edge_graph.node_to_edge:
                if isinstance(item, list):
                    # Ensure all components are primitive types
                    fixed_item = []
                    for component in item:
                        if isinstance(component, list):
                            fixed_item.append(tuple(component) if len(component) > 1 else component[0])
                        else:
                            fixed_item.append(component)
                    fixed_node_to_edge.append(tuple(fixed_item))
                else:
                    fixed_node_to_edge.append(item)
            edge_graph.node_to_edge = fixed_node_to_edge
        
        if hasattr(edge_graph, 'edge_to_node') and edge_graph.edge_to_node is not None:
            if isinstance(edge_graph.edge_to_node, dict):
                fixed_edge_to_node = {}
                for key, value in edge_graph.edge_to_node.items():
                    if isinstance(key, list):
                        fixed_key = tuple(key)
                        fixed_edge_to_node[fixed_key] = value
                    else:
                        fixed_edge_to_node[key] = value
                edge_graph.edge_to_node = fixed_edge_to_node
        
        return edge_graph



    def build_line_graph(self, residue_edge_lists, ca_coords, edge_features):
        """
        Build line graph where nodes are edges of the original graph
        FIXED: Only create tensor-based mappings
        """
        # Create mapping from residue edges to line graph nodes
        node_to_edge = []  # List of tuples - will be converted to tensor
        node_features = []
        
        node_idx = 0
        for rel_type, edges in residue_edge_lists.items():
            for edge in edges:
                # Store as tuple for consistency
                if isinstance(edge, list):
                    edge_tuple = tuple(edge)
                else:
                    edge_tuple = edge
                    
                edge_key = (edge_tuple[0], edge_tuple[1], rel_type)
                node_to_edge.append(edge_key)
                
                # Node feature is the edge feature from residue graph
                if edge_key in edge_features:
                    node_features.append(edge_features[edge_key])
                else:
                    # Try with original edge format if tuple version not found
                    alt_key = (edge[0], edge[1], rel_type) if isinstance(edge, list) else edge_key
                    if alt_key in edge_features:
                        node_features.append(edge_features[alt_key])
                    else:
                        # Create default feature if missing
                        default_feature = torch.zeros(100)
                        node_features.append(default_feature)
                
                node_idx += 1
        
        if not node_features:
            # Handle case with no edges
            return Data(
                x=torch.zeros((0, 100)),
                edge_index=torch.zeros((2, 0), dtype=torch.long),
                edge_attr=torch.zeros((0, 1)),
                node_to_edge=[]  # Empty list
            )
        
        # Build line graph edges (edges that share a common residue)
        line_edges = []
        line_edge_types = []
        
        for i, (src1, dst1, rel1) in enumerate(node_to_edge):
            for j, (src2, dst2, rel2) in enumerate(node_to_edge):
                if i >= j:  # Avoid duplicates and self-loops
                    continue
                
                # Check if edges share a common residue
                common_residue = None
                if dst1 == src2:
                    common_residue = dst1
                elif src1 == dst2:
                    common_residue = src1
                elif src1 == src2:
                    common_residue = src1
                elif dst1 == dst2:
                    common_residue = dst1
                
                if common_residue is not None:
                    # Compute angle between edges (simplified for now)
                    edge_type = rel1 * 4 * self.num_angle_bins + rel2 * self.num_angle_bins
                    
                    line_edges.append([i, j])
                    line_edge_types.append(edge_type)
        
        if len(line_edges) > 0:
            line_edge_index = torch.tensor(line_edges).t()
            line_edge_types = torch.tensor(line_edge_types)
        else:
            line_edge_index = torch.empty((2, 0), dtype=torch.long)
            line_edge_types = torch.empty(0, dtype=torch.long)
        
        return Data(
            x=torch.stack(node_features) if node_features else torch.empty((0, 100)),
            edge_index=line_edge_index,
            edge_attr=line_edge_types.unsqueeze(1) if line_edge_types.numel() > 0 else torch.empty((0, 1)),
            node_to_edge=node_to_edge  # Keep as list of tuples for now - will be converted to tensor in build_hetero_data
        )




# ==================== HETERODATA CONSTRUCTION ====================


def build_hetero_data(complex_data):
    """
    Create a HeteroData object for the complex sample
    Fixed to avoid problematic edge mapping attributes that cause DataLoader issues
    """
    hetero_data = HeteroData()
    
    # Extract components
    ag_data = complex_data['antigen']
    ab_data = complex_data['antibody']
    interaction_data = complex_data['interactions']
    
    # Add antigen data
    hetero_data['ag_atom'].x = ag_data['atom_graph'].x
    hetero_data['ag_atom'].pos = ag_data['atom_graph'].pos
    hetero_data['ag_res'].x = ag_data['residue_graph'].x
    hetero_data['ag_res'].plm = ag_data['plm_embeddings']
    hetero_data['ag_res'].y = ag_data['epitope_labels']
    hetero_data['ag_res'].pos = ag_data['coordinates']
    
    # Add antibody data
    hetero_data['ab_atom'].x = ab_data['atom_graph'].x
    hetero_data['ab_atom'].pos = ab_data['atom_graph'].pos
    hetero_data['ab_res'].x = ab_data['residue_graph'].x
    hetero_data['ab_res'].plm = ab_data['plm_embeddings']
    hetero_data['ab_res'].y = ab_data['paratope_labels']
    hetero_data['ab_res'].pos = ab_data['coordinates']
    
    # Add edge graphs (FIXED: Store only tensor-based edge mappings)
    if hasattr(ag_data['edge_graph'], 'x') and ag_data['edge_graph'].x is not None:
        hetero_data['ag_edge'].x = ag_data['edge_graph'].x

        # FIXED: Convert node_to_edge to tensor and remove edge_to_node
        if hasattr(ag_data['edge_graph'], 'node_to_edge'):
            node_to_edge_list = ag_data['edge_graph'].node_to_edge
            if node_to_edge_list and len(node_to_edge_list) > 0:
                # Convert list of tuples to tensor
                try:
                    node_to_edge_tensor = torch.tensor(node_to_edge_list, dtype=torch.long)
                    hetero_data['ag_edge'].edge_mapping = node_to_edge_tensor
                except Exception as e:
                    print(f"Warning: Could not convert ag_edge node_to_edge to tensor: {e}")
                    # Store edge_lists from residue graph instead
                    if hasattr(ag_data['residue_graph'], 'edge_lists'):
                        hetero_data['ag_res'].edge_lists = ag_data['residue_graph'].edge_lists
        
        # DON'T store edge_to_node - it causes DataLoader issues

    if hasattr(ab_data['edge_graph'], 'x') and ab_data['edge_graph'].x is not None:
        hetero_data['ab_edge'].x = ab_data['edge_graph'].x

        # FIXED: Convert node_to_edge to tensor and remove edge_to_node
        if hasattr(ab_data['edge_graph'], 'node_to_edge'):
            node_to_edge_list = ab_data['edge_graph'].node_to_edge
            if node_to_edge_list and len(node_to_edge_list) > 0:
                # Convert list of tuples to tensor
                try:
                    node_to_edge_tensor = torch.tensor(node_to_edge_list, dtype=torch.long)
                    hetero_data['ab_edge'].edge_mapping = node_to_edge_tensor
                except Exception as e:
                    print(f"Warning: Could not convert ab_edge node_to_edge to tensor: {e}")
                    # Store edge_lists from residue graph instead
                    if hasattr(ab_data['residue_graph'], 'edge_lists'):
                        hetero_data['ab_res'].edge_lists = ab_data['residue_graph'].edge_lists
        
        # DON'T store edge_to_node - it causes DataLoader issues

    # Add edge_lists to residue nodes for EdgeMP processing
    if hasattr(ag_data['residue_graph'], 'edge_lists'):
        hetero_data['ag_res'].edge_lists = ag_data['residue_graph'].edge_lists
    if hasattr(ab_data['residue_graph'], 'edge_lists'):
        hetero_data['ab_res'].edge_lists = ab_data['residue_graph'].edge_lists

    # Add edge_features to residue nodes for EdgeMP processing
    if hasattr(ag_data['residue_graph'], 'edge_features'):
        # Convert edge_features dict to a more DataLoader-friendly format
        edge_features_list = []
        edge_keys_list = []
        for key, feat in ag_data['residue_graph'].edge_features.items():
            edge_keys_list.append(list(key))  # Convert tuple to list
            edge_features_list.append(feat)
        
        if edge_features_list:
            hetero_data['ag_res'].edge_features_tensor = torch.stack(edge_features_list)
            hetero_data['ag_res'].edge_keys_tensor = torch.tensor(edge_keys_list, dtype=torch.long)

    if hasattr(ab_data['residue_graph'], 'edge_features'):
        # Convert edge_features dict to a more DataLoader-friendly format
        edge_features_list = []
        edge_keys_list = []
        for key, feat in ab_data['residue_graph'].edge_features.items():
            edge_keys_list.append(list(key))  # Convert tuple to list
            edge_features_list.append(feat)
        
        if edge_features_list:
            hetero_data['ab_res'].edge_features_tensor = torch.stack(edge_features_list)
            hetero_data['ab_res'].edge_keys_tensor = torch.tensor(edge_keys_list, dtype=torch.long)
    
    # Add atom bonds with attributes
    hetero_data = add_edges_with_attrs(
        hetero_data,
        ag_data['atom_graph'],
        ('ag_atom', 'atom_bond', 'ag_atom'),
        ['edge_index', 'edge_attr']
    )
    hetero_data = add_edges_with_attrs(
        hetero_data,
        ab_data['atom_graph'],
        ('ab_atom', 'atom_bond', 'ab_atom'),
        ['edge_index', 'edge_attr']
    )
    
    # Add residue relations with attributes
    for rel_type in [0, 1, 2, 3]:
        rel_name = f"r{rel_type}"
        hetero_data = add_residue_relation(
            hetero_data,
            ag_data['residue_graph'],
            'ag_res',
            rel_type,
            rel_name
        )
        hetero_data = add_residue_relation(
            hetero_data,
            ab_data['residue_graph'],
            'ab_res',
            rel_type,
            rel_name
        )
    
    # Add edge graph connections
    if hasattr(ag_data['edge_graph'], 'edge_index'):
        hetero_data = add_edges_with_attrs(
            hetero_data,
            ag_data['edge_graph'],
            ('ag_edge', 'edge_connect', 'ag_edge'),
            ['edge_index', 'edge_attr']
        )
    if hasattr(ab_data['edge_graph'], 'edge_index'):
        hetero_data = add_edges_with_attrs(
            hetero_data,
            ab_data['edge_graph'],
            ('ab_edge', 'edge_connect', 'ab_edge'),
            ['edge_index', 'edge_attr']
        )
    
    # Add cross connections
    hetero_data = add_cross_connections(hetero_data, ag_data, 'ag')
    hetero_data = add_cross_connections(hetero_data, ab_data, 'ab')
    
    # Add interaction edges
    if 'interaction_pairs' in interaction_data and interaction_data['interaction_pairs'].numel() > 0:
        hetero_data['ag_res', 'interacts', 'ab_res'].edge_index = interaction_data['interaction_pairs']
    
    # Add metadata
    hetero_data.complex_id = complex_data['complex_id']
    
    # Clean up any empty node stores
    clean_hetero_data(hetero_data)
    
    # Initialize missing components
    initialize_missing_components(hetero_data)
    
    return hetero_data


def add_cross_connections(hetero_data, protein_data, prefix):
    """
    Add cross connections between different hierarchy levels
    FIXED: Use tensor-based edge mappings
    """
    # Atom to residue connections
    if hasattr(protein_data['atom_graph'], 'residue_indices'):
        res_indices = protein_data['atom_graph'].residue_indices
        atom_to_res = torch.stack([
            torch.arange(len(res_indices)),
            res_indices
        ])
        hetero_data[f'{prefix}_atom', 'belongs_to', f'{prefix}_res'].edge_index = atom_to_res
    
    # Residue to edge graph connections using tensor-based mapping
    if f'{prefix}_edge' in hetero_data.node_types and hasattr(hetero_data[f'{prefix}_edge'], 'edge_mapping'):
        edge_mapping = hetero_data[f'{prefix}_edge'].edge_mapping
        
        # Create connections: each edge node connects to its source and destination residues
        if edge_mapping.numel() > 0:
            # edge_mapping shape: [num_edge_nodes, 3] where columns are [src, dst, rel]
            edge_nodes = torch.arange(edge_mapping.size(0))
            
            # Connect to source residues
            src_connections = torch.stack([edge_mapping[:, 0], edge_nodes])  # [residue_idx, edge_idx]
            # Connect to destination residues  
            dst_connections = torch.stack([edge_mapping[:, 1], edge_nodes])  # [residue_idx, edge_idx]
            
            # Combine both connections
            res_to_edge = torch.cat([src_connections, dst_connections], dim=1)
            hetero_data[f'{prefix}_res', 'connected_by', f'{prefix}_edge'].edge_index = res_to_edge
    
    return hetero_data




def add_edges_with_attrs(hetero_data, graph_data, edge_type, attr_names):
    """Add edges with attributes to HeteroData"""
    for attr in attr_names:
        if hasattr(graph_data, attr) and getattr(graph_data, attr) is not None:
            hetero_data[edge_type][attr] = getattr(graph_data, attr)
    return hetero_data

def add_residue_relation(hetero_data, residue_graph, prefix, rel_type, rel_name):
    """Add residue relation edges with attributes"""
    if rel_type in residue_graph.edge_lists:
        edges = residue_graph.edge_lists[rel_type]
        if edges:
            edge_index = torch.tensor(edges).t().contiguous()
            hetero_data[prefix, rel_name, prefix].edge_index = edge_index
            
            # Add edge attributes if available
            edge_attrs = []
            for edge in edges:
                edge_key = (edge[0], edge[1], rel_type)
                if edge_key in residue_graph.edge_features:
                    edge_attrs.append(residue_graph.edge_features[edge_key])
            
            if edge_attrs:
                hetero_data[prefix, rel_name, prefix].edge_attr = torch.stack(edge_attrs)
    return hetero_data



def clean_hetero_data(hetero_data):
    """Remove empty node stores from HeteroData"""
    for node_type in list(hetero_data.node_types):
        store = hetero_data[node_type]
        if store.num_nodes == 0:
            # Find and remove all associated edges
            for edge_type in list(hetero_data.edge_types):
                if node_type in edge_type:
                    del hetero_data[edge_type]
            # Remove the node store
            del hetero_data[node_type]


def initialize_missing_components(hetero_data):
    # Define all expected edge types
    expected_edge_types = [
        ('ag_res', 'r0', 'ag_res'),
        ('ag_res', 'r1', 'ag_res'),
        ('ag_res', 'r2', 'ag_res'),
        ('ag_res', 'r3', 'ag_res'),
        ('ab_res', 'r0', 'ab_res'),
        ('ab_res', 'r1', 'ab_res'),
        ('ab_res', 'r2', 'ab_res'),
        ('ab_res', 'r3', 'ab_res'),
        ('ag_edge', 'edge_connect', 'ag_edge'),
        ('ab_edge', 'edge_connect', 'ab_edge'),
        ('ag_atom', 'belongs_to', 'ag_res'),
        ('ab_atom', 'belongs_to', 'ab_res'),
        ('ag_res', 'connected_by', 'ag_edge'),
        ('ab_res', 'connected_by', 'ab_edge'),
        ('ag_res', 'interacts', 'ab_res')
    ]

    # Initialize missing edge types
    for edge_type in expected_edge_types:
        if edge_type not in hetero_data.edge_types:
            hetero_data[edge_type].edge_index = torch.empty((2, 0), dtype=torch.long)
            # Add edge_attr if required by your model
            if edge_type[1] in ['r0', 'r1', 'r2', 'r3', 'edge_connect']:
                dim = 100 if 'r' in edge_type[1] else 1
                hetero_data[edge_type].edge_attr = torch.empty((0, dim), dtype=torch.float)
    
    return hetero_data



# ==================== DATASET CREATION ====================

class EpiformerDatasetCreator:
    """Creates epiformer dataset combining AsEP data with PDB structures"""
    
    def __init__(self, 
                 asep_data_path: str,
                 ag_pdb_dir: str,
                 ab_pdb_dir: str,
                 output_path: str,
                 antiberty_path: str):
        """
        Initialize the dataset creator.
        
        Args:
            asep_data_path: Path to AsEP preprocessed data pickle file
            ag_pdb_dir: Directory containing antigen PDB files (surf)
            ab_pdb_dir: Directory containing antibody PDB files (cdr)
            output_path: Path to save the processed dataset
        """
        self.asep_data_path = asep_data_path
        self.ag_pdb_dir = ag_pdb_dir
        self.ab_pdb_dir = ab_pdb_dir
        self.output_path = output_path
        self.antiberty_path =antiberty_path
        
        # Initialize graph builders
        self.atom_processor = PDBProcessor
        self.residue_builder = ResidueGraphBuilder()
        self.edge_builder = LineGraphBuilder()  # Add edge graph builder
        
        # AA mapping for sequence processing
        self.AA_MAP = {
            "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G",
            "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N",
            "PRO": "P", "GLN": "Q", "ARG": "R", "SER": "S", "THR": "T", "VAL": "V",
            "TRP": "W", "TYR": "Y"
        }

    def load_asep_data(self):
        """Load and process AsEP dataset"""
        print(f"Loading AsEP data from {self.asep_data_path}")
        
        # Load original AsEP graphs
        asep_graphs = torch.load(self.asep_data_path)
        
        return asep_graphs

    def extract_plm_embeddings(self, asep_graphs, pdb_id, antiberty_embeddings):
        """Extract PLM embeddings from AsEP data"""
        if pdb_id in asep_graphs:
            # Antigen PLM (ESM-2)
            ag_plm = asep_graphs[pdb_id]["x_g"].numpy()  # Shape: (n_ag_residues, esm_dim)
            # Antibody PLM (AbLang/IgFold) 
            # ab_plm = asep_graphs[pdb_id]["x_b"].numpy()  # Shape: (n_ab_residues, ab_dim)

            """
            TODO: add AntiBERTy embeddings
            """
            if pdb_id in antiberty_embeddings:
                # print("using antiberty embeddings")
                ab_plm = antiberty_embeddings[pdb_id]
            else:
                # Fallback to existing embeddings
                ab_plm = asep_graphs[pdb_id]["x_b"].numpy()    
            
            return {
                'antigen': torch.tensor(ag_plm, dtype=torch.float),
                'antibody': torch.tensor(ab_plm, dtype=torch.float)
            }
        else:
            print(f"Warning: PLM embeddings not found for {pdb_id}")
            return None

    def extract_labels_and_interactions(self, asep_graphs, pdb_id):
        """Extract epitope/paratope labels and interaction matrix"""
        if pdb_id in asep_graphs:
            data = asep_graphs[pdb_id]
            
            # Extract labels
            epitope_labels = data.get("y_g", torch.zeros(data["x_g"].shape[0]))
            paratope_labels = data.get("y_b", torch.zeros(data["x_b"].shape[0]))
            
            # Extract bipartite interaction matrix from edge_index_bg
            if "edge_index_bg" in data:
                edge_index_bg = data["edge_index_bg"]
                # Convert to dense bipartite matrix
                n_ag = data["x_g"].shape[0]
                n_ab = data["x_b"].shape[0]
                bipartite_matrix = torch.zeros(n_ag, n_ab)
                
                if edge_index_bg.shape[1] > 0:
                    ab_indices = edge_index_bg[0]  # First row: antibody indices
                    ag_indices = edge_index_bg[1]  # Second row: antigen indices
                    bipartite_matrix[ag_indices, ab_indices] = 1
            else:
                n_ag = data["x_g"].shape[0] 
                n_ab = data["x_b"].shape[0]
                bipartite_matrix = torch.zeros(n_ag, n_ab)
            
            return {
                'epitope_labels': epitope_labels,
                'paratope_labels': paratope_labels,
                'bipartite_matrix': bipartite_matrix,
                'interaction_pairs': data.get("edge_index_bg", torch.zeros((2, 0)))
            }
        else:
            print(f"Warning: Labels not found for {pdb_id}")
            return None

    def process_complex(self, pdb_id: str, asep_graphs: dict, antiberty_embeddings):
        """Process a single antibody-antigen complex into HeteroData"""
        print(f"Processing complex: {pdb_id}")
        
        # File paths
        ag_pdb_path = os.path.join(self.ag_pdb_dir, f"{pdb_id}_surf.pdb")
        ab_pdb_path = os.path.join(self.ab_pdb_dir, f"{pdb_id}_cdr.pdb")
        
        # Check if files exist
        if not os.path.exists(ag_pdb_path):
            print(f"Warning: Antigen PDB not found: {ag_pdb_path}")
            return None
        if not os.path.exists(ab_pdb_path):
            print(f"Warning: Antibody PDB not found: {ab_pdb_path}")
            return None
        
        try:
            # Build atom graphs
            ag_atom_processor = self.atom_processor(ag_pdb_path)
            ag_atom_processor.extract_atoms()
            ag_atom_graph = ag_atom_processor.build_graph()
            
            ab_atom_processor = self.atom_processor(ab_pdb_path)
            ab_atom_processor.extract_atoms()
            ab_atom_graph = ab_atom_processor.build_graph()
            
            # Build residue graphs
            ag_residue_graph = self.residue_builder.build_residue_graph(ag_pdb_path)
            ab_residue_graph = self.residue_builder.build_residue_graph(ab_pdb_path)
            
            # Build edge graphs (line graphs for EdgeMP)
            ag_edge_graph = self.edge_builder.build_line_graph(
                ag_residue_graph.edge_lists,
                ag_residue_graph.pos,
                ag_residue_graph.edge_features
            )
            ab_edge_graph = self.edge_builder.build_line_graph(
                ab_residue_graph.edge_lists,
                ab_residue_graph.pos,
                ab_residue_graph.edge_features
            )
            
            # Extract PLM embeddings
            plm_embeddings = self.extract_plm_embeddings(asep_graphs, pdb_id, antiberty_embeddings)
            if plm_embeddings is None:
                return None
            
            # Extract labels and interactions
            labels_interactions = self.extract_labels_and_interactions(asep_graphs, pdb_id)
            if labels_interactions is None:
                return None
            
            # Create comprehensive data structure
            complex_data = {
                'complex_id': pdb_id,
                'antigen': {
                    'atom_graph': ag_atom_graph,
                    'residue_graph': ag_residue_graph,
                    'edge_graph': ag_edge_graph,
                    'plm_embeddings': plm_embeddings['antigen'],
                    'epitope_labels': labels_interactions['epitope_labels'],
                    'coordinates': ag_residue_graph.pos
                },
                'antibody': {
                    'atom_graph': ab_atom_graph,
                    'residue_graph': ab_residue_graph,
                    'edge_graph': ab_edge_graph,
                    'plm_embeddings': plm_embeddings['antibody'],
                    'paratope_labels': labels_interactions['paratope_labels'],
                    'coordinates': ab_residue_graph.pos
                },
                'interactions': {
                    'bipartite_matrix': labels_interactions['bipartite_matrix'],
                    'interaction_pairs': labels_interactions['interaction_pairs']
                }
            }
            
            # Build HeteroData object
            return build_hetero_data(complex_data)
            
        except Exception as e:
            print(f"Error processing {pdb_id}: {e}")
            return None
        

    def create_dataset(self, max_examples: Optional[int] = None):
        """
        - this function checks if the dataset already exists, and resumes the graph construction
        """

        """Create the complete epiformer dataset as HeteroData objects with resume capability"""
        print("Creating epiformer dataset...")
        
        # Load AsEP data
        asep_graphs = self.load_asep_data()
        antiberty_embeddings = torch.load(self.antiberty_path)

        # Get list of PDB IDs
        pdb_ids = list(asep_graphs.keys())
        pdb_ids.remove("5nj6_0P")
        pdb_ids.remove("5ies_0P")

        # Check for existing dataset
        processed_data = []
        existing_ids = set()
        if os.path.exists(self.output_path):
            try:
                processed_data = torch.load(self.output_path)
                existing_ids = {data.complex_id for data in processed_data}
                print(f"Found existing dataset with {len(processed_data)} complexes. Resuming...")
            except Exception as e:
                print(f"Error loading existing dataset: {e}. Starting fresh.")
                processed_data = []

        # Filter out already processed complexes
        remaining_ids = [pid for pid in pdb_ids if pid not in existing_ids]
        print(f"Total complexes to process: {len(remaining_ids)}")
        
        if max_examples:
            remaining_ids = remaining_ids[:max_examples]

        successful_count = 0
        total_count = len(remaining_ids)
        
        for i, pdb_id in enumerate(remaining_ids):
            print(f"\nProcessing {i+1}/{total_count}: {pdb_id}")
            
            hetero_data = self.process_complex(pdb_id, asep_graphs, antiberty_embeddings)
            if hetero_data is not None:
                processed_data.append(hetero_data)
                successful_count += 1
                
                # Save periodically (every 10 complexes)
                if successful_count % 10 == 0:
                    print(f"Temporary save after {successful_count} complexes")
                    torch.save(processed_data, self.output_path)
        
        print(f"\nDataset creation complete!")
        print(f"Successfully processed: {successful_count}/{total_count} complexes")
        
        # Final save
        print(f"Saving dataset to {self.output_path}")
        torch.save(processed_data, self.output_path)
        
        return processed_data



    def print_statistics(self, hetero_data):
        """Print statistics for a processed complex"""
        print(f"Complex ID: {hetero_data.complex_id}")
        print(f"  Antigen atoms: {hetero_data['ag_atom'].x.shape[0]}")
        print(f"  Antigen residues: {hetero_data['ag_res'].x.shape[0]}")
        print(f"  Antigen edge nodes: {hetero_data['ag_edge'].x.shape[0]}")
        print(f"  Antibody atoms: {hetero_data['ab_atom'].x.shape[0]}")
        print(f"  Antibody residues: {hetero_data['ab_res'].x.shape[0]}")
        print(f"  Antibody edge nodes: {hetero_data['ab_edge'].x.shape[0]}")
        print(f"  Interaction edges: {hetero_data['ag_res', 'interacts', 'ab_res'].edge_index.shape[1]}")


# ==================== UTILITY FUNCTIONS ====================

def load_epiformer_dataset(dataset_path: str):
    """Load the processed epiformer dataset"""
    return torch.load(dataset_path)

def get_dataset_statistics(dataset):
    """Get comprehensive statistics about the dataset"""
    n_complexes = len(dataset)
    
    ag_atom_counts = []
    ag_residue_counts = []
    ag_edge_counts = []
    ab_atom_counts = []
    ab_residue_counts = []
    ab_edge_counts = []
    interaction_counts = []
    
    for hetero_data in dataset:
        ag_atom_counts.append(hetero_data['ag_atom'].x.shape[0])
        ag_residue_counts.append(hetero_data['ag_res'].x.shape[0])
        ag_edge_counts.append(hetero_data['ag_edge'].x.shape[0])
        ab_atom_counts.append(hetero_data['ab_atom'].x.shape[0])
        ab_residue_counts.append(hetero_data['ab_res'].x.shape[0])
        ab_edge_counts.append(hetero_data['ab_edge'].x.shape[0])
        interaction_counts.append(hetero_data['ag_res', 'interacts', 'ab_res'].edge_index.shape[1])
    
    stats = {
        'n_complexes': n_complexes,
        'antigen_atoms': {'mean': np.mean(ag_atom_counts), 'std': np.std(ag_atom_counts), 'range': (min(ag_atom_counts), max(ag_atom_counts))},
        'antigen_residues': {'mean': np.mean(ag_residue_counts), 'std': np.std(ag_residue_counts), 'range': (min(ag_residue_counts), max(ag_residue_counts))},
        'antigen_edges': {'mean': np.mean(ag_edge_counts), 'std': np.std(ag_edge_counts), 'range': (min(ag_edge_counts), max(ag_edge_counts))},
        'antibody_atoms': {'mean': np.mean(ab_atom_counts), 'std': np.std(ab_atom_counts), 'range': (min(ab_atom_counts), max(ab_atom_counts))},
        'antibody_residues': {'mean': np.mean(ab_residue_counts), 'std': np.std(ab_residue_counts), 'range': (min(ab_residue_counts), max(ab_residue_counts))},
        'antibody_edges': {'mean': np.mean(ab_edge_counts), 'std': np.std(ab_edge_counts), 'range': (min(ab_edge_counts), max(ab_edge_counts))},
        'interactions': {'mean': np.mean(interaction_counts), 'std': np.std(interaction_counts), 'range': (min(interaction_counts), max(interaction_counts))}
    }
    
    return stats

def print_dataset_statistics(stats):
    """Print formatted dataset statistics"""
    print("\n" + "="*60)
    print("HIERARCHICAL DATASET STATISTICS")
    print("="*60)
    print(f"Total complexes: {stats['n_complexes']}")
    print()
    
    for category, data in stats.items():
        if category == 'n_complexes':
            continue
        print(f"{category.replace('_', ' ').title()}:")
        print(f"  Mean: {data['mean']:.1f} ± {data['std']:.1f}")
        print(f"  Range: {data['range'][0]} - {data['range'][1]}")
        print()


# ==================== TESTING AND VALIDATION ====================

def validate_dataset_integrity(dataset):
    """Validate the integrity and consistency of the created dataset"""
    print("\n" + "="*60)
    print("VALIDATING DATASET INTEGRITY")
    print("="*60)
    
    issues = []
    
    for i, hetero_data in enumerate(dataset):
        complex_id = getattr(hetero_data, 'complex_id', f'complex_{i}')
        
        try:
            # Check antigen data consistency
            ag_atoms = hetero_data['ag_atom'].x.shape[0]
            ag_residues = hetero_data['ag_res'].x.shape[0]
            ag_edge_nodes = hetero_data['ag_edge'].x.shape[0]
            ag_plm_residues = hetero_data['ag_res'].plm.shape[0]
            ag_epitope_labels = hetero_data['ag_res'].y.shape[0]
            
            if ag_residues != ag_plm_residues:
                issues.append(f"{complex_id}: Antigen residue count mismatch (graph: {ag_residues}, PLM: {ag_plm_residues})")
            if ag_residues != ag_epitope_labels:
                issues.append(f"{complex_id}: Antigen epitope label count mismatch (residues: {ag_residues}, labels: {ag_epitope_labels})")
            
            # Check antibody data consistency
            ab_atoms = hetero_data['ab_atom'].x.shape[0]
            ab_residues = hetero_data['ab_res'].x.shape[0]
            ab_edge_nodes = hetero_data['ab_edge'].x.shape[0]
            ab_plm_residues = hetero_data['ab_res'].plm.shape[0]
            ab_paratope_labels = hetero_data['ab_res'].y.shape[0]
            
            if ab_residues != ab_plm_residues:
                issues.append(f"{complex_id}: Antibody residue count mismatch (graph: {ab_residues}, PLM: {ab_plm_residues})")
            if ab_residues != ab_paratope_labels:
                issues.append(f"{complex_id}: Antibody paratope label count mismatch (residues: {ab_residues}, labels: {ab_paratope_labels})")
            
            # Check feature dimensions
            ag_atom_features = hetero_data['ag_atom'].x.shape[1]
            ag_residue_features = hetero_data['ag_res'].x.shape[1]
            if ag_atom_features != 28:  # 7 elements + 21 residues
                issues.append(f"{complex_id}: Antigen atom feature dimension mismatch (expected: 28, got: {ag_atom_features})")
            if ag_residue_features != 105:  # Expected residue feature dimension
                issues.append(f"{complex_id}: Antigen residue feature dimension mismatch (expected: 105, got: {ag_residue_features})")
            
            # Check edge graph dimensions
            if ag_edge_nodes > 0:  # Only check if edge graph has nodes
                ag_edge_features = hetero_data['ag_edge'].x.shape[1]
                if ag_edge_features != 100:  # Expected edge node feature dimension
                    issues.append(f"{complex_id}: Antigen edge node feature dimension mismatch (expected: 100, got: {ag_edge_features})")
            
            if ab_edge_nodes > 0:  # Only check if edge graph has nodes
                ab_edge_features = hetero_data['ab_edge'].x.shape[1]
                if ab_edge_features != 100:  # Expected edge node feature dimension
                    issues.append(f"{complex_id}: Antibody edge node feature dimension mismatch (expected: 100, got: {ab_edge_features})")
            
        except Exception as e:
            issues.append(f"{complex_id}: Error during validation - {str(e)}")
    
    if issues:
        print(f"Found {len(issues)} integrity issues:")
        for issue in issues[:10]:  # Show first 10 issues
            print(f"  ❌ {issue}")
        if len(issues) > 10:
            print(f"  ... and {len(issues) - 10} more issues")
    else:
        print("✅ Dataset integrity validation passed!")
    
    return issues


# ==================== MAIN EXECUTION ====================

if __name__ == "__main__":
    import os
    import sys
    sys.path.append(os.path.join(os.getcwd(), 'm3epi'))
    sys.path.append(os.path.join(os.getcwd(), '../../'))
    sys.path.append(os.path.join(os.getcwd(), '../../../walle'))

    # Configuration - Update these paths according to the setup
    proj_dir = os.path.join(os.getcwd(), '../../../../')
    asep_data_path = os.path.join(proj_dir, "data/asep/processed/dict_pre_cal_esm2_esm2.pt")
    ag_pdb_dir = os.path.join(proj_dir, "data/asep/antigen/atmseq2surf")
    ab_pdb_dir = os.path.join(proj_dir, "data/asep/antibody/atmseq2cdr")
    output_path = os.path.join(proj_dir, "data/asep/m3epi/epiformer_dataset_test.pkl")
    antiberty_path = os.path.join(proj_dir, "data/asep/antibody/antiberty_embeddings/asep_antiberty_embeddings.pt")
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Initialize dataset creator
    dataset_creator = EpiformerDatasetCreator(
        asep_data_path=asep_data_path,
        ag_pdb_dir=ag_pdb_dir,
        ab_pdb_dir=ab_pdb_dir,
        output_path=output_path,
        antiberty_path=antiberty_path
    )
    
    # Create dataset (limit to 100 examples for testing)
    print("Creating epiformer dataset...")
    dataset = dataset_creator.create_dataset(max_examples=15)
    
    # Analyze dataset
    print("\nAnalyzing dataset...")
    stats = get_dataset_statistics(dataset)
    print_dataset_statistics(stats)
    
    # Test loading the saved dataset
    print(f"\nTesting dataset loading...")
    loaded_dataset = load_epiformer_dataset(output_path)
    print(f"Successfully loaded {len(loaded_dataset)} complexes")
    
    # Print example complex structure
    if loaded_dataset:
        print("\nExample complex structure:")
        example = loaded_dataset[0]
        print(f"Complex ID: {example.complex_id}")
        print(f"Node types: {example.node_types}")
        print(f"Edge types: {example.edge_types}")
        print(f"Antigen residue features: {example['ag_res'].x.shape}")
        print(f"Antibody residue features: {example['ab_res'].x.shape}")

    # Run tests
    print("\n" + "="*60)
    print("RUNNING TESTS")
    print("="*60)
    
    # Validate dataset integrity
    print("\nValidating dataset...")
    issues = validate_dataset_integrity(loaded_dataset)
    
    print("DATASET CREATION COMPLETE!")
    
    print(f"\nDataset saved to: {output_path}")
    print(f"Total complexes: {len(loaded_dataset)}")






"""
NOTE:
- two complexes are excluded from the asep data:
1. 5nj6_0P.pdb 
2. 5ies_0P.pdb because its seqres2cdr_seq and atmseq2cdr have different lengths

"""

"""
example usage:

python data/graph_construction.py

nohup python data/graph_construction.py \
    > graph_construct_output.log 2>&1 &

"""







"""

# test
def validate_heterodata(data):
    for chain in ['ag', 'ab']:
        res_key = f'{chain}_res'
        if res_key in data:
            for rel, edges in data[res_key].edge_lists.items():
                if isinstance(edges, torch.Tensor) and edges.nelement() == 0:
                    print(f"Warning: Empty edge list in {data['complex_id']} {res_key} relation {rel}")
                    # Add default self-loop edges
                    num_nodes = data[res_key].x.shape[0]
                    data[res_key].edge_lists[rel] = torch.stack([
                        torch.arange(num_nodes),
                        torch.arange(num_nodes)
                    ])

"""


