import os
import pandas as pd
import numpy as np
import torch
from Bio.PDB import PDBParser
from scipy.spatial import cKDTree
import warnings
import tqdm
from pathlib import Path

warnings.filterwarnings("ignore")

# 数据路径
PDB_DIR = ""
EXCEL_FILE = ""
OUTPUT_DIR = ""
TRAIN_DIR = os.path.join(OUTPUT_DIR, "train_ppi")
TEST_DIR = os.path.join(OUTPUT_DIR, "test_ppi")
os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)

# 常量
NUM_POINTS = 5000  # 点云点数
PROBE_RADIUS = 1.4  # 探针半径 (Å)
DISTANCE_THRESHOLD = 5.0  # 界面标签距离阈值 (Å)
R = 8.314 / 4184  # 气体常数 (kcal/mol/K)
T = 298.15  # 温度 (K)

def load_pdb(file_path, ligand_chains, receptor_chains):
    """从 PDB 文件加载指定链的原子坐标和类型"""
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", file_path)
        atoms_ligand = []
        atom_types_ligand = []
        atoms_receptor = []
        atom_types_receptor = []
        ele2num = {"C": 0, "H": 1, "O": 2, "N": 3, "S": 4, "SE": 5}
        for model in structure:
            for chain in model:
                if chain.id in ligand_chains or chain.id in receptor_chains:
                    for residue in chain:
                        for atom in residue:
                            if atom.element in ele2num:
                                if chain.id in ligand_chains:
                                    atoms_ligand.append(atom.get_coord())
                                    atom_types_ligand.append(ele2num[atom.element])
                                elif chain.id in receptor_chains:
                                    atoms_receptor.append(atom.get_coord())
                                    atom_types_receptor.append(ele2num[atom.element])
        atoms_ligand = np.array(atoms_ligand, dtype=np.float32)
        atom_types_ligand = np.array(atom_types_ligand, dtype=np.int32)
        atoms_receptor = np.array(atoms_receptor, dtype=np.float32)
        atom_types_receptor = np.array(atom_types_receptor, dtype=np.int32)
        unique_types = sorted(set(ele2num.values()))
        return (atoms_ligand, atom_types_ligand), (atoms_receptor, atom_types_receptor), unique_types
    except Exception as e:
        print(f"错误：无法解析 PDB 文件 {file_path}: {e}")
        return None, None, None

def farthest_point_sampling(points, num_points):
    """使用 Farthest Point Sampling 采样点云"""
    if len(points) <= num_points:
        return np.arange(len(points))
    
    n_points = len(points)
    selected_indices = [np.random.randint(n_points)]
    distances = np.full(n_points, np.inf)
    for i in range(n_points):
        distances[i] = np.linalg.norm(points[i] - points[selected_indices[0]])
    
    for _ in range(num_points - 1):
        farthest_idx = np.argmax(distances)
        selected_indices.append(farthest_idx)
        distances = np.minimum(distances, np.linalg.norm(points - points[farthest_idx], axis=1))
    
    return np.array(selected_indices)

def interpolate_features(points, features, target_num_points):
    """使用 KNN 插值生成新点的特征"""
    current_num_points = points.shape[0]
    if current_num_points >= target_num_points:
        return features
    kdtree = cKDTree(points)
    new_points = points[np.random.choice(current_num_points, target_num_points - current_num_points)]
    new_points += np.random.randn(new_points.shape[0], 3) * points.std() * 0.1
    indices = kdtree.query(new_points, k=3)[1]
    new_features = np.mean(features[indices], axis=1)
    return np.concatenate([features, new_features], axis=0)

def compute_molecular_surface(atoms, probe_radius=1.4, num_points=5000):
    """生成分子表面点云和法向量"""
    try:
        import open3d as o3d
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(atoms)
        pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=probe_radius * 2, max_nn=30))
        mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
            pcd, o3d.utility.DoubleVector([probe_radius, probe_radius * 2])
        )
        pcd_surface = mesh.sample_points_uniformly(number_of_points=num_points)
        points = np.asarray(pcd_surface.points, dtype=np.float32)
        normals = np.asarray(pcd_surface.normals, dtype=np.float32)
        norms = np.linalg.norm(normals, axis=1, keepdims=True)
        normals = normals / np.maximum(norms, 1e-6)
        if len(points) > num_points:
            indices = farthest_point_sampling(points, num_points)
            points = points[indices]
            normals = normals[indices]
        elif len(points) < num_points:
            indices = np.random.choice(len(points), num_points - len(points))
            points = np.concatenate([points, points[indices]], axis=0)
            normals = np.concatenate([normals, normals[indices]], axis=0)
        return points, normals
    except Exception as e:
        print(f"错误：无法生成分子表面: {e}")
        return None, None

