import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import knn_graph
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from supervised_fine_tuning import SE3EquivariantConv, PPIPredictor
from scipy.spatial.distance import cdist
import glob
from Bio.PDB import PDBParser
import open3d as o3d
from scipy.spatial import cKDTree
from plyfile import PlyData
import warnings
from pyrosetta import *
from pyrosetta.rosetta import *
from pyrosetta.teaching import *

warnings.filterwarnings("ignore")

# 初始化 PyRosetta
try:
    pyrosetta.init("-mute all -ignore_unrecognized_res")
    print("PyRosetta initialized successfully.")
except Exception as e:
    print(f"PyRosetta initialization failed: {e}")
    exit(1)

# 氨基酸化学性质数据
AMINO_ACID_PROPERTIES = {
    'A': {'charge': 0.5, 'hydrophobicity': 0.616, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 67},
    'R': {'charge': 1.0, 'hydrophobicity': 0.000, 'hbond': 1.0, 'oxygen': 0.0, 'nitrogen': 1.0, 'volume': 148},
    'N': {'charge': 0.5, 'hydrophobicity': 0.085, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 1.0, 'volume': 96},
    'D': {'charge': 0.0, 'hydrophobicity': 0.028, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 0.0, 'volume': 91},
    'C': {'charge': 0.5, 'hydrophobicity': 0.680, 'hbond': 1.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 86},
    'E': {'charge': 0.0, 'hydrophobicity': 0.057, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 0.0, 'volume': 109},
    'Q': {'charge': 0.5, 'hydrophobicity': 0.085, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 1.0, 'volume': 114},
    'G': {'charge': 0.5, 'hydrophobicity': 0.501, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 48},
    'H': {'charge': 0.55, 'hydrophobicity': 0.165, 'hbond': 1.0, 'oxygen': 0.0, 'nitrogen': 1.0, 'volume': 118},
    'I': {'charge': 0.5, 'hydrophobicity': 0.943, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 124},
    'L': {'charge': 0.5, 'hydrophobicity': 0.943, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 124},
    'K': {'charge': 1.0, 'hydrophobicity': 0.283, 'hbond': 1.0, 'oxygen': 0.0, 'nitrogen': 1.0, 'volume': 135},
    'M': {'charge': 0.5, 'hydrophobicity': 0.648, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 124},
    'F': {'charge': 0.5, 'hydrophobicity': 0.876, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 135},
    'P': {'charge': 0.5, 'hydrophobicity': 0.159, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 90},
    'S': {'charge': 0.5, 'hydrophobicity': 0.359, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 0.0, 'volume': 73},
    'T': {'charge': 0.5, 'hydrophobicity': 0.450, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 0.0, 'volume': 93},
    'W': {'charge': 0.5, 'hydrophobicity': 0.878, 'hbond': 1.0, 'oxygen': 0.0, 'nitrogen': 1.0, 'volume': 163},
    'Y': {'charge': 0.5, 'hydrophobicity': 0.880, 'hbond': 1.0, 'oxygen': 1.0, 'nitrogen': 0.0, 'volume': 141},
    'V': {'charge': 0.5, 'hydrophobicity': 0.825, 'hbond': 0.0, 'oxygen': 0.0, 'nitrogen': 0.0, 'volume': 105}
}

# 加载已训练的模型
def load_model(model_path, device):
    encoder = SE3EquivariantConv(in_channels=11, out_channels=64)
    model = PPIPredictor(encoder)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    return model

# 物理/化学约束函数
def apply_constraints(lig_features, lig_points, rec_points, threshold=2.0, bond_length=1.5):
    dist_matrix = torch.cdist(lig_points, lig_points)
    valid_mask = (dist_matrix > 0) & (dist_matrix < bond_length * 1.5)
    lig_features = lig_features * valid_mask.any(dim=1).float().unsqueeze(-1)
    
    rec_lig_dist = torch.cdist(rec_points, lig_points)
    complement_mask = (rec_lig_dist.mean(dim=0) > 1.0) & (rec_lig_dist.mean(dim=0) < 5.0)
    lig_features = lig_features * complement_mask.float().unsqueeze(-1)
    
    lig_features = torch.where(lig_features == 0, torch.randn_like(lig_features) * 0.1, lig_features)
    return lig_features, lig_points

# 从 PDB 文件加载原子坐标、类型和序列
def load_pdb(file_path):
    try:
        parser = PDBParser()
        structure = parser.get_structure("protein", file_path)
        atoms = []
        atom_types = []
        ele2num = {"C": 0, "H": 1, "O": 2, "N": 3, "S": 4, "SE": 5}
        sequence = ""
        chain_sequences = {}
        aa_map = {
            'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE',
            'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 'K': 'LYS', 'L': 'LEU',
            'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 'R': 'ARG',
            'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'
        }
        for model in structure:
            for chain in model:
                chain_id = chain.get_id()
                chain_seq = ""
                for residue in chain:
                    res_name = residue.get_resname()
                    for one_letter, three_letter in aa_map.items():
                        if res_name == three_letter:
                            chain_seq += one_letter
                            break
                    for atom in residue:
                        if atom.element in ele2num:
                            atoms.append(atom.get_coord())
                            atom_types.append(ele2num[atom.element])
                chain_sequences[chain_id] = chain_seq
                if not sequence:
                    sequence = chain_seq
        atoms = np.array(atoms, dtype=np.float32)
        unique_types = sorted(set(atom_types))
        type_to_idx = {i: i for i in unique_types}
        atom_types = np.array(atom_types, dtype=np.int32)
        return atoms, atom_types, unique_types, sequence, chain_sequences
    except Exception as e:
        print(f"错误：无法解析 PDB 文件 {file_path}: {e}")
        return None, None, None, "", {}

