import pandas as pd
import numpy as np
import torch
import open3d as o3d
from Bio.PDB import PDBParser
from rdkit import Chem
from rdkit.Chem import AllChem
from scipy.spatial import cKDTree
import os
from pathlib import Path

# Constants
R = 8.314 / 1000  # Gas constant in kJ/mol/K
T = 298.15  # Temperature in Kelvin
PROBE_RADIUS = 1.4
N_POINTS = 5000
K_NEIGHBORS = 10
CHEMICAL_RADIUS = 5.0

def load_pdb(pdb_file, receptor_chains=None, ligand_chains=None):
    """Load atomic coordinates and types from PDB file."""
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    coords, atom_types = [], []
    for model in structure:
        for chain in model:
            chain_id = chain.id
            if (receptor_chains and chain_id in receptor_chains) or (ligand_chains and chain_id in ligand_chains):
                for residue in chain:
                    for atom in residue:
                        coords.append(atom.get_coord())
                        atom_type = atom.element
                        atom_types.append(atom_type if atom_type in ['C', 'H', 'O', 'N', 'S', 'SE'] else 'C')
    return np.array(coords), np.array(atom_types)

def load_sdf(sdf_file):
    """Load atomic coordinates and types from SDF file for small molecules."""
    mol = Chem.MolFromMolFile(sdf_file, sanitize=True)
    if mol is None:
        raise ValueError(f"Failed to load SDF file: {sdf_file}")
    conf = mol.GetConformer()
    coords = conf.GetPositions()
    atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()]
    atom_types = [t if t in ['C', 'H', 'O', 'N', 'S', 'SE'] else 'C' for t in atom_types]
    return np.array(coords), np.array(atom_types)

def compute_molecular_surface(coords, probe_radius=PROBE_RADIUS):
    """Generate molecular surface point cloud using Open3D."""
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=probe_radius * 2, max_nn=30))
    mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector([probe_radius, probe_radius * 2]))
    surface_pcd = mesh.sample_points_uniformly(number_of_points=N_POINTS)
    points = np.asarray(surface_pcd.points)
    normals = np.asarray(surface_pcd.normals)
    return points, normals

def farthest_point_sampling(points, n_points=N_POINTS):
    """Sample points uniformly using farthest point sampling."""
    if len(points) <= n_points:
        return points
    idx = []
    idx.append(np.random.randint(0, len(points)))
    dists = np.full(len(points), np.inf)
    for _ in range(n_points - 1):
        last_idx = idx[-1]
        dists = np.minimum(dists, np.linalg.norm(points - points[last_idx], axis=1))
        idx.append(np.argmax(dists))
    return points[idx]

def interpolate_features(points, normals, features, target_n=N_POINTS):
    """Interpolate to target number of points."""
    if len(points) >= target_n:
        return points, normals, features
    tree = cKDTree(points)
    new_points = []
    new_normals = []
    new_features = []
    for _ in range(target_n - len(points)):
        i = np.random.randint(0, len(points))
        _, idx = tree.query(points[i], k=2)
        w = np.random.random()
        new_point = w * points[idx[0]] + (1 - w) * points[idx[1]]
        new_normal = w * normals[idx[0]] + (1 - w) * normals[idx[1]]
        new_feature = w * features[idx[0]] + (1 - w) * features[idx[1]]
        new_points.append(new_point)
        new_normals.append(new_normal)
        new_features.append(new_feature)
    return (np.vstack([points, new_points]), 
            np.vstack([normals, new_normals]), 
            np.vstack([features, new_features]))

def compute_chemical_features(coords, atom_types, ligand_type='protein'):
    """Compute chemical features based on ligand type."""
    tree = cKDTree(coords)
    features = np.zeros((len(coords), 3))  # 3D chemical features
    if ligand_type == 'protein':
        # Protein ligand: polarity, hydrogen bonding, hydrophobicity
        for i, coord in enumerate(coords):
            idx = tree.query_ball_point(coord, r=CHEMICAL_RADIUS)
            nearby_atoms = atom_types[idx]
            features[i, 0] = np.mean([1 if a in ['O', 'N'] else -1 if a in ['C', 'S'] else 0 for a in nearby_atoms])  # Polarity
            features[i, 1] = np.mean([1 if a in ['O', 'N'] else 0 for a in nearby_atoms])  # H-bond
            features[i, 2] = np.mean([1 if a == 'C' else 0 for a in nearby_atoms])  # Hydrophobicity
    else:
        # Small molecule: charge, electrophilicity, aromaticity
        mol = Chem.MolFromMolFile(coords, sanitize=True) if isinstance(coords, str) else None
        for i, coord in enumerate(coords):
            idx = tree.query_ball_point(coord, r=CHEMICAL_RADIUS)
            nearby_atoms = atom_types[idx]
            features[i, 0] = np.mean([1 if a in ['N', 'O'] else -1 if a in ['C', 'S'] else 0 for a in nearby_atoms])  # Charge
            features[i, 1] = np.mean([1 if a in ['O', 'N', 'S'] else 0 for a in nearby_atoms])  # Electrophilicity
            features[i, 2] = 1 if mol and mol.GetAtomWithIdx(i).GetIsAromatic() else 0  # Aromaticity
    return features