def compute_curvature(points, normals, k=10):
    """计算点云的平均曲率和高斯曲率"""
    if np.any(np.isnan(points)) or np.any(np.isinf(points)):
        print("警告：点云包含 NaN 或 Inf，替换为零")
        points = np.nan_to_num(points, nan=0.0, posinf=0.0, neginf=0.0)
    if np.any(np.isnan(normals)) or np.any(np.isinf(normals)):
        print("警告：法向量包含 NaN 或 Inf，替换为零")
        normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0)
    
    kdtree = cKDTree(points)
    curvatures = np.zeros((len(points), 2), dtype=np.float32)
    for i in range(len(points)):
        indices = kdtree.query_ball_point(points[i], r=5.0)
        if len(indices) > 1:
            indices = indices[1:]
            neighbors = points[indices]
            neighbor_normals = normals[indices]
            normal_diff = neighbor_normals - normals[i]
            mean_curvature = np.mean(np.linalg.norm(normal_diff, axis=1))
            cov = np.cov(neighbors.T)
            if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
                print(f"警告：点 {i} 的协方差矩阵包含 NaN 或 Inf，设置曲率为零")
                curvatures[i] = [0.0, 0.0]
                continue
            eigenvalues = np.linalg.eigvals(cov).real
            if np.any(np.isnan(eigenvalues)) or np.any(np.isinf(eigenvalues)):
                print(f"警告：点 {i} 的特征值包含 NaN 或 Inf，设置曲率为零")
                curvatures[i] = [0.0, 0.0]
                continue
            eigen_sum = np.sum(eigenvalues[:2])
            gaussian_curvature = np.prod(eigenvalues[:2]) / eigen_sum if eigen_sum > 1e-6 else 0.0
            curvatures[i] = [mean_curvature, gaussian_curvature]
        else:
            print(f"警告：点 {i} 的邻域点不足，设置曲率为零")
            curvatures[i] = [0.0, 0.0]
    return curvatures

def compute_chemical_features(atoms, atom_types, unique_types, surface_points, radius=5.0):
    """计算原子类型特征"""
    try:
        kdtree = cKDTree(atoms)
        num_types = 6  # C, H, O, N, S, SE
        atom_features = np.zeros((len(surface_points), num_types), dtype=np.float32)
        type_to_idx = {t: t for t in unique_types}
        for i, point in enumerate(surface_points):
            indices = kdtree.query_ball_point(point, radius)
            if indices:
                nearby_types = atom_types[indices]
                for t in nearby_types:
                    if t in type_to_idx:
                        atom_features[i, type_to_idx[t]] += 1
                atom_features[i] /= max(1, np.sum(atom_features[i]))
        return atom_features
    except Exception as e:
        print(f"错误：无法计算化学特征: {e}")
        return None

def normalize_chemical_features(atom_features):
    """从原子类型特征派生并归一化化学特征 (pb, hbond, hphob)"""
    pb = np.sum(atom_features[:, [2, 3]], axis=1) - np.sum(atom_features[:, [0, 4, 5]], axis=1)  # O, N 增加电势，C, S, SE 降低
    hbond = np.sum(atom_features[:, [2, 3]], axis=1)  # O, N 贡献氢键
    hphob = np.sum(atom_features[:, [0, 4, 5]], axis=1)  # C, S, SE 贡献疏水性
    
    pb_upper, pb_lower = 3.0, -3.0
    pb = np.clip(pb, pb_lower, pb_upper)
    pb = (pb - pb_lower) / (pb_upper - pb_lower)
    pb = 2 * pb - 1  # [-1, 1]
    hbond = np.clip(hbond, 0.0, 1.0)
    hphob = np.clip(hphob / 4.5, 0.0, 1.0)
    return np.stack([pb, hbond, hphob]).T

def compute_iface_labels(points1, points2, distance_threshold=2.0):
    """计算界面标签"""
    kdtree = cKDTree(points2)
    distances, _ = kdtree.query(points1, k=1)
    iface_labels = (distances < distance_threshold).astype(np.float32)
    return iface_labels