# 使用 Farthest Point Sampling 采样点云
def farthest_point_sampling(points, num_points):
    if len(points) <= num_points:
        return np.arange(len(points))
    n_points = len(points)
    selected_indices = [np.random.randint(n_points)]
    distances = np.full(n_points, np.inf)
    for i in range(n_points):
        distances[i] = np.linalg.norm(points[i] - points[selected_indices[0]])
    for _ in range(num_points - 1):
        farthest_idx = np.argmax(distances)
        selected_indices.append(farthest_idx)
        distances = np.minimum(distances, np.linalg.norm(points - points[farthest_idx], axis=1))
    return np.array(selected_indices)

# 使用 KNN 插值生成新点的特征
def interpolate_features(points, features, target_num_points):
    current_num_points = points.shape[0]
    if current_num_points >= target_num_points:
        return features
    kdtree = cKDTree(points)
    new_points = points[np.random.choice(current_num_points, target_num_points - current_num_points)]
    new_points += np.random.randn(new_points.shape[0], 3) * points.std() * 0.1
    indices = kdtree.query(new_points, k=3)[1]
    new_features = np.mean(features[indices], axis=1)
    return np.concatenate([features, new_features], axis=0)

# 从 PLY 文件加载点云和特征
def load_ply(file_path, num_points=5000):
    try:
        plydata = PlyData.read(str(file_path))
        points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]]).astype(np.float32)
        normals = np.stack([plydata["vertex"]["nx"], plydata["vertex"]["ny"], plydata["vertex"]["nz"]]).T.astype(np.float32)
        charge = plydata["vertex"]["charge"].astype(np.float32)
        hbond = plydata["vertex"]["hbond"].astype(np.float32)
        hphob = plydata["vertex"]["hphob"].astype(np.float32)
        features = np.stack([charge, hbond, hphob]).T
        if len(points) > num_points:
            indices = farthest_point_sampling(points, num_points)
            points = points[indices]
            normals = normals[indices]
            features = features[indices]
        elif len(points) < num_points:
            features = interpolate_features(points, features, num_points)
            normals = interpolate_features(points, normals, num_points)
            points = np.concatenate([points, points[np.random.choice(len(points), num_points - len(points))]], axis=0)
        return points, normals, features
    except Exception as e:
        print(f"错误：无法解析 PLY 文件 {file_path}: {e}")
        return None, None, None

# 计算分子表面
def compute_molecular_surface(atoms, probe_radius=1.4, num_points=5000):
    try:
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(atoms)
        pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=probe_radius * 2, max_nn=30))
        mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
            pcd, o3d.utility.DoubleVector([probe_radius, probe_radius * 2])
        )
        pcd_surface = mesh.sample_points_uniformly(number_of_points=num_points)
        points = np.asarray(pcd_surface.points, dtype=np.float32)
        normals = np.asarray(pcd_surface.normals, dtype=np.float32)
        norms = np.linalg.norm(normals, axis=1, keepdims=True)
        normals = normals / np.maximum(norms, 1e-6)
        return points, normals
    except Exception as e:
        print(f"错误：无法生成分子表面: {e}")
        return None, None

# 计算曲率
def compute_curvature(points, normals, k=10):
    if np.any(np.isnan(points)) or np.any(np.isinf(points)):
        print("警告：点云包含 NaN 或 Inf，替换为零")
        points = np.nan_to_num(points, nan=0.0, posinf=0.0, neginf=0.0)
    if np.any(np.isnan(normals)) or np.any(np.isinf(normals)):
        print("警告：法向量包含 NaN 或 Inf，替换为零")
        normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0)
    
    kdtree = cKDTree(points)
    curvatures = np.zeros((len(points), 2), dtype=np.float32)
    for i in range(len(points)):
        indices = kdtree.query_ball_point(points[i], r=5.0)
        if len(indices) > 1:
            indices = indices[1:]
            neighbors = points[indices]
            neighbor_normals = normals[indices]
            normal_diff = neighbor_normals - normals[i]
            mean_curvature = np.mean(np.linalg.norm(normal_diff, axis=1))
            cov = np.cov(neighbors.T)
            if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
                print(f"警告：点 {i} 的协方差矩阵包含 NaN 或 Inf，设置曲率为零")
                curvatures[i] = [0.0, 0.0]
                continue
            eigenvalues = np.linalg.eigvals(cov).real
            if np.any(np.isnan(eigenvalues)) or np.any(np.isinf(eigenvalues)):
                print(f"警告：点 {i} 的特征值包含 NaN 或 Inf，设置曲率为零")
                curvatures[i] = [0.0, 0.0]
                continue
            eigen_sum = np.sum(eigenvalues[:2])
            gaussian_curvature = np.prod(eigenvalues[:2]) / eigen_sum if eigen_sum > 1e-6 else 0.0
            curvatures[i] = [mean_curvature, gaussian_curvature]
        else:
            print(f"警告：点 {i} 的邻域点不足，设置曲率为零")
            curvatures[i] = [0.0, 0.0]
    return curvatures