def normalize_chemical_features(features, ligand_type='protein'):
    """Normalize chemical features."""
    if ligand_type == 'protein':
        features[:, 0] = np.clip(features[:, 0], -1, 1)  # Polarity
        features[:, 1] = np.clip(features[:, 1], 0, 1)  # H-bond
        features[:, 2] = np.clip(features[:, 2], 0, 1)  # Hydrophobicity
    else:
        features[:, 0] = np.clip(features[:, 0], -1, 1)  # Charge
        features[:, 1] = np.clip(features[:, 1], 0, 1)  # Electrophilicity
        features[:, 2] = np.clip(features[:, 2], 0, 1)  # Aromaticity
    return features

def compute_local_density(coords):
    """Compute local density as a 1D geometric feature based on KNN neighbors."""
    tree = cKDTree(coords)
    densities = np.zeros(len(coords))
    for i, coord in enumerate(coords):
        _, idx = tree.query(coord, k=K_NEIGHBORS + 1)  # Include self
        distances = np.linalg.norm(coords[idx[1:]] - coord, axis=1)  # Exclude self
        density[i] = np.mean(distances) if len(distances) > 0 else 0
    # Normalize density to [0, 1]
    density = (density - np.min(density)) / (np.max(density) - np.min(density) + 1e-8)
    return density

def compute_curvature(points, normals, k=K_NEIGHBORS):
    """Compute mean and Gaussian curvature."""
    tree = cKDTree(points)
    curvatures = np.zeros((len(points), 2))
    for i, point in enumerate(points):
        _, idx = tree.query(point, k=k + 1)
        neighbors = points[idx[1:]]
        neighbor_normals = normals[idx[1:]]
        if len(neighbors) > 0:
            relative_pos = neighbors - point
            dot_products = np.sum(normals[i] * neighbor_normals, axis=1)
            curvatures[i, 0] = np.mean(np.arccos(np.clip(dot_products, -1, 1))) / np.mean(np.linalg.norm(relative_pos, axis=1))  # Mean curvature
            curvatures[i, 1] = np.std(np.arccos(np.clip(dot_products, -1, 1))) / np.mean(np.linalg.norm(relative_pos, axis=1))  # Gaussian curvature
    return curvatures

def compute_atom_type_features(atom_types):
    """Encode atom types as 6D features."""
    atom_map = {'C': 0, 'H': 1, 'O': 2, 'N': 3, 'S': 4, 'Cl': 5}
    features = np.zeros((len(atom_types), 6))
    for i, atom in enumerate(atom_types):
        if atom in atom_map:
            features[i, atom_map[atom]] = 1
    return features

def compute_iface_labels(receptor_points, ligand_points, threshold=5.0):
    """Label points as interface if within threshold distance."""
    tree = cKDTree(ligand_points)
    labels = np.zeros(len(receptor_points))
    for i, point in enumerate(receptor_points):
        dist, _ = tree.query(point)
        if dist < threshold:
            labels[i] = 1
    return labels

def compute_delta_g(kd):
    """Convert dissociation constant to binding free energy."""
    return R * T * np.log(kd) if kd > 0 else np.nan

