import pandas as pd
import numpy as np
import torch
import open3d as o3d
from Bio.PDB import PDBParser
from rdkit import Chem
from rdkit.Chem import AllChem
from scipy.spatial import cKDTree
import os
from pathlib import Path
import random
import logging
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

from rdkit import rdBase
rdBase.DisableLog('rdApp.*')

R = 8.314 / 1000  # kJ/mol/K
T = 298.15
PROBE_RADIUS = 1.4
N_POINTS_RECEPTOR = 5000
N_POINTS_LIGAND = 500
K_NEIGHBORS = 10
CHEMICAL_RADIUS = 5.0

def load_pdb(pdb_file, chain_id=None):
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure('protein', pdb_file)
        coords, atom_types = [], []
        for model in structure:
            for chain in model:
                if chain_id is None or chain.id == chain_id:
                    for residue in chain:
                        for atom in residue:
                            coords.append(atom.get_coord())
                            at = atom.element
                            atom_types.append(at if at in ['C','H','O','N','S','SE'] else 'C')
        coords = np.array(coords)
        atom_types = np.array(atom_types)
        if len(coords)==0 or len(coords)!=len(atom_types):
            return None, None
        return coords, atom_types
    except Exception as e:
        logger.error(f"Failed to load PDB {pdb_file}: {e}")
        return None, None

def load_sdf(sdf_file):
    try:
        mol = Chem.MolFromMolFile(sdf_file, sanitize=True)
        if mol is None:
            raise ValueError(f"Load sdf failed {sdf_file}")
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, randomSeed=42)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        conf = mol.GetConformer()
        coords = conf.GetPositions()
        atom_types = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
        atom_types = np.array([t if t in ['C','H','O','N','S','SE'] else 'C' for t in atom_types])
        if len(coords)==0 or len(coords)!=len(atom_types):
            raise ValueError(f"Invalid coords or atom_types in {sdf_file}")
        return coords, atom_types, mol
    except Exception as e:
        logger.error(f"Failed to load SDF {sdf_file}: {e}")
        return None, None, None

def compute_molecular_surface(coords, probe_radius=PROBE_RADIUS, n_points=N_POINTS_RECEPTOR):
    try:
        if len(coords)==0:
            return None, None
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(coords)
        pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=probe_radius*2, max_nn=30))
        mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector([probe_radius, probe_radius*2]))
        surface_pcd = mesh.sample_points_uniformly(number_of_points=n_points)
        points = np.asarray(surface_pcd.points)
        normals = np.asarray(surface_pcd.normals)
        return points, normals
    except Exception as e:
        logger.error(f"Failed to compute molecular surface: {e}")
        return None, None

def compute_chemical_features(coords, atom_types, mol=None, ligand_type='small_molecule'):
    if len(coords)==0 or len(atom_types)==0 or len(coords)!=len(atom_types):
        return np.zeros((0,3))
    tree = cKDTree(coords)
    features = np.zeros((len(coords),3))
    if ligand_type=='protein':
        for i, c in enumerate(coords):
            idx = tree.query_ball_point(c, r=CHEMICAL_RADIUS)
            nearby = atom_types[idx] if idx else np.array([])
            if len(nearby)>0:
                features[i,0] = np.mean([1 if a in ['O','N'] else -1 if a in ['C','S'] else 0 for a in nearby])
                features[i,1] = np.mean([1 if a in ['O','N'] else 0 for a in nearby])
                features[i,2] = np.mean([1 if a=='C' else 0 for a in nearby])
    else:
        for i, c in enumerate(coords):
            idx = tree.query_ball_point(c, r=CHEMICAL_RADIUS)
            nearby = atom_types[idx] if idx else np.array([])
            if len(nearby)>0:
                features[i,0] = np.mean([1 if a in ['N','O'] else -1 if a in ['C','S'] else 0 for a in nearby])
                features[i,1] = np.mean([1 if a in ['O','N','S'] else 0 for a in nearby])
                features[i,2] = 1 if mol and i < mol.GetNumAtoms() and mol.GetAtomWithIdx(i).GetIsAromatic() else 0
    return features

def normalize_chemical_features(features, ligand_type='small_molecule'):
    if len(features)==0:
        return features
    if ligand_type=='protein':
        features[:,0] = np.clip(features[:,0], -1, 1)
        features[:,1] = np.clip(features[:,1], 0, 1)
        features[:,2] = np.clip(features[:,2], 0, 1)
    else:
        features[:,0] = np.clip(features[:,0], -1, 1)
        features[:,1] = np.clip(features[:,1], 0, 1)
        features[:,2] = np.clip(features[:,2], 0, 1)
    return features

def compute_atom_type_features(atom_types):
    if len(atom_types)==0:
        return np.zeros((0,6))
    amap = {'C':0, 'H':1, 'O':2, 'N':3, 'S':4, 'SE':5}
    features = np.zeros((len(atom_types),6))
    for i,a in enumerate(atom_types):
        if a in amap:
            features[i, amap[a]] = 1
    return features