# 计算化学特征
def compute_chemical_features(atoms, atom_types, unique_types, surface_points, radius=5.0):
    try:
        kdtree = cKDTree(atoms)
        num_types = 6
        atom_features = np.zeros((len(surface_points), num_types), dtype=np.float32)
        type_to_idx = {i: i for i in unique_types}
        for i, point in enumerate(surface_points):
            indices = kdtree.query_ball_point(point, radius)
            if indices:
                nearby_types = atom_types[indices]
                for t in nearby_types:
                    if t in type_to_idx:
                        atom_features[i, type_to_idx[t]] += 1
                atom_features[i] /= max(1, np.sum(atom_features[i]))
        return atom_features
    except Exception as e:
        print(f"错误：无法计算化学特征: {e}")
        return None

# 归一化化学特征
def normalize_chemical_features(features):
    pb_upper, pb_lower = 3.0, -3.0
    pb = np.clip(features[:, 0], pb_lower, pb_upper)
    pb = (pb - pb_lower) / (pb_upper - pb_lower)
    pb = 2 * pb - 1
    hbond = np.clip(features[:, 1], 0.0, 1.0)
    hphob = np.clip(features[:, 2] / 4.5, 0.0, 1.0)
    return np.stack([pb, hbond, hphob]).T

# 计算界面标签
def compute_iface_labels(points1, points2, distance_threshold=8.0):
    kdtree = cKDTree(points2)
    distances, _ = kdtree.query(points1, k=1)
    iface_labels = (distances < distance_threshold).astype(np.float32)
    return iface_labels

# 生成初始氨基酸序列
def generate_initial_sequence(length=20):
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    return ''.join(np.random.choice(list(amino_acids), length))

# 将序列转换为点云（使用 PyRosetta 生成结构）
def sequence_to_point_cloud(sequence, num_points=5000, device='cuda:0', receptor_points=None):
    pose = pose_from_sequence(sequence)
    scorefxn = get_score_function()
    
    # 使用 FastRelax 优化初始结构
    relax = pyrosetta.rosetta.protocols.relax.FastRelax(scorefxn)
    try:
        relax.apply(pose)
    except Exception as e:
        print(f"警告：FastRelax 失败: {e}")
    
    # 提取原子坐标
    atoms = []
    for i in range(1, pose.total_residue() + 1):
        for atom in pose.residue(i).atoms():
            atoms.append([atom.xyz()[0], atom.xyz()[1], atom.xyz()[2]])
    atoms = np.array(atoms, dtype=np.float32)
    
    # 计算分子表面
    surface_points, normals = compute_molecular_surface(atoms, probe_radius=1.4, num_points=num_points)
    if surface_points is not None:
        points = surface_points
    else:
        print("警告：分子表面计算失败，使用最远点采样")
        indices = farthest_point_sampling(atoms, num_points)
        points = atoms[indices]
    
    # 如果提供受体点，平移到口袋中心
    if receptor_points is not None:
        receptor_pocket_mask = (receptor_data['iface_labels'] == 1).cpu().numpy()
        pocket_centroid = np.mean(receptor_points[receptor_pocket_mask], axis=0)
        points = points + (pocket_centroid - np.mean(points, axis=0))
    
    return torch.from_numpy(points).float().to(device)

