import os
import glob
import numpy as np
import torch
from Bio.PDB import PDBParser
import open3d as o3d
from scipy.spatial import cKDTree
from plyfile import PlyData
import warnings
import tqdm
from pathlib import Path
warnings.filterwarnings("ignore")

def load_pdb(file_path):
    """从 PDB 文件加载原子坐标和类型"""
    try:
        parser = PDBParser()
        structure = parser.get_structure("protein", file_path)
        atoms = []
        atom_types = []
        ele2num = {"C": 0, "H": 1, "O": 2, "N": 3, "S": 4, "SE": 5}
        for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        if atom.element in ele2num:
                            atoms.append(atom.get_coord())
                            atom_types.append(ele2num[atom.element])
        atoms = np.array(atoms, dtype=np.float32)
        unique_types = sorted(set(atom_types))
        type_to_idx = {i: i for i in unique_types}
        atom_types = np.array(atom_types, dtype=np.int32)
        return atoms, atom_types, unique_types
    except Exception as e:
        print(f"错误：无法解析 PDB 文件 {file_path}: {e}")
        return None, None, None

def farthest_point_sampling(points, num_points):
    """使用 Farthest Point Sampling 采样点云（NumPy 实现）"""
    if len(points) <= num_points:
        return np.arange(len(points))
    
    # 初始化
    n_points = len(points)
    selected_indices = [np.random.randint(n_points)]  # 随机选择第一个点
    distances = np.full(n_points, np.inf)  # 到已选点集的最小距离
    
    # 计算初始距离
    for i in range(n_points):
        distances[i] = np.linalg.norm(points[i] - points[selected_indices[0]])
    
    # 迭代选择最远点
    for _ in range(num_points - 1):
        farthest_idx = np.argmax(distances)  # 选择距离最大的点
        selected_indices.append(farthest_idx)
        # 更新所有点到新选点的最小距离
        distances = np.minimum(distances, np.linalg.norm(points - points[farthest_idx], axis=1))
    
    return np.array(selected_indices)

def interpolate_features(points, features, target_num_points):
    """使用 KNN 插值生成新点的特征"""
    current_num_points = points.shape[0]
    if current_num_points >= target_num_points:
        return features
    kdtree = cKDTree(points)
    new_points = points[np.random.choice(current_num_points, target_num_points - current_num_points)]
    new_points += np.random.randn(new_points.shape[0], 3) * points.std() * 0.1
    indices = kdtree.query(new_points, k=3)[1]
    new_features = np.mean(features[indices], axis=1)
    return np.concatenate([features, new_features], axis=0)

def load_ply(file_path, num_points=5000):
    """从 PLY 文件加载点云和特征，并调整点数"""
    try:
        plydata = PlyData.read(str(file_path))
        points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]]).astype(np.float32)
        normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]]).T.astype(np.float32)
        charge = plydata["vertex"]["charge"].astype(np.float32)
        hbond = plydata["vertex"]["hbond"].astype(np.float32)
        hphob = plydata["vertex"]["hphob"].astype(np.float32)
        features = np.stack([charge, hbond, hphob]).T
        if len(points) > num_points:
            indices = farthest_point_sampling(points, num_points)
            points = points[indices]
            normals = normals[indices]
            features = features[indices]
        elif len(points) < num_points:
            features = interpolate_features(points, features, num_points)
            normals = interpolate_features(points, normals, num_points)
            points = np.concatenate([points, points[np.random.choice(len(points), num_points - len(points))]], axis=0)
        return points, normals, features
    except Exception as e:
        print(f"错误：无法解析 PLY 文件 {file_path}: {e}")
        return None, None, None

def compute_molecular_surface(atoms, probe_radius=1.4, num_points=5000):
    try:
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(atoms)
        pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=probe_radius * 2, max_nn=30))
        mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
            pcd, o3d.utility.DoubleVector([probe_radius, probe_radius * 2])
        )
        pcd_surface = mesh.sample_points_uniformly(number_of_points=num_points)
        points = np.asarray(pcd_surface.points, dtype=np.float32)
        normals = np.asarray(pcd_surface.normals, dtype=np.float32)
        norms = np.linalg.norm(normals, axis=1, keepdims=True)
        normals = normals / np.maximum(norms, 1e-6)  # 归一化
        return points, normals
    except Exception as e:
        print(f"错误：无法生成分子表面: {e}")
        return None, None