def process_skempi(data_dir, output_dir, ligand_type='protein'):
    """Process SKEMPI v2.0 dataset."""
    skempi_df = pd.read_excel(os.path.join(data_dir, 'skempi_v2.xlsx'))
    os.makedirs(output_dir, exist_ok=True)
    
    for _, row in skempi_df.iterrows():
        pdb_id = row['PDB']
        receptor_chains = row['Receptor_Chains'].split(',')
        ligand_chains = row['Ligand_Chains'].split(',')
        kd = row['Kd_(nM)'] * 1e-9 if 'Kd_(nM)' in row else np.nan
        delta_g = compute_delta_g(kd)
        
        pdb_file = os.path.join(data_dir, f"{pdb_id}.pdb")
        sdf_file = os.path.join(data_dir, f"{pdb_id}_ligand.sdf") if ligand_type == 'small_molecule' else None
        
        try:
            if ligand_type == 'protein':
                receptor_coords, receptor_atom_types = load_pdb(pdb_file, receptor_chains=receptor_chains)
                ligand_coords, ligand_atom_types = load_pdb(pdb_file, ligand_chains=ligand_chains)
            else:
                receptor_coords, receptor_atom_types = load_pdb(pdb_file, receptor_chains=receptor_chains)
                ligand_coords, ligand_atom_types = load_sdf(sdf_file)
            
            # Generate point clouds
            receptor_points, receptor_normals = compute_molecular_surface(receptor_coords)
            ligand_points, ligand_normals = compute_molecular_surface(ligand_coords)
            
            if len(receptor_points) > N_POINTS:
                receptor_points = farthest_point_sampling(receptor_points)
                receptor_normals = receptor_normals[:N_POINTS]
            if len(ligand_points) > N_POINTS:
                ligand_points = farthest_point_sampling(ligand_points)
                ligand_normals = ligand_normals[:N_POINTS]
            
            # Compute features
            receptor_chem_features = normalize_chemical_features(compute_chemical_features(receptor_coords, receptor_atom_types, ligand_type), ligand_type)
            ligand_chem_features = normalize_chemical_features(compute_chemical_features(ligand_coords, ligand_atom_types, ligand_type), ligand_type)
            receptor_density = compute_local_density(receptor_points)
            ligand_density = compute_local_density(ligand_points)
            receptor_curvatures = compute_curvature(receptor_points, receptor_normals)
            ligand_curvatures = compute_curvature(ligand_points, ligand_normals)
            receptor_atom_features = compute_atom_type_features(receptor_atom_types)
            ligand_atom_features = compute_atom_type_features(ligand_atom_types)
            
            # Combine features (12D: 3 chem + 6 atom + 2 curvature + 1 density)
            receptor_features = np.hstack([receptor_chem_features[:N_POINTS], receptor_atom_features[:N_POINTS], receptor_curvatures, receptor_density[:N_POINTS, None]])
            ligand_features = np.hstack([ligand_chem_features[:N_POINTS], ligand_atom_features[:N_POINTS], ligand_curvatures, ligand_density[:N_POINTS, None]])
            
            # Interpolate if needed
            receptor_points, receptor_normals, receptor_features = interpolate_features(receptor_points, receptor_normals, receptor_features)
            ligand_points, ligand_normals, ligand_features = interpolate_features(ligand_points, ligand_normals, ligand_features)
            
            # Compute interface labels
            receptor_iface = compute_iface_labels(receptor_points, ligand_points)
            ligand_iface = compute_iface_labels(ligand_points, receptor_points)
            
            # Save data
            torch.save({
                'points': torch.tensor(receptor_points, dtype=torch.float32),
                'normals': torch.tensor(receptor_normals, dtype=torch.float32),
                'features': torch.tensor(receptor_features, dtype=torch.float32),
                'iface_labels': torch.tensor(receptor_iface, dtype=torch.float32),
                'delta_g': torch.tensor(delta_g, dtype=torch.float32),
                'protein_id': pdb_id,
                'probe_radius': PROBE_RADIUS
            }, os.path.join(output_dir, f"{pdb_id}_receptor.pt"))
            
            torch.save({
                'points': torch.tensor(ligand_points, dtype=torch.float32),
                'normals': torch.tensor(ligand_normals, dtype=torch.float32),
                'features': torch.tensor(ligand_features, dtype=torch.float32),
                'iface_labels': torch.tensor(ligand_iface, dtype=torch.float32),
                'delta_g': torch.tensor(delta_g, dtype=torch.float32),
                'protein_id': pdb_id,
                'probe_radius': PROBE_RADIUS
            }, os.path.join(output_dir, f"{pdb_id}_ligand.pt"))
            
        except Exception as e:
            print(f"Error processing {pdb_id}: {e}")
            continue

if __name__ == "__main__":
    data_dir = "skempi_v2"
    output_dir = "processed_data"
    process_skempi(data_dir, os.path.join(output_dir, "train_ppi"), ligand_type='protein')
    process_skempi(data_dir, os.path.join(output_dir, "test_ppi"), ligand_type='small_molecule')