# 将序列转换为特征（包含原子属性、原子特征、曲率特征）
def sequence_to_features(sequence, receptor_features, num_points=5000, device='cuda:0'):
    # 初始化特征
    features = torch.zeros(num_points, 3, device=device)
    
    # 与受体特征互补性初始化
    for i in range(num_points):
        features[i, 0] = -receptor_features[i, 0]  # Charge
        features[i, 1] = receptor_features[i, 1]   # Hydrogen bond
        features[i, 2] = receptor_features[i, 2]   # Hydrophobicity

    # 生成初始结构
    pose = pose_from_sequence(sequence)
    scorefxn = get_score_function()
    relax = pyrosetta.rosetta.protocols.relax.FastRelax(scorefxn)
    try:
        relax.apply(pose)
    except Exception as e:
        print(f"警告：FastRelax 失败: {e}")
    
    # 提取原子坐标和类型
    atoms = []
    atom_types = []
    ele2num = {"C": 0, "H": 1, "O": 2, "N": 3, "S": 4, "SE": 5}
    for i in range(1, pose.total_residue() + 1):
        residue = pose.residue(i)
        for j in range(1, residue.natoms() + 1):
            try:
                xyz = residue.xyz(j)
                ele = residue.atom_type(j).element()
                atoms.append([xyz[0], xyz[1], xyz[2]])
                atom_types.append(ele2num.get(ele, 0))
            except Exception as e:
                print(f"警告：无法获取第 {i} 个残基中第 {j} 个原子的元素或坐标: {e}")
                atoms.append([0.0, 0.0, 0.0])
                atom_types.append(0)
    
    atoms = np.array(atoms, dtype=np.float32)
    atom_types = np.array(atom_types, dtype=np.int32)
    unique_types = sorted(set(atom_types))
    print(f"提取的原子数: {len(atoms)}, 原子类型数: {len(unique_types)}")
    
    # 计算分子表面
    surface_points, normals = compute_molecular_surface(atoms, probe_radius=1.4, num_points=num_points)
    if surface_points is None:
        print("警告：分子表面计算失败，使用原始点云")
        surface_points = atoms
        normals = np.zeros((len(atoms), 3), dtype=np.float32)
        if len(surface_points) > num_points:
            indices = farthest_point_sampling(surface_points, num_points)
            surface_points = surface_points[indices]
            normals = normals[indices]
        elif len(surface_points) < num_points:
            padding_points = np.random.randn(num_points - len(surface_points), 3) * 0.01
            surface_points = np.concatenate([surface_points, padding_points], axis=0)
            normals = np.concatenate([normals, np.zeros((num_points - len(surface_points), 3))], axis=0)
    
    # 计算化学特征
    atom_features = compute_chemical_features(atoms, atom_types, unique_types, surface_points, radius=5.0)
    if atom_features is None:
        print("警告：原子特征计算失败，使用零填充")
        atom_features = np.zeros((num_points, 6), dtype=np.float32)
    
    # 计算曲率特征
    curvature_features = compute_curvature(surface_points, normals, k=10)
    
    # 合并特征
    combined_features = np.concatenate([features.cpu().numpy(), atom_features, curvature_features], axis=1)
    if combined_features.shape[1] < 11:
        padding = np.random.randn(num_points, 11 - combined_features.shape[1]) * 0.1
        combined_features = np.concatenate([combined_features, padding], axis=1)
    elif combined_features.shape[1] > 11:
        combined_features = combined_features[:, :11]
    
    return torch.from_numpy(combined_features).float().to(device)

# 选择最有可能的氨基酸
def select_likely_amino_acid(lig_features):
    # 提取特征均值
    charge = lig_features[:, 0].mean().cpu().item()
    hbond = lig_features[:, 1].mean().cpu().item()
    hydrophobicity = lig_features[:, 2].mean().cpu().item()
    oxygen_density = lig_features[:, 5].mean().cpu().item()
    nitrogen_density = lig_features[:, 6].mean().cpu().item()
    mean_curvature = lig_features[:, 9].mean().cpu().item()
    
    # 归一化特征到 [0, 1]
    charge = (charge + 1.0) / 2.0
    hbond = (hbond + 1.0) / 2.0
    hydrophobicity = (hydrophobicity + 1.0) / 2.0
    mean_curvature = (mean_curvature + 1.0) / 2.0
    oxygen_density = np.clip(oxygen_density, 0.0, 1.0)
    nitrogen_density = np.clip(nitrogen_density, 0.0, 1.0)
    
    # 曲率映射到体积 (48 to 163 Å³)
    target_volume = 48 + mean_curvature * (163 - 48)
    
    # 计算每种氨基酸的得分
    scores = {}
    for aa, props in AMINO_ACID_PROPERTIES.items():
        score = (
            0.3 * abs(props['charge'] - charge) +
            0.3 * abs(props['hydrophobicity'] - hydrophobicity) +
            0.2 * abs(props['hbond'] - hbond) +
            0.1 * oxygen_density * props['oxygen'] +
            0.1 * nitrogen_density * props['nitrogen'] +
            0.1 * abs(props['volume'] - target_volume) / 115
        )
        scores[aa] = score
    
    # 选择得分最低的氨基酸
    best_aa = min(scores, key=scores.get)
    return best_aa

# 生成序列
def generate_sequence(current_sequence, lig_features, operation, pos=None):
    sequence = list(current_sequence)
    likely_aa = select_likely_amino_acid(lig_features)

    if operation == 'add':
        if len(sequence) < 74:
            pos = pos if pos is not None else np.random.randint(0, len(sequence) + 1)
            sequence.insert(pos, likely_aa)
    elif operation == 'delete' and sequence:
        pos = pos if pos is not None else np.random.randint(0, len(sequence))
        if sequence[pos] == likely_aa:
            sequence.pop(pos)
        else:
            return current_sequence
    elif operation == 'modify' and sequence:
        pos = pos if pos is not None else np.random.randint(0, len(sequence))
        if sequence[pos] != likely_aa:
            sequence[pos] = likely_aa
        else:
            return current_sequence

    return ''.join(sequence)