def compute_curvature(points, normals, k=10):
    """计算点云的平均曲率和高斯曲率（CPU 版本，修复 NaN 问题）"""
    # 检查输入数据
    if np.any(np.isnan(points)) or np.any(np.isinf(points)):
        print("警告：点云包含 NaN 或 Inf，替换为零")
        points = np.nan_to_num(points, nan=0.0, posinf=0.0, neginf=0.0)
    if np.any(np.isnan(normals)) or np.any(np.isinf(normals)):
        print("警告：法向量包含 NaN 或 Inf，替换为零")
        normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0)
    
    kdtree = cKDTree(points)
    curvatures = np.zeros((len(points), 2), dtype=np.float32)  # [mean_curvature, gaussian_curvature]
    for i in range(len(points)):
        # 查询邻域点，半径设为 5.0，限制最大邻域点数
        indices = kdtree.query_ball_point(points[i], r=5.0)
        if len(indices) > 1:  # 确保至少有一个邻域点（排除自身）
            indices = indices[1:]  # 排除自身
            neighbors = points[indices]
            neighbor_normals = normals[indices]
            normal_diff = neighbor_normals - normals[i]
            mean_curvature = np.mean(np.linalg.norm(normal_diff, axis=1))
            cov = np.cov(neighbors.T)
            # 检查协方差矩阵是否有效
            if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
                print(f"警告：点 {i} 的协方差矩阵包含 NaN 或 Inf，设置曲率为零")
                curvatures[i] = [0.0, 0.0]
                continue
            eigenvalues = np.linalg.eigvals(cov).real
            # 检查特征值是否有效
            if np.any(np.isnan(eigenvalues)) or np.any(np.isinf(eigenvalues)):
                print(f"警告：点 {i} 的特征值包含 NaN 或 Inf，设置曲率为零")
                curvatures[i] = [0.0, 0.0]
                continue
            eigen_sum = np.sum(eigenvalues[:2])
            gaussian_curvature = np.prod(eigenvalues[:2]) / eigen_sum if eigen_sum > 1e-6 else 0.0
            curvatures[i] = [mean_curvature, gaussian_curvature]
        else:
            print(f"警告：点 {i} 的邻域点不足，设置曲率为零")
            curvatures[i] = [0.0, 0.0]
    return curvatures

def compute_chemical_features(atoms, atom_types, unique_types, surface_points, radius=5.0):
    try:
        kdtree = cKDTree(atoms)
        num_types = 6  # 固定为 C, H, O, N, S, SE
        atom_features = np.zeros((len(surface_points), num_types), dtype=np.float32)
        type_to_idx = {t: t for t in unique_types}  # 直接使用原子类型编号
        for i, point in enumerate(surface_points):
            indices = kdtree.query_ball_point(point, radius)
            if indices:
                nearby_types = atom_types[indices]
                for t in nearby_types:
                    if t in type_to_idx:
                        atom_features[i, type_to_idx[t]] += 1
                atom_features[i] /= max(1, np.sum(atom_features[i]))
        return atom_features
    except Exception as e:
        print(f"错误：无法计算化学特征: {e}")
        return None

def normalize_chemical_features(features):
    pb_upper, pb_lower = 3.0, -3.0
    pb = np.clip(features[:, 0], pb_lower, pb_upper)
    pb = (pb - pb_lower) / (pb_upper - pb_lower)
    pb = 2 * pb - 1
    hbond = np.clip(features[:, 1], 0.0, 1.0)
    hphob = np.clip(features[:, 2] / 4.5, 0.0, 1.0)
    return np.stack([pb, hbond, hphob]).T

def compute_iface_labels(points1, points2, distance_threshold=8.0):
    """计算界面标签（CPU 版本）"""
    kdtree = cKDTree(points2)
    distances, _ = kdtree.query(points1, k=1)
    iface_labels = (distances < distance_threshold).astype(np.float32)
    return iface_labels

def generate_dmasif_point_cloud(file_path, file_type="pdb", pdb_dir=None, num_points=5000, probe_radius=1.4):
    """生成 dMaSIF 风格的表面点云"""
    base_name = os.path.basename(file_path).replace(".pdb" if file_type == "pdb" else ".ply", "")
    if file_type.lower() == "pdb":
        atoms, atom_types, unique_types = load_pdb(file_path)
        if atoms is None:
            return None
        surface_points, normals = compute_molecular_surface(atoms, probe_radius, num_points)
        if surface_points is None:
            return None
        ply_path = os.path.join(pdb_dir.replace("01-benchmark_pdbs", "01-benchmark_surfaces"), f"{base_name}.ply")
        if os.path.exists(ply_path):
            _, _, ply_features = load_ply(ply_path, num_points=num_points)
            if ply_features is not None:
                features = normalize_chemical_features(ply_features)
            else:
                features = np.zeros((num_points, 3), dtype=np.float32)
        else:
            print(f"警告：未找到对应的 PLY 文件 {ply_path}，使用零特征")
            features = np.zeros((num_points, 3), dtype=np.float32)
    else:
        surface_points, normals, ply_features = load_ply(file_path, num_points=num_points)
        if surface_points is None:
            return None
        features = normalize_chemical_features(ply_features)
        pdb_path = os.path.join(pdb_dir, f"{base_name}.pdb")
        if not os.path.exists(pdb_path):
            print(f"错误：未找到匹配的 PDB 文件 {pdb_path}")
            return None
        atoms, atom_types, unique_types = load_pdb(pdb_path)
        if atoms is None:
            return None

    # 计算原子类型特征
    atom_features = compute_chemical_features(atoms, atom_types, unique_types, surface_points, radius=5.0)
    if atom_features is None:
        return None
    
    # 计算曲率特征
    curvature_features = compute_curvature(surface_points, normals, k=10)
    
    # 合并化学和几何特征
    features = np.concatenate([features, atom_features, curvature_features], axis=1)
    
    # 转换为 PyTorch 张量
    points_tensor = torch.from_numpy(surface_points).float()
    normals_tensor = torch.from_numpy(normals).float()
    features_tensor = torch.from_numpy(features).float()
    
    return {
        "points": points_tensor,  # [5000, 3]
        "normals": normals_tensor,  # [5000, 3]
        "features": features_tensor,  # [5000, 3 + 6 + 2 = 11]
        "protein_id": base_name,
        "probe_radius": probe_radius
    }

def load_split_files(train_file, test_file):
    """加载训练和测试集划分文件"""
    def parse_split_file(file_path):
        pairs = []
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    parts = line.split('_')
                    if len(parts) == 3:
                        protein_id, receptor_chain, ligand_chains = parts
                        pairs.append({
                            'protein_id': protein_id,
                            'receptor': f"{protein_id}_{receptor_chain}",
                            'ligand': f"{protein_id}_{ligand_chains}"
                        })
        return pairs
    train_pairs = parse_split_file(train_file)
    test_pairs = parse_split_file(test_file)
    return train_pairs, test_pairs

def load_and_merge_ligand_points(ligand_id, point_cloud_dir):
    """加载并合并配体点云"""
    chains = ligand_id.split('_')[-1]
    protein_id = '_'.join(ligand_id.split('_')[:-1])
    points_list = []
    for chain in chains:
        pt_file = os.path.join(point_cloud_dir, f"{protein_id}_{chain}.pt")
        if os.path.exists(pt_file):
            data = torch.load(pt_file)
            points_list.append(data['points'].numpy())
        else:
            print(f"警告：未找到点云文件 {pt_file}")
    if not points_list:
        return None
    return np.concatenate(points_list, axis=0)