def compute_local_density(coords):
    if len(coords)==0:
        return np.zeros((0,))
    tree = cKDTree(coords)
    density = np.zeros(len(coords))
    for i,c in enumerate(coords):
        _, idx = tree.query(c, k=K_NEIGHBORS+1)
        dists = np.linalg.norm(coords[idx[1:]] - c, axis=1)
        density[i] = np.mean(dists) if len(dists)>0 else 0
    density = (density - np.min(density)) / (np.max(density)-np.min(density)+1e-8)
    return density

def compute_curvature(points, normals, k=K_NEIGHBORS):
    if len(points)==0 or len(normals)==0:
        return np.zeros((len(points), 2))
    tree = cKDTree(points)
    curvatures = np.zeros((len(points), 2))
    for i, p in enumerate(points):
        _, idx = tree.query(p, k=k+1)
        neighbors = points[idx[1:]]
        neighbor_normals = normals[idx[1:]]
        if len(neighbors)>0:
            rel_pos = neighbors - p
            dot = np.sum(normals[i] * neighbor_normals, axis=1)
            curvatures[i,0] = np.mean(np.arccos(np.clip(dot, -1,1))) / np.mean(np.linalg.norm(rel_pos, axis=1))
            curvatures[i,1] = np.std(np.arccos(np.clip(dot, -1,1))) / np.mean(np.linalg.norm(rel_pos, axis=1))
    return curvatures

def interpolate_atom_features_to_surface(atom_coords, atom_features, surface_points, k=3):
    if len(atom_coords)==0 or len(atom_features)==0 or len(surface_points)==0:
        return np.zeros((len(surface_points), atom_features.shape[1] if atom_features.size else 3))
    tree = cKDTree(atom_coords)
    dists, idxs = tree.query(surface_points, k=k)  # 删除 n_jobs 参数
    weights = 1 / (dists + 1e-8)
    weights /= np.sum(weights, axis=1, keepdims=True)
    interpolated = np.sum(atom_features[idxs] * weights[:,:,None], axis=1)
    return interpolated


def compute_iface_labels(receptor_points, ligand_points, threshold=5.0):
    if len(receptor_points)==0 or len(ligand_points)==0:
        return np.zeros(0)
    tree = cKDTree(ligand_points)
    labels = np.zeros(len(receptor_points))
    for i, p in enumerate(receptor_points):
        dist, _ = tree.query(p)
        if dist < threshold:
            labels[i] = 1
    return labels

def compute_delta_g(kd):
    return R * T * np.log(kd) if kd > 0 else np.nan