def generate_ligand(model, receptor_data, num_iterations=400, target_delta_g=-15.0, 
                    device='cuda:0', reinit_every=25, ligand_file=None):
    # === 准备 receptor 数据 ===
    batch_size = 1
    rec_batch_idx = torch.arange(batch_size).repeat_interleave(receptor_data['points'].size(0) // batch_size).to(device)
    rec_points = receptor_data['points'].to(device)
    rec_features = receptor_data['features'].to(device)
    rec_labels = receptor_data['iface_labels'].to(device)

    # === 初始化 ligand ===
    current_sequence = generate_initial_sequence(length=20)
    lig_points = sequence_to_point_cloud(current_sequence, device=device, receptor_points=rec_points.cpu().numpy())
    lig_features = sequence_to_features(current_sequence, rec_features, device=device)
    lig_points = lig_points.detach().requires_grad_(True)
    lig_features = lig_features.detach().requires_grad_(True)

    optimizer_points = torch.optim.Adam([lig_points], lr=1e-3)
    optimizer_features = torch.optim.Adam([lig_features], lr=2e-3)

    # === 记录最优状态 ===
    best_delta_g = float('inf')
    best_sequence = current_sequence
    best_lig_points = lig_points.detach().clone()
    best_lig_features = lig_features.detach().clone()
    best_lig_pocket_logits = None

    for iteration in tqdm(range(num_iterations), desc="Generating Ligand"):
        optimizer_points.zero_grad()
        optimizer_features.zero_grad()

        # === 计算当前 ΔG ===
        lig_pocket_logits, _, _, _ = model(
            rec_points.unsqueeze(0), rec_features.unsqueeze(0),
            lig_points.unsqueeze(0), lig_features.unsqueeze(0),
            rec_batch_idx
        )
        rec_z = model.encoder(rec_features.unsqueeze(0), rec_points.unsqueeze(0), rec_batch_idx)
        lig_z = model.encoder(lig_features.unsqueeze(0), lig_points.unsqueeze(0), rec_batch_idx)
        rec_pocket_logits = model.pocket_head(rec_z).squeeze(-1)
        rec_pocket_mask = (torch.sigmoid(rec_pocket_logits) > 0.5).float().squeeze(0)
        lig_pocket_mask = (torch.sigmoid(lig_pocket_logits.squeeze(0)) > 0.5).float()
        rec_sum = rec_pocket_mask.sum(dim=0, keepdim=True).clamp(min=1.0)
        lig_sum = lig_pocket_mask.sum(dim=0, keepdim=True).clamp(min=1.0)
        rec_pocket_features = (rec_z.squeeze(0) * rec_pocket_mask.unsqueeze(-1)).sum(dim=0) / rec_sum
        lig_pocket_features = (lig_z.squeeze(0) * lig_pocket_mask.unsqueeze(-1)).sum(dim=0) / lig_sum
        pair_features = torch.cat([rec_pocket_features, lig_pocket_features], dim=-1)
        current_delta_g = model.delta_g_head(pair_features.unsqueeze(0)).squeeze(0).item()

        # === 尝试所有操作 ===
        likely_aa = select_likely_amino_acid(lig_features)
        candidate_sequences = []
        candidate_operations = []

        # 添加：在每个位置添加最有可能的氨基酸
        if len(current_sequence) < 74:
            for pos in range(len(current_sequence) + 1):
                new_sequence = generate_sequence(current_sequence, lig_features, 'add', pos=pos)
                candidate_sequences.append(new_sequence)
                candidate_operations.append(('add', pos, likely_aa))

        # 删除：仅删除与最有可能氨基酸匹配的位置
        if len(current_sequence) > 1:
            for pos in range(len(current_sequence)):
                if current_sequence[pos] == likely_aa:
                    new_sequence = generate_sequence(current_sequence, lig_features, 'delete', pos=pos)
                    if new_sequence != current_sequence:
                        candidate_sequences.append(new_sequence)
                        candidate_operations.append(('delete', pos, likely_aa))

        # 修改：将每个位置修改为最有可能的氨基酸
        for pos in range(len(current_sequence)):
            if current_sequence[pos] != likely_aa:
                new_sequence = generate_sequence(current_sequence, lig_features, 'modify', pos=pos)
                if new_sequence != current_sequence:
                    candidate_sequences.append(new_sequence)
                    candidate_operations.append(('modify', pos, likely_aa))

        # === 评估候选序列 ===
        best_new_sequence = current_sequence
        best_new_lig_points = lig_points
        best_new_lig_features = lig_features
        best_delta_delta_g = float('inf')
        best_operation = None
        best_pos = None
        best_new_aa = None
        all_delta_delta_g_positive = True
        add_negative_ddgs = []
        modify_negative_ddgs = []
        delete_negative_ddgs = []

        for idx, (new_sequence, (operation, pos, new_aa)) in enumerate(zip(candidate_sequences, candidate_operations)):
            if not new_sequence or new_sequence == current_sequence:
                continue

            try:
                # 生成新点云和特征
                new_lig_points = sequence_to_point_cloud(new_sequence, device=device, receptor_points=rec_points.cpu().numpy())
                new_lig_features = sequence_to_features(new_sequence, rec_features, device=device)
                new_lig_points = new_lig_points.detach().requires_grad_(True)
                new_lig_features = new_lig_features.detach().requires_grad_(True)

                # 计算新 ΔG
                with torch.no_grad():
                    lig_pocket_logits, _, _, _ = model(
                        rec_points.unsqueeze(0), rec_features.unsqueeze(0),
                        new_lig_points.unsqueeze(0), new_lig_features.unsqueeze(0),
                        rec_batch_idx
                    )
                    lig_z = model.encoder(new_lig_features.unsqueeze(0), new_lig_points.unsqueeze(0), rec_batch_idx)
                    lig_pocket_mask = (torch.sigmoid(lig_pocket_logits.squeeze(0)) > 0.5).float()
                    lig_sum = lig_pocket_mask.sum(dim=0, keepdim=True).clamp(min=1.0)
                    lig_pocket_features = (lig_z.squeeze(0) * lig_pocket_mask.unsqueeze(-1)).sum(dim=0) / lig_sum
                    pair_features = torch.cat([rec_pocket_features, lig_pocket_features], dim=-1)
                    new_delta_g = model.delta_g_head(pair_features.unsqueeze(0)).squeeze(0).item()

                # 计算 ΔΔG
                delta_delta_g = new_delta_g - current_delta_g
                print(f"[Operation {operation}, Pos {pos}, AA {new_aa}] Iter {iteration:03d} | Sequence Length: {len(new_sequence)} | ΔG: {new_delta_g:.4f} | ΔΔG: {delta_delta_g:.4f}")

                # 按操作类型收集 ΔΔG < 0 的操作
                if delta_delta_g < 0:
                    if operation == 'add':
                        add_negative_ddgs.append((delta_delta_g, new_sequence, new_lig_points, new_lig_features, pos, new_aa))
                    elif operation == 'modify':
                        modify_negative_ddgs.append((delta_delta_g, new_sequence, new_lig_points, new_lig_features, pos, new_aa))
                    elif operation == 'delete':
                        delete_negative_ddgs.append((delta_delta_g, new_sequence, new_lig_points, new_lig_features, pos, new_aa))
                    all_delta_delta_g_positive = False

            except Exception as e:
                print(f"[Warning] Operation {operation} at pos {pos} with AA {new_aa} failed: {e}")
                continue

        # === 检查退出条件 ===
        if all_delta_delta_g_positive:
            print(f"[Exit] Iter {iteration:03d} | All operations have ΔΔG > 0, returning current sequence")
            break

        # === 按优先级选择操作：add > modify > delete ===
        if add_negative_ddgs:
            # 选择 ΔΔG 最小的添加操作
            best_delta_delta_g, best_new_sequence, best_new_lig_points, best_new_lig_features, best_pos, best_new_aa = min(
                add_negative_ddgs, key=lambda x: x[0]
            )
            best_operation = 'add'
            print(f"[Accepted] Iter {iteration:03d} | Operation: {best_operation}, Pos: {best_pos}, AA: {best_new_aa} | Sequence Length: {len(best_new_sequence)} | ΔG: {current_delta_g + best_delta_delta_g:.4f} | ΔΔG: {best_delta_delta_g:.4f}")
        elif modify_negative_ddgs:
            # 选择 ΔΔG 最小的修改操作
            best_delta_delta_g, best_new_sequence, best_new_lig_points, best_new_lig_features, best_pos, best_new_aa = min(
                modify_negative_ddgs, key=lambda x: x[0]
            )
            best_operation = 'modify'
            print(f"[Accepted] Iter {iteration:03d} | Operation: {best_operation}, Pos: {best_pos}, AA: {best_new_aa} | Sequence Length: {len(best_new_sequence)} | ΔG: {current_delta_g + best_delta_delta_g:.4f} | ΔΔG: {best_delta_delta_g:.4f}")
        elif delete_negative_ddgs:
            # 选择 ΔΔG 最小的删除操作
            best_delta_delta_g, best_new_sequence, best_new_lig_points, best_new_lig_features, best_pos, best_new_aa = min(
                delete_negative_ddgs, key=lambda x: x[0]
            )
            best_operation = 'delete'
            print(f"[Accepted] Iter {iteration:03d} | Operation: {best_operation}, Pos: {best_pos}, AA: {best_new_aa} | Sequence Length: {len(best_new_sequence)} | ΔG: {current_delta_g + best_delta_delta_g:.4f} | ΔΔG: {best_delta_delta_g:.4f}")
        else:
            print(f"[Exit] Iter {iteration:03d} | No operations with ΔΔG < 0, returning current sequence")
            break

        # 更新序列和结构
        current_sequence = best_new_sequence
        lig_points = best_new_lig_points
        lig_features = best_new_lig_features
        current_delta_g = current_delta_g + best_delta_delta_g

        # === 更新最优状态 ===
        if current_delta_g < best_delta_g:
            best_delta_g = current_delta_g
            best_sequence = current_sequence
            best_lig_points = lig_points.detach().clone()
            best_lig_features = lig_features.detach().clone()
            best_lig_pocket_logits = lig_pocket_logits.detach().clone()
            print(f"[Best ΔG Updated] Iter {iteration:03d} → ΔG: {best_delta_g:.4f} | Sequence: {best_sequence}")

        # === 约束序列长度 ===
        if len(current_sequence) > 74:
            current_sequence = current_sequence[:74]
            lig_points = sequence_to_point_cloud(current_sequence, device=device, receptor_points=rec_points.cpu().numpy())
            lig_features = sequence_to_features(current_sequence, rec_features, device=device)
            lig_points = lig_points.detach().requires_grad_(True)
            lig_features = lig_features.detach().requires_grad_(True)

        # === 优化点云和特征 ===
        optimizer_points.zero_grad()
        optimizer_features.zero_grad()
        lig_pocket_logits, _, _, _ = model(
            rec_points.unsqueeze(0), rec_features.unsqueeze(0),
            lig_points.unsqueeze(0), lig_features.unsqueeze(0),
            rec_batch_idx
        )
        rec_z = model.encoder(rec_features.unsqueeze(0), rec_points.unsqueeze(0), rec_batch_idx)
        rec_pocket_logits = model.pocket_head(rec_z).squeeze(-1)
        rec_pocket_mask = (torch.sigmoid(rec_pocket_logits) > 0.5).float().squeeze(0)
        lig_pocket_mask = (torch.sigmoid(lig_pocket_logits.squeeze(0)) > 0.5).float()
        rec_sum = rec_pocket_mask.sum(dim=0, keepdim=True).clamp(min=1.0)
        lig_sum = lig_pocket_mask.sum(dim=0, keepdim=True).clamp(min=1.0)
        rec_pocket_features = (rec_z.squeeze(0) * rec_pocket_mask.unsqueeze(-1)).sum(dim=0) / rec_sum
        lig_pocket_features = (lig_z.squeeze(0) * lig_pocket_mask.unsqueeze(-1)).sum(dim=0) / lig_sum
        pair_features = torch.cat([rec_pocket_features, lig_pocket_features], dim=-1)
        delta_g_loss = F.mse_loss(model.delta_g_head(pair_features.unsqueeze(0)).squeeze(0), torch.tensor([target_delta_g], device=device))
        pocket_match_loss = 10.0 * (1.0 - F.cosine_similarity(rec_pocket_mask.unsqueeze(0), lig_pocket_mask.unsqueeze(0)).mean())
        rec_pocket_points = rec_points[rec_pocket_mask.bool()]
        dist_loss = 5.0 * torch.mean(torch.cdist(lig_points, rec_pocket_points).min(dim=1)[0])
        feature_diversity = torch.var(lig_features + 1e-4, dim=0).mean()
        feature_l2 = torch.norm(lig_features, p=2) / lig_features.numel()
        feature_reg = 5.0 * feature_diversity + 0.1 * feature_l2
        total_loss = 5 * delta_g_loss + pocket_match_loss + dist_loss + feature_reg
        total_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_([lig_points, lig_features], max_norm=1.0)
        optimizer_points.step()
        optimizer_features.step()
        lig_features, lig_points = apply_constraints(lig_features, lig_points, rec_points)

    # === PyRosetta ΔG 评估 ===
    scorefxn = get_score_function()
    receptor_pdb = ""
    _, _, _, receptor_sequence, chain_sequences = load_pdb(receptor_pdb)
    if not receptor_sequence:
        receptor_sequence = "A" * 74
    try:
        receptor_pose = pose_from_sequence(receptor_sequence)
    except Exception as e:
        print(f"错误：无法创建 receptor_pose: {e}")
        receptor_pose = pose_from_sequence("A" * 74)
    
    try:
        ligand_pose = pose_from_sequence(best_sequence)
    except Exception as e:
        print(f"错误：无法创建 ligand_pose: {e}")
        ligand_pose = pose_from_sequence("A" * 74)

    # 平移配体
    ligand_atoms_coords = np.array([[atom.xyz()[0], atom.xyz()[1], atom.xyz()[2]] for i in range(1, ligand_pose.total_residue()+1)
                                    for atom in ligand_pose.residue(i).atoms()])
    receptor_pocket_mask = (receptor_data['iface_labels'] == 1).cpu().numpy()
    pocket_points = rec_points.cpu().numpy()[receptor_pocket_mask]
    pocket_center = np.mean(pocket_points, axis=0)
    ligand_center = np.mean(ligand_atoms_coords, axis=0)
    translation = pocket_center - ligand_center
    from pyrosetta.rosetta.numeric import xyzVector_double_t
    for i in range(1, ligand_pose.total_residue() + 1):
        res = ligand_pose.residue(i)
        for j in range(1, res.natoms() + 1):
            old_xyz = res.xyz(j)
            new_xyz = xyzVector_double_t(old_xyz[0] + translation[0],
                                        old_xyz[1] + translation[1],
                                        old_xyz[2] + translation[2])
            res.set_xyz(j, new_xyz)

    # 构建复合物
    complex_pose = Pose()
    complex_pose.assign(receptor_pose)
    jump_anchor = receptor_pose.total_residue()
    try:
        complex_pose.append_pose_by_jump(ligand_pose, jump_anchor)
    except Exception as e:
        print(f"[Error] append_pose_by_jump 失败：{e}")
        complex_pose = pose_from_sequence(receptor_sequence + best_sequence)

    # Relax 和 ΔG 计算
    try:
        relax = pyrosetta.rosetta.protocols.relax.FastRelax(scorefxn)
        relax.apply(complex_pose)
        complex_energy = scorefxn(complex_pose)
        receptor_energy = scorefxn(receptor_pose)
        ligand_energy = scorefxn(ligand_pose)
        delta_g_rosetta = complex_energy - (receptor_energy + ligand_energy)
        print(f"[Rosetta ΔG] = {delta_g_rosetta:.4f}")
    except Exception as e:
        print(f"[Warning] ΔG 计算失败: {e}")
        delta_g_rosetta = float('nan')

    return best_sequence, best_lig_points.cpu().numpy(), rec_points.detach(), rec_pocket_logits.detach().squeeze(0), best_lig_pocket_logits.detach().squeeze(0)



if __name__ == "__main__":
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    model_path = ""
    test_data_dir = ""
    
    model = load_model(model_path, device)
    receptor_id = "5E6P_receptor_wt"
    receptor_file = os.path.join(test_data_dir, f"{receptor_id}.pt")
    receptor_data = torch.load(receptor_file, map_location=device)

    # === 调试：加载并预测标准配体的 ΔG ===
    ligand_file = os.path.join(test_data_dir, f"{receptor_id.replace('receptor', 'ligand')}.pt")
    if os.path.exists(ligand_file):
        ligand_data = torch.load(ligand_file, map_location=device)
        lig_points = ligand_data['points'].to(device)
        lig_features = ligand_data['features'].to(device)
        batch_size = 1
        rec_batch_idx = torch.arange(batch_size).repeat_interleave(lig_points.size(0) // batch_size).to(device)
        
        with torch.no_grad():
            rec_z = model.encoder(receptor_data['features'].unsqueeze(0), receptor_data['points'].unsqueeze(0), rec_batch_idx)
            lig_z = model.encoder(lig_features.unsqueeze(0), lig_points.unsqueeze(0), rec_batch_idx)
            rec_pocket_mask = (torch.sigmoid(model.pocket_head(rec_z).squeeze(-1)) > 0.5).float()
            lig_pocket_mask = (torch.sigmoid(model.pocket_head(lig_z).squeeze(-1)) > 0.5).float()
            rec_sum = rec_pocket_mask.sum(dim=1, keepdim=True).clamp(min=1.0)
            lig_sum = lig_pocket_mask.sum(dim=1, keepdim=True).clamp(min=1.0)
            rec_pocket_features = (rec_z.squeeze(0) * rec_pocket_mask.unsqueeze(-1)).sum(dim=1) / rec_sum
            lig_pocket_features = (lig_z.squeeze(0) * lig_pocket_mask.unsqueeze(-1)).sum(dim=1) / lig_sum
            pair_features = torch.cat([rec_pocket_features, lig_pocket_features], dim=-1)
            pred_delta_g = model.delta_g_head(pair_features.unsqueeze(0)).item()
            print(f"[Debug] Standard Ligand ΔG Prediction for {receptor_id}: {pred_delta_g:.4f}")
            print(f"[Debug] lig_points shape: {lig_points.shape}, lig_features shape: {lig_features.shape}")
            if 'delta_g' in ligand_data:
                print(f"[Debug] True ΔG from ligand_data: {ligand_data['delta_g'].item():.4f}")
    else:
        print(f"[Debug] No standard ligand file found for {receptor_id}, using sequence-based generation.")
        initial_sequence = generate_initial_sequence()
        lig_points = sequence_to_point_cloud(initial_sequence, device=device, receptor_points=receptor_data['points'].cpu().numpy())
        lig_features = sequence_to_features(initial_sequence, receptor_data['features'], device=device)
        with torch.no_grad():
            rec_z = model.encoder(receptor_data['features'].unsqueeze(0), receptor_data['points'].unsqueeze(0), rec_batch_idx)
            lig_z = model.encoder(lig_features.unsqueeze(0), lig_points.unsqueeze(0), rec_batch_idx)
            rec_pocket_mask = (torch.sigmoid(model.pocket_head(rec_z).squeeze(-1)) > 0.5).float()
            lig_pocket_mask = (torch.sigmoid(model.pocket_head(lig_z).squeeze(-1)) > 0.5).float()
            rec_sum = rec_pocket_mask.sum(dim=1, keepdim=True).clamp(min=1.0)
            lig_sum = lig_pocket_mask.sum(dim=1, keepdim=True).clamp(min=1.0)
            rec_pocket_features = (rec_z.squeeze(0) * rec_pocket_mask.unsqueeze(-1)).sum(dim=1) / rec_sum
            lig_pocket_features = (lig_z.squeeze(0) * lig_pocket_mask.unsqueeze(-1)).sum(dim=1) / rec_sum
            pair_features = torch.cat([rec_pocket_features, lig_pocket_features], dim=-1)
            pred_delta_g = model.delta_g_head(pair_features.unsqueeze(0)).item()
            print(f"[Debug] Generated Standard Ligand ΔG Prediction for {receptor_id}: {pred_delta_g:.4f}")
            print(f"[Debug] lig_points shape: {lig_points.shape}, lig_features shape: {lig_features.shape}")

    # 生成配体
    sequence, structure, rec_points, rec_pocket_logits, lig_pocket_logits = generate_ligand(
        model, receptor_data, num_iterations=10, target_delta_g=-15.0, device=device, ligand_file=ligand_file
    )
    
    print(f"Generated Ligand Sequence: {sequence}")
    print(f"Generated Ligand Structure Shape: {structure.shape}")
    
    np.save(f"{receptor_id}_generated_ligand_structure.npy", structure)
    with open(f"{receptor_id}_generated_ligand_sequence.txt", "w") as f:
        f.write(sequence)
    