def compute_delta_g(kd):
    """从 KD 计算 ΔG (kcal/mol)"""
    if kd <= 0 or np.isnan(kd):
        return np.nan
    return R * T * np.log(kd)

def process_skempi_v2():
    """处理 SKEMPI v2.0 数据集，仅生成野生型数据"""
    # 读取 Excel 文件
    df = pd.read_excel(EXCEL_FILE)
    
    # 数据划分 (80% 训练，20% 测试)
    train_df = df.sample(frac=0.8, random_state=42)
    test_df = df.drop(train_df.index)
    
    # 总进度条
    total_steps = len(train_df) + len(test_df)
    with tqdm.tqdm(total=total_steps, desc="处理 SKEMPI v2.0 数据集", unit="PDB") as pbar:
        for split, split_df, output_dir in [
            ('train', train_df, TRAIN_DIR),
            ('test', test_df, TEST_DIR)
        ]:
            pbar.set_description(f"处理 {split} 集")
            for _, row in split_df.iterrows():
                pdb_id = row['PDB']
                ligand_chains = row['Ligand Chains'].split(',')
                receptor_chains = row['Receptor Chains'].split(',')
                kd = row['KD(M)']
                
                # 加载 PDB 文件
                pdb_file = os.path.join(PDB_DIR, f"{pdb_id}.pdb")
                if not os.path.exists(pdb_file):
                    print(f"警告: PDB 文件 {pdb_file} 不存在")
                    pbar.update(1)
                    continue
                
                # 解析野生型 PDB
                (atoms_ligand, atom_types_ligand), (atoms_receptor, atom_types_receptor), unique_types = load_pdb(pdb_file, ligand_chains, receptor_chains)
                if atoms_ligand is None or atoms_receptor is None:
                    print(f"警告: PDB {pdb_id} 链数据为空")
                    pbar.update(1)
                    continue
                
                # 子进度条：处理 ligand 和 receptor
                with tqdm.tqdm(total=2, desc=f"处理 {pdb_id} wt", leave=False, unit="chain") as sub_pbar:
                    for role, atoms, atom_types in [
                        ('ligand', atoms_ligand, atom_types_ligand),
                        ('receptor', atoms_receptor, atom_types_receptor)
                    ]:
                        points, normals = compute_molecular_surface(atoms, probe_radius=PROBE_RADIUS, num_points=NUM_POINTS)
                        if points is None:
                            print(f"警告: PDB {pdb_id} {role} wt 点云生成失败")
                            sub_pbar.update(1)
                            continue
                        
                        # 计算化学特征
                        atom_features = compute_chemical_features(atoms, atom_types, unique_types, points, radius=5.0)
                        if atom_features is None:
                            sub_pbar.update(1)
                            continue
                        chem_features = normalize_chemical_features(atom_features)
                        
                        # 计算曲率特征
                        curvature_features = compute_curvature(points, normals, k=10)
                        
                        # 合并特征 [3 (chem) + 6 (atom types) + 2 (curvature) = 11]
                        features = np.concatenate([chem_features, atom_features, curvature_features], axis=1)
                        
                        # 计算界面标签
                        if role == 'ligand':
                            iface_labels = compute_iface_labels(points, atoms_receptor, distance_threshold=DISTANCE_THRESHOLD)
                        else:
                            iface_labels = compute_iface_labels(points, atoms_ligand, distance_threshold=DISTANCE_THRESHOLD)
                        
                        # 计算 ΔG
                        delta_g = compute_delta_g(kd)
                        
                        # 保存为 .pt 文件
                        output_pt = os.path.join(output_dir, f"{pdb_id}_{role}_wt.pt")
                        torch.save({
                            'points': torch.tensor(points, dtype=torch.float32),
                            'normals': torch.tensor(normals, dtype=torch.float32),
                            'features': torch.tensor(features, dtype=torch.float32),
                            'iface_labels': torch.tensor(iface_labels, dtype=torch.float32),
                            'delta_g': torch.tensor(delta_g, dtype=torch.float32),
                            'protein_id': f"{pdb_id}_{role}_wt",
                            'probe_radius': PROBE_RADIUS
                        }, output_pt)
                        print(f"保存 {role} wt .pt 文件: {output_pt}")
                        sub_pbar.update(1)
                
                pbar.update(1)

if __name__ == "__main__":
    process_skempi_v2()