def generate_and_save_ppi_data(pairs, pdb_dir, ply_dir, output_dir, num_points=5000, probe_radius=1.4, distance_threshold=5.0):
    """为蛋白蛋白质对生成点云和界面标签，并保存到本地"""
    os.makedirs(output_dir, exist_ok=True)
    
    for pair in tqdm.tqdm(pairs, desc="处理蛋白质对"):
        receptor_id = pair['receptor']
        ligand_id = pair['ligand']
        protein_id = pair['protein_id']
        
        # 加载或生成受体点云
        receptor_file = os.path.join(output_dir, f"{receptor_id}.pt")
        if os.path.exists(receptor_file):
            receptor_data = torch.load(receptor_file)
        else:
            receptor_pdb = os.path.join(pdb_dir, f"{receptor_id}.pdb")
            if not os.path.exists(receptor_pdb):
                print(f"错误：未找到受体 PDB 文件 {receptor_pdb}")
                continue
            receptor_data = generate_dmasif_point_cloud(receptor_pdb, file_type="pdb", pdb_dir=pdb_dir, 
                                                      num_points=num_points, probe_radius=probe_radius)
            if receptor_data is None:
                continue
        
        # 加载或生成配体点云
        ligand_points = []
        ligand_data_list = []
        for chain in ligand_id.split('_')[-1]:
            ligand_file = os.path.join(output_dir, f"{protein_id}_{chain}.pt")
            ligand_pdb = os.path.join(pdb_dir, f"{protein_id}_{chain}.pdb")
            if os.path.exists(ligand_file):
                data = torch.load(ligand_file)
            elif os.path.exists(ligand_pdb):
                data = generate_dmasif_point_cloud(ligand_pdb, file_type="pdb", pdb_dir=pdb_dir, 
                                                  num_points=num_points, probe_radius=probe_radius)
                if data is None:
                    continue
            else:
                print(f"错误：未找到配体 PDB 文件 {ligand_pdb}")
                continue
            ligand_points.append(data['points'].numpy())
            ligand_data_list.append(data)
        
        if not ligand_points:
            print(f"错误：无法加载配体点云 {ligand_id}")
            continue
        ligand_points = np.concatenate(ligand_points, axis=0)
        
        # 计算受体界面标签
        receptor_iface_labels = compute_iface_labels(receptor_data['points'].numpy(), ligand_points, distance_threshold=4.0)
        receptor_data['iface_labels'] = torch.from_numpy(receptor_iface_labels).float()
        
        # 保存受体数据
        torch.save(receptor_data, receptor_file)
        print(f"已保存受体数据: {receptor_file}")
        
        # 计算配体界面标签并保存
        for data in ligand_data_list:
            ligand_iface_labels = compute_iface_labels(data['points'].numpy(), receptor_data['points'].numpy(), distance_threshold=4.0)
            data['iface_labels'] = torch.from_numpy(ligand_iface_labels).float()
            ligand_file = os.path.join(output_dir, f"{data['protein_id']}.pt")
            torch.save(data, ligand_file)
            print(f"已保存配体数据: {ligand_file}")

def batch_process_files(pdb_dir, ply_dir, train_output_dir, test_output_dir, train_file, test_file, num_points=5000, probe_radius=1.4, distance_threshold=2.0):
    """批量处理 PDB 和 PLY 文件，生成点云并添加界面标签，分开存储训练和测试集"""
    # 加载训练和测试集划分
    train_pairs, test_pairs = load_split_files(train_file, test_file)
    print(f"训练集: {len(train_pairs)} 个蛋白质对")
    print(f"测试集: {len(test_pairs)} 个蛋白质对")
    
    # 处理训练集
    print("处理训练集...")
    generate_and_save_ppi_data(train_pairs, pdb_dir, ply_dir, train_output_dir, num_points, probe_radius, distance_threshold)
    
    # 处理测试集
    print("处理测试集...")
    generate_and_save_ppi_data(test_pairs, pdb_dir, ply_dir, test_output_dir, num_points, probe_radius, distance_threshold)

if __name__ == "__main__":
    pdb_dir = ""
    ply_dir = ""
    train_output_dir = ""
    test_output_dir = ""
    train_file = ""
    test_file = ""
    
    batch_process_files(pdb_dir, ply_dir, train_output_dir, test_output_dir, train_file, test_file, num_points=5000, probe_radius=1.4, distance_threshold=2.0)