def process_pdbbind(data_dir, output_dir):
    index_file = os.path.join(data_dir, 'index', 'INDEX_refined_data.2020')
    if not os.path.exists(index_file):
        logger.error(f"Index file not found at {index_file}")
        return

    index_df = pd.read_csv(index_file, sep='\s+', comment='#',
                           names=['pdb_id', 'resolution', 'year', 'affinity', 'affinity_type', 'reference', 'ligand_name'],
                           dtype={'pdb_id':str}, usecols=[0,1,2,3,4,5,6])

    # 解析Kd/Ki及数值，注意单位转换（假设原始单位是nM或uM，这里自行调整）
    def parse_affinity(row):
        try:
            if pd.isna(row['affinity']):
                return np.nan
            s = str(row['affinity_type'])
            val = float(row['affinity'])
            if 'nM' in s:
                val *= 1e-9
            elif 'uM' in s or 'μM' in s:
                val *= 1e-6
            elif 'mM' in s:
                val *= 1e-3
            else:
                val = val
            return val
        except:
            return np.nan

    index_df['kd_value'] = index_df.apply(parse_affinity, axis=1)

    train_ids = random.sample(list(index_df['pdb_id']), int(0.8*len(index_df)))
    test_ids = [pid for pid in index_df['pdb_id'] if pid not in train_ids]

    os.makedirs(os.path.join(output_dir,'train'), exist_ok=True)
    os.makedirs(os.path.join(output_dir,'test'), exist_ok=True)

    train_pairs = []
    test_pairs = []

    for _, row in tqdm(index_df.iterrows(), total=len(index_df), desc="Processing PDBbind"):
        pdb_id = str(row['pdb_id'])
        kd = row['kd_value']
        delta_g = compute_delta_g(kd) if not np.isnan(kd) else np.nan

        pdb_dir = os.path.join(data_dir, pdb_id)
        if not os.path.exists(pdb_dir):
            logger.warning(f"Directory not found for {pdb_id}: {pdb_dir}")
            continue

        pdb_file = os.path.join(pdb_dir, f"{pdb_id}_protein.pdb")
        pocket_file = os.path.join(pdb_dir, f"{pdb_id}_pocket.pdb")
        sdf_file = os.path.join(pdb_dir, f"{pdb_id}_ligand.sdf")
        mol2_file = os.path.join(pdb_dir, f"{pdb_id}_ligand.mol2")

        try:
            receptor_coords, receptor_atom_types = load_pdb(pdb_file)
            if receptor_coords is None:
                receptor_coords, receptor_atom_types = load_pdb(pocket_file)
                if receptor_coords is None:
                    logger.error(f"Failed to load receptor for {pdb_id}")
                    continue

            ligand_coords, ligand_atom_types, ligand_mol = load_sdf(sdf_file)
            if ligand_coords is None:
                ligand_coords, ligand_atom_types, ligand_mol = load_sdf(mol2_file)
                if ligand_coords is None:
                    logger.error(f"Failed to load ligand for {pdb_id}")
                    continue

            logger.info(f"Loaded {pdb_id}: receptor_coords shape={receptor_coords.shape}, ligand_coords shape={ligand_coords.shape}")

            receptor_points, receptor_normals = compute_molecular_surface(receptor_coords, n_points=N_POINTS_RECEPTOR)
            ligand_points, ligand_normals = compute_molecular_surface(ligand_coords, n_points=N_POINTS_LIGAND)

            if receptor_points is None or ligand_points is None:
                logger.error(f"Failed to compute molecular surface for {pdb_id}")
                continue

            # --- 原子层特征计算 ---
            receptor_chem_atom = compute_chemical_features(receptor_coords, receptor_atom_types, ligand_type='protein')
            receptor_atom_atom = compute_atom_type_features(receptor_atom_types)
            ligand_chem_atom = compute_chemical_features(ligand_coords, ligand_atom_types, ligand_mol, ligand_type='small_molecule')
            ligand_atom_atom = compute_atom_type_features(ligand_atom_types)

            # --- 插值到表面点 ---
            receptor_chem_features = interpolate_atom_features_to_surface(receptor_coords, receptor_chem_atom, receptor_points)
            receptor_atom_features = interpolate_atom_features_to_surface(receptor_coords, receptor_atom_atom, receptor_points)
            ligand_chem_features = interpolate_atom_features_to_surface(ligand_coords, ligand_chem_atom, ligand_points)
            ligand_atom_features = interpolate_atom_features_to_surface(ligand_coords, ligand_atom_atom, ligand_points)

            # --- 计算表面点几何特征 ---
            receptor_curvatures = compute_curvature(receptor_points, receptor_normals)
            ligand_curvatures = compute_curvature(ligand_points, ligand_normals)
            receptor_density = compute_local_density(receptor_points)
            ligand_density = compute_local_density(ligand_points)

            # --- 特征拼接 ---
            receptor_features = np.hstack([
                receptor_chem_features,
                receptor_atom_features,
                receptor_curvatures,
                receptor_density[:, None]
            ])
            ligand_features = np.hstack([
                ligand_chem_features,
                ligand_atom_features,
                ligand_curvatures,
                ligand_density[:, None]
            ])

            # --- 计算界面标签 ---
            receptor_iface = compute_iface_labels(receptor_points, ligand_points)
            ligand_iface = compute_iface_labels(ligand_points, receptor_points)

            # --- 保存数据 ---
            output_subdir = 'train' if pdb_id in train_ids else 'test'
            torch.save({
                'points': torch.tensor(receptor_points, dtype=torch.float32),
                'normals': torch.tensor(receptor_normals, dtype=torch.float32),
                'features': torch.tensor(receptor_features, dtype=torch.float32),
                'iface_labels': torch.tensor(receptor_iface, dtype=torch.float32),
                'delta_g': torch.tensor(delta_g, dtype=torch.float32),
                'pdb_id': pdb_id,
                'probe_radius': PROBE_RADIUS
            }, os.path.join(output_dir, output_subdir, f"{pdb_id}_receptor.pt"))

            torch.save({
                'points': torch.tensor(ligand_points, dtype=torch.float32),
                'normals': torch.tensor(ligand_normals, dtype=torch.float32),
                'features': torch.tensor(ligand_features, dtype=torch.float32),
                'iface_labels': torch.tensor(ligand_iface, dtype=torch.float32),
                'delta_g': torch.tensor(delta_g, dtype=torch.float32),
                'pdb_id': pdb_id,
                'probe_radius': PROBE_RADIUS,
                'smiles': Chem.MolToSmiles(ligand_mol) if ligand_mol else None
            }, os.path.join(output_dir, output_subdir, f"{pdb_id}_ligand.pt"))

            pair_entry = {'pdb_id': pdb_id, 'ligand_smiles': Chem.MolToSmiles(ligand_mol), 'delta_g': delta_g}
            if pdb_id in train_ids:
                train_pairs.append(pair_entry)
            else:
                test_pairs.append(pair_entry)

        except Exception as e:
            logger.error(f"Error processing {pdb_id}: {e}")
            continue

    pd.DataFrame(train_pairs).to_csv(os.path.join(output_dir, 'train_pairs.csv'), index=False)
    pd.DataFrame(test_pairs).to_csv(os.path.join(output_dir, 'test_pairs.csv'), index=False)

if __name__ == "__main__":
    data_dir = ""
    output_dir = ""
    process_pdbbind(data_dir, output_dir)
