import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from scipy.spatial import cKDTree
import open3d as o3d
from Bio.PDB import PDBParser
import torch.nn.functional as F

# Constants
R = 8.314 / 1000  # Gas constant in kJ/mol/K
T = 298.15  # Temperature in Kelvin
PROBE_RADIUS = 1.4
N_POINTS = 500  # 调整为500个点
K_NEIGHBORS = 10
CHEMICAL_RADIUS = 5.0

# 原子属性数据 (保持不变)
ATOM_PROPERTIES = {
    'C': {'charge': 0.0, 'hydrophobicity': 0.74, 'valence': 4, 'hbond_donor': 0, 'hbond_acceptor': 0, 'radius': 0.77},
    'O': {'charge': -0.4, 'hydrophobicity': 0.23, 'valence': 2, 'hbond_donor': 0.5, 'hbond_acceptor': 1.0, 'radius': 0.73},
    'N': {'charge': -0.5, 'hydrophobicity': 0.28, 'valence': 3, 'hbond_donor': 1.0, 'hbond_acceptor': 1.0, 'radius': 0.70},
    'S': {'charge': 0.0, 'hydrophobicity': 0.87, 'valence': 2, 'hbond_donor': 0.2, 'hbond_acceptor': 0.5, 'radius': 1.03}
}

# 加载PDB/SDF的辅助函数
def load_pdb_coords(smiles):
    """从SMILES生成分子并加载坐标（模拟SDF加载）。"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, randomSeed=42)
    conf = mol.GetConformer()
    coords = conf.GetPositions()
    atom_types = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
    atom_types = np.array([t if t in ['C', 'H', 'O', 'N', 'S', 'SE'] else 'C' for t in atom_types])
    return coords, atom_types, mol  # 返回mol对象

def compute_molecular_surface(coords, probe_radius=PROBE_RADIUS):
    """Generate molecular surface point cloud using Open3D."""
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=probe_radius * 2, max_nn=30))
    mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector([probe_radius, probe_radius * 2]))
    surface_pcd = mesh.sample_points_uniformly(number_of_points=N_POINTS)
    points = np.asarray(surface_pcd.points)
    normals = np.asarray(surface_pcd.normals)
    return points, normals

def farthest_point_sampling(points, n_points=N_POINTS):
    """Sample points uniformly using farthest point sampling."""
    if len(points) <= n_points:
        return points
    idx = []
    idx.append(np.random.randint(0, len(points)))
    dists = np.full(len(points), np.inf)
    for _ in range(n_points - 1):
        last_idx = idx[-1]
        dists = np.minimum(dists, np.linalg.norm(points - points[last_idx], axis=1))
        idx.append(np.argmax(dists))
    return points[idx]

def interpolate_features(points, normals, features, target_n=N_POINTS):
    """Interpolate to target number of points."""
    if len(points) >= target_n:
        return points, normals, features
    tree = cKDTree(points)
    new_points = []
    new_normals = []
    new_features = []
    for _ in range(target_n - len(points)):
        i = np.random.randint(0, len(points))
        _, idx = tree.query(points[i], k=2)
        w = np.random.random()
        new_point = w * points[idx[0]] + (1 - w) * points[idx[1]]
        new_normal = w * normals[idx[0]] + (1 - w) * normals[idx[1]]
        new_feature = w * features[idx[0]] + (1 - w) * features[idx[1]]
        new_points.append(new_point)
        new_normals.append(new_normal)
        new_features.append(new_feature)
    return (np.vstack([points, new_points]), 
            np.vstack([normals, new_normals]), 
            np.vstack([features, new_features]))

def compute_chemical_features(coords, atom_types, mol, ligand_type='protein'):
    """Compute chemical features based on ligand type."""
    tree = cKDTree(coords)
    features = np.zeros((len(coords), 3))  # 3D chemical features
    for i, coord in enumerate(coords):
        idx = tree.query_ball_point(coord, r=CHEMICAL_RADIUS)
        nearby_atoms = atom_types[idx]
        features[i, 0] = np.mean([1 if a in ['N', 'O'] else -1 if a in ['C', 'S'] else 0 for a in nearby_atoms])  # Polarity/Charge
        features[i, 1] = np.mean([1 if a in ['O', 'N', 'S'] else 0 for a in nearby_atoms])  # H-bond/Electrophilicity
        # 修正芳香性计算，使用mol对象
        features[i, 2] = 1 if mol and i < mol.GetNumAtoms() and mol.GetAtomWithIdx(i).GetIsAromatic() else 0  # Aromaticity
    return features

def normalize_chemical_features(features, ligand_type='protein'):
    """Normalize chemical features."""
    features[:, 0] = np.clip(features[:, 0], -1, 1)  # Polarity/Charge
    features[:, 1] = np.clip(features[:, 1], 0, 1)  # H-bond/Electrophilicity
    features[:, 2] = np.clip(features[:, 2], 0, 1)  # Aromaticity
    return features

def compute_local_density(coords):
    """Compute local density as a 1D geometric feature based on KNN neighbors."""
    tree = cKDTree(coords)
    densities = np.zeros(len(coords))
    for i, coord in enumerate(coords):
        _, idx = tree.query(coord, k=K_NEIGHBORS + 1)
        distances = np.linalg.norm(coords[idx[1:]] - coord, axis=1)
        densities[i] = np.mean(distances) if len(distances) > 0 else 0
    densities = (densities - np.min(densities)) / (np.max(densities) - np.min(densities) + 1e-8)
    return densities

def compute_curvature(points, normals, k=K_NEIGHBORS):
    """Compute mean and Gaussian curvature."""
    tree = cKDTree(points)
    curvatures = np.zeros((len(points), 2))
    for i, point in enumerate(points):
        _, idx = tree.query(point, k=k + 1)
        neighbors = points[idx[1:]]
        neighbor_normals = normals[idx[1:]]
        if len(neighbors) > 0:
            relative_pos = neighbors - point
            dot_products = np.sum(normals[i] * neighbor_normals, axis=1)
            curvatures[i, 0] = np.mean(np.arccos(np.clip(dot_products, -1, 1))) / np.mean(np.linalg.norm(relative_pos, axis=1))  # Mean curvature
            curvatures[i, 1] = np.std(np.arccos(np.clip(dot_products, -1, 1))) / np.mean(np.linalg.norm(relative_pos, axis=1))  # Gaussian curvature
    return curvatures

def compute_atom_type_features(atom_types):
    """Encode atom types as 6D features."""
    atom_map = {'C': 0, 'H': 1, 'O': 2, 'N': 3, 'S': 4, 'SE': 5}
    features = np.zeros((len(atom_types), 6))
    for i, atom in enumerate(atom_types):
        if atom in atom_map:
            features[i, atom_map[atom]] = 1
    return features

def sequence_to_features(smiles, receptor_features=None, n_points=N_POINTS):
    """Generate features for ligand consistent with receptor processing."""
    # Load coordinates, atom types, and molecule from SMILES
    coords, atom_types, mol = load_pdb_coords(smiles)
    
    # Generate molecular surface point cloud with 500 points
    points, normals = compute_molecular_surface(coords)
    if len(points) > n_points:
        points = farthest_point_sampling(points, n_points)
        normals = normals[:n_points]
    
    # Compute features at atomic level
    chem_features = normalize_chemical_features(compute_chemical_features(coords, atom_types, mol))
    atom_features = compute_atom_type_features(atom_types)
    
    # Interpolate features to match point cloud size using KDTree
    tree = cKDTree(coords)
    interp_chem_features = np.zeros((n_points, chem_features.shape[1]))
    interp_atom_features = np.zeros((n_points, atom_features.shape[1]))
    for i in range(n_points):
        dist, idx = tree.query(points[i], k=1)  # Find nearest atom
        if idx < len(chem_features):
            interp_chem_features[i] = chem_features[idx]
            interp_atom_features[i] = atom_features[idx]
    
    # Compute geometric features
    density = compute_local_density(points)
    curvatures = compute_curvature(points, normals)
    
    # Combine features (12D: 3 chem + 6 atom + 2 curvature + 1 density)
    features = np.hstack([interp_chem_features, interp_atom_features, curvatures, density[:, None]])
    
    # Add 13th dimension (all 1s) for ligand label
    features = np.hstack([features, np.ones((n_points, 1))])
    
    # Interpolate if needed (should not be necessary with n_points=500)
    if len(points) < n_points:
        points, normals, features = interpolate_features(points, normals, features, n_points)
    
    # Convert to tensor
    points = torch.tensor(points, dtype=torch.float32)
    features = torch.tensor(features, dtype=torch.float32)
    return points, features

# 其余模型和生成函数保持不变
class SE3EquivariantConv(nn.Module):
    def __init__(self, in_channels=13, out_channels=64):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, 1)
        self.mlp = nn.Sequential(
            nn.Linear(out_channels + 3, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, points, features, k=16):
        dist = torch.cdist(points, points, p=2)
        _, idx = torch.topk(-dist, k=k, dim=-1, largest=False)
        idx = idx.to(device=points.device)
        relative_pos = points[idx] - points.unsqueeze(1)
        features = features[idx]
        features = self.conv(features.permute(0, 2, 1)).permute(0, 2, 1)
        features = torch.cat([features, relative_pos], dim=-1)
        features = self.mlp(features)
        return features.max(dim=1)[0]

class AttentionFusion(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)
        self.scale = 1 / (channels ** 0.5)
    
    def forward(self, receptor_enc, ligand_enc):
        if receptor_enc.dim() == 1:
            receptor_enc = receptor_enc.unsqueeze(0)
        if ligand_enc.dim() == 1:
            ligand_enc = ligand_enc.unsqueeze(0)
        query = self.query(receptor_enc)
        key = self.key(ligand_enc)
        value = self.value(ligand_enc)
        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        attn = torch.softmax(scores, dim=-1)
        fused = torch.matmul(attn, value)
        return torch.cat([receptor_enc, fused.squeeze(-1)], dim=-1)

class BindingPredictor(nn.Module):
    def __init__(self, in_channels=13):
        super().__init__()
        self.receptor_encoder = SE3EquivariantConv(in_channels, 64)
        self.ligand_encoder = SE3EquivariantConv(in_channels, 64)
        self.fusion = AttentionFusion(64)
        self.pocket_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, 1)
        )
        self.interaction_head = nn.Sequential(
            nn.Linear(64 * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.delta_g_head = nn.Sequential(
            nn.Linear(64 * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, receptor_points, receptor_features, ligand_points, ligand_features):
        receptor_enc = self.receptor_encoder(receptor_points, receptor_features)
        ligand_enc = self.ligand_encoder(ligand_points, ligand_features)
        receptor_pocket = self.pocket_head(receptor_enc).squeeze(-1)
        ligand_pocket = self.pocket_head(ligand_enc).squeeze(-1)
        interaction_features = self.fusion(receptor_enc, ligand_enc)
        interaction = self.interaction_head(interaction_features)
        delta_g = self.delta_g_head(interaction_features)
        return receptor_pocket, ligand_pocket, interaction, delta_g

class LigandOptimizer(nn.Module):
    def __init__(self, points, features):
        super().__init__()
        self.points = nn.Parameter(points.clone().detach())
        self.features = nn.Parameter(features.clone().detach())

def modify_smiles(smiles, operation, pos=None, new_atom=None, double_bond=False):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return smiles
    editable = Chem.RWMol(mol)
    num_atoms = editable.GetNumAtoms()
    
    if pos is None:
        pos = np.random.randint(0, num_atoms) if operation != 'add' else np.random.randint(0, num_atoms + 1)
    
    if pos >= num_atoms and operation != 'add':
        print(f"Warning: pos {pos} out of range for {operation}, skipping modification")
        return smiles
    
    current_atom = editable.GetAtomWithIdx(pos) if pos < num_atoms else None
    current_symbol = current_atom.GetSymbol() if current_atom else None
    
    if operation == 'add':
        new_atom = new_atom if new_atom else select_likely_atom(torch.tensor(sequence_to_features(smiles)[1]))
        new_atom_idx = editable.AddAtom(Chem.Atom(new_atom))
        if pos < num_atoms:
            # 检查是否已有键，避免重复
            if not editable.GetBondBetweenAtoms(pos, new_atom_idx):
                if not double_bond or new_atom != 'O':
                    editable.AddBond(pos, new_atom_idx, Chem.BondType.SINGLE)
                elif double_bond and new_atom == 'O' and editable.GetAtomWithIdx(pos).GetSymbol() in ['C', 'N']:
                    editable.AddBond(pos, new_atom_idx, Chem.BondType.DOUBLE)
        # 成环策略：如果碳原子数 > 5，考虑成 5 或 6 元环，最多一个杂原子
        carbon_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() == 'C')
        if carbon_count > 5 and np.random.random() < 0.5:  # 30% 概率尝试成环
            ring_size = np.random.choice([5, 6, 7])  # 仅允许 5 或 6 元环
            try:
                # 统计当前杂原子数
                hetero_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() not in ['C', 'H'])
                # 构建新环
                ring = Chem.MolFromSmiles('C' * ring_size)
                ring_editable = Chem.RWMol(ring)
                for i in range(ring_size - 1):
                    if not ring_editable.GetBondBetweenAtoms(i, i + 1):
                        ring_editable.AddBond(i, i + 1, Chem.BondType.SINGLE)
                if not ring_editable.GetBondBetweenAtoms(ring_size - 1, 0):
                    ring_editable.AddBond(ring_size - 1, 0, Chem.BondType.SINGLE)
                # 随机替换一个碳为杂原子（若杂原子总数 < 1）
                if hetero_count < 1 and np.random.random() < 0.5:
                    hetero_atom = np.random.choice(['O', 'N'])
                    ring_editable.ReplaceAtom(np.random.randint(0, ring_size), Chem.Atom(hetero_atom))
                # 合并到现有分子
                combined = Chem.CombineMols(editable, ring_editable)
                editable = Chem.RWMol(combined)
                last_atom = editable.GetNumAtoms() - 1
                if not editable.GetBondBetweenAtoms(pos, last_atom - ring_size + 1):
                    editable.AddBond(pos, last_atom - ring_size + 1, Chem.BondType.SINGLE)
            except Exception as e:
                print(f"Ring formation failed: {e}")
                pass  # 成环失败，保持原样
    
    elif operation == 'modify' and num_atoms > 0:
        print(f"Modifying atom at pos={pos}, num_atoms={num_atoms}, from {current_symbol} to {new_atom}")
        if pos < num_atoms:
            new_atom = new_atom if new_atom else select_likely_atom(torch.tensor(sequence_to_features(smiles)[1]))
            editable.ReplaceAtom(pos, Chem.Atom(new_atom))
            # 如果修改为 O 且前一个原子允许双键，尝试添加双键
            if new_atom == 'O' and pos > 0:
                prev_atom = editable.GetAtomWithIdx(pos - 1)
                if prev_atom.GetSymbol() in ['C', 'N']:
                    if editable.GetBondBetweenAtoms(pos - 1, pos):
                        editable.RemoveBond(pos - 1, pos)
                    editable.AddBond(pos - 1, pos, Chem.BondType.DOUBLE)
        else:
            print(f"Warning: pos {pos} out of range, skipping modification")
            return smiles
    
    new_smiles = Chem.MolToSmiles(editable)
    # 验证连通性
    new_mol = Chem.MolFromSmiles(new_smiles)
    if new_mol and len(Chem.GetMolFrags(new_mol)) > 1:
        print(f"Warning: Generated SMILES {new_smiles} contains multiple fragments, rejected")
        return smiles
    return new_smiles if validate_molecule(new_smiles) else smiles

from rdkit.Chem import rdmolops

def is_aromatic_ring_formed(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        ssr = rdmolops.GetSymmSSSR(mol)
        for ring in ssr:
            atoms = [mol.GetAtomWithIdx(i) for i in ring]
            if all(a.GetIsAromatic() for a in atoms):
                return True
        return False
    except:
        return False


def validate_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None or not mol.GetNumAtoms() > 0:
        return False
    try:
        Chem.SanitizeMol(mol)
        return True
    except:
        return False

def select_likely_atom(lig_features):
    hbond_tendency = lig_features[:, 0].mean().cpu().item()
    hbond_composite = lig_features[:, 1].mean().cpu().item()
    aromaticity = lig_features[:, 2].mean().cpu().item()
    
    scores = {}
    for atom, props in ATOM_PROPERTIES.items():
        hbond_score = abs(props['hbond_donor'] - hbond_composite) + abs(props['hbond_acceptor'] - hbond_composite)
        aromatic_score = abs(0.5 - aromaticity)
        polarity_score = abs(props['charge'] - hbond_tendency)
        base_score = - (0.4 * hbond_score + 0.3 * aromatic_score + 0.3 * polarity_score)
        if atom == 'S':
            scores[atom] = base_score * 0.2  # 减少 S 原子使用
        elif atom == 'O':
            scores[atom] = base_score * 5  # 增加 O 原子使用
        elif atom == 'C':
            scores[atom] = base_score * 1.5
        else:
            scores[atom] = base_score

    atoms = list(scores.keys())
    logits = torch.tensor([scores[a] for a in atoms])
    probs = F.softmax(logits, dim=0).numpy()
    return np.random.choice(atoms, p=probs)

def generate_ligand(receptor_file, target_delta_g=-50.0, num_iterations=20):
    device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
    model = BindingPredictor().to(device)
    model.load_state_dict(torch.load(''))
    model.eval()

    receptor = torch.load(receptor_file)
    receptor_points = receptor['points'].to(device)
    receptor_features = receptor['features'].to(device)
    receptor_features = torch.cat([receptor_features, torch.zeros(receptor_features.size(0), 1, device=device)], dim=1)

    initial_pool = [
        'CCC', 'CCCC', 'c1ccccc1',
        'C1CCCC1', 'C1COCC1', 'C1CNCC1', 'C1COCCC1',
        'CC(=O)O', 'C(=O)O', 'CCC(=O)O',
        'CN(C)C', 'CCN', 'CC(N)C'
    ]
    smiles = np.random.choice(initial_pool)

    points, features = sequence_to_features(smiles)
    points = points.to(device)
    features = features.to(device)

    ligand_opt = LigandOptimizer(points, features).to(device)
    optimizer = optim.Adam(list(model.parameters()) + list(ligand_opt.parameters()), lr=1e-3)
    delta_g_criterion = nn.MSELoss()

    best_smiles = smiles
    best_points = ligand_opt.points.clone()
    best_features = ligand_opt.features.clone()
    best_delta_g = float('inf')

    for iteration in range(num_iterations):
        optimizer.zero_grad()

        receptor_pocket, ligand_pocket, interaction, delta_g = model(receptor_points, receptor_features, ligand_opt.points, ligand_opt.features)
        delta_g = delta_g.mean()
        delta_g_loss = delta_g_criterion(delta_g.unsqueeze(0), torch.tensor([target_delta_g], device=device))
        loss = delta_g_loss
        loss.backward()

        print(f"Loss: {loss.item()}")
        if ligand_opt.points.grad is None or ligand_opt.features.grad is None:
            print(f"Iteration {iteration}: Gradient is None for points or features")
        else:
            optimizer.step()

        current_delta_g = delta_g.item()

        candidate_smiles = []
        candidate_operations = []
        mol = Chem.MolFromSmiles(smiles)
        carbon_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() == 'C')
        for pos in range(mol.GetNumAtoms() + 1):
            new_smiles = modify_smiles(smiles, 'add', pos)
            if new_smiles != smiles and validate_molecule(new_smiles):
                candidate_smiles.append(new_smiles)
                candidate_operations.append(('add', pos, select_likely_atom(features)))

        for pos in range(mol.GetNumAtoms()):
            new_smiles = modify_smiles(smiles, 'modify', pos)
            if new_smiles != smiles and validate_molecule(new_smiles):
                candidate_smiles.append(new_smiles)
                candidate_operations.append(('modify', pos, mol.GetAtomWithIdx(pos).GetSymbol()))

            new_smiles = modify_smiles(smiles, 'delete', pos)
            if new_smiles != smiles and validate_molecule(new_smiles):
                candidate_smiles.append(new_smiles)
                candidate_operations.append(('delete', pos, mol.GetAtomWithIdx(pos).GetSymbol()))

            # 成环操作
            if carbon_count > 5:
                new_smiles = modify_smiles(smiles, 'add', pos)
                if new_smiles != smiles and validate_molecule(new_smiles):
                    candidate_smiles.append(new_smiles)
                    candidate_operations.append(('ring', pos, 'C'))

        best_new_smiles = smiles
        best_delta_g_value = current_delta_g
        best_operation = None
        best_pos = None
        best_new_atom = None

        for new_smiles, (operation, pos, new_atom) in zip(candidate_smiles, candidate_operations):
            new_mol = Chem.MolFromSmiles(new_smiles)
            if new_mol and len(Chem.GetMolFrags(new_mol)) > 1:
                continue  # 跳过多分子结构
            new_points, new_features = sequence_to_features(new_smiles)
            new_points = new_points.to(device)
            new_features = new_features.to(device)
            with torch.no_grad():
                _, _, _, new_delta_g = model(receptor_points, receptor_features, new_points, new_features)
                new_delta_g = new_delta_g.mean()
            delta_delta_g = new_delta_g.item() - current_delta_g
            if is_aromatic_ring_formed(new_smiles):
                delta_delta_g -= 0.5
            print(f"[操作 {operation}, 位置 {pos}, 原子 {new_atom}] 迭代 {iteration:03d} | ΔG: {new_delta_g:.4f} | ΔΔG: {delta_delta_g:.4f}")
            if new_delta_g < best_delta_g_value:
                best_delta_g_value = new_delta_g
                best_new_smiles = new_smiles
                best_operation = operation
                best_pos = pos
                best_new_atom = new_atom

        if best_delta_g_value < current_delta_g:
            smiles = best_new_smiles
            points, features = sequence_to_features(smiles)
            points = points.to(device)
            features = features.to(device)
            current_delta_g = best_delta_g_value

        if current_delta_g < best_delta_g:
            best_delta_g = current_delta_g
            best_smiles = smiles
            best_points = points.clone()
            best_features = features.clone()
            print(f"[最佳 ΔG 更新] 迭代 {iteration:03d} → ΔG: {best_delta_g:.4f} | SMILES: {best_smiles}")

        print(f"迭代 {iteration+1}, 当前 ΔG: {current_delta_g:.4f}")

    return best_smiles, best_points.cpu().numpy(), best_delta_g

import pyrosetta
from rdkit import Chem
from rdkit.Chem import AllChem
import os
import subprocess

# 初始化pyRosetta（只需调用一次）
pyrosetta.init(extra_options="-mute all")  # 可根据需要调整选项

def prepare_ligand_params(smiles, ligand_name="LIG", out_dir="tmp/ligand"):
    """从SMILES生成配体参数文件和PDB文件，使用PyRosetta"""
    os.makedirs(out_dir, exist_ok=True)

    # Step 1: SMILES -> 3D MOL -> PDB
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, randomSeed=42)
    AllChem.UFFOptimizeMolecule(mol)
    
    # 使用PDBWriter并显式设置残基信息
    pdb_file = os.path.join(out_dir, f"{ligand_name}_0001.pdb")
    writer = Chem.PDBWriter(pdb_file)
    mol.SetProp("_Name", ligand_name)
    for atom in mol.GetAtoms():
        atom.SetProp("residueName", ligand_name)
        atom.SetProp("residueNumber", "1")
        atom.SetProp("chainId", "L")  # 添加链ID
    writer.write(mol)
    writer.close()
    with open(pdb_file, 'a') as f:  # 添加TER记录
        f.write("TER\nEND")

    # Step 2: 验证并加载Pose
    try:
        pose = pyrosetta.pose_from_file(pdb_file)
        if pose.total_residue() == 0:
            with open(pdb_file, 'r') as f:
                print(f"Debug: PDB content: {f.read()}")
            raise ValueError(f"Generated PDB file {pdb_file} contains no residues")
        print(f"Debug: Pose loaded with {pose.total_residue()} residues")
    except Exception as e:
        raise RuntimeError(f"Failed to load pose from {pdb_file}: {str(e)}")

    # Step 3: 生成参数文件
    params_file = os.path.join(out_dir, f"{ligand_name}.params")
    residue = pose.residue(1)

    from pyrosetta.rosetta.core.chemical import ResidueTypeSet, ResidueType
    from pyrosetta.rosetta.core.chemical import AtomTypeSet
    from pyrosetta.rosetta.core.conformation import Residue

    chem_info = pyrosetta.rosetta.core.chemical.ChemicalManager.get_instance().residue_type_set("fa_standard")
    new_residue_type = ResidueType(chem_info)

    # 手动设置配体残基的属性
    for i in range(1, residue.natoms() + 1):
        atom_name = residue.atom_name(i).strip()
        atom_type = residue.atom_type(i).element()
        new_residue_type.add_atom(atom_name, atom_type)

    # 增强参数文件内容
    with open(params_file, 'w') as f:
        f.write(f"NAME {ligand_name}\n")
        f.write("TYPE LIGAND\n")
        f.write("IO_STRING :AUTO\n")
        f.write("PROPERTIES VARIANT_TYPE=TRANSLATION\n")  # 支持小分子变体
        for i in range(1, residue.natoms() + 1):
            atom_name = residue.atom_name(i).strip()
            x, y, z = residue.xyz(i)
            element = residue.atom_type(i).element()
            f.write(f"ATOM {atom_name} {element} {x} {y} {z}\n")
        # 添加简单的键信息（基于CONECT记录）
        mol = Chem.MolFromPDBFile(pdb_file)
        for bond in mol.GetBonds():
            atom1_idx = bond.GetBeginAtomIdx() + 1  # PDB 索引从1开始
            atom2_idx = bond.GetEndAtomIdx() + 1
            f.write(f"BOND {residue.atom_name(atom1_idx).strip()} {residue.atom_name(atom2_idx).strip()}\n")

    return params_file, pdb_file

def smiles_to_pose(smiles, tmp_dir="tmp"):
    """从SMILES生成带氢3D结构，用pyRosetta加载为Pose"""
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, randomSeed=42)
    AllChem.UFFOptimizeMolecule(mol)
    ligand_pdb = os.path.join(tmp_dir, "ligand.pdb")
    Chem.MolToPDBFile(mol, ligand_pdb)

    ligand_pose = pyrosetta.pose_from_pdb(ligand_pdb)
    return ligand_pose

def load_protein_pose(pdb_path):
    """加载蛋白质Pose"""
    return pyrosetta.pose_from_pdb(pdb_path)

def assemble_complex(protein_pose, ligand_pose):
    """
    将蛋白和配体拼接到一个Pose中，用于能量计算
    """
    complex_pose = pyrosetta.Pose()
    complex_pose.assign(protein_pose)
    complex_pose.append_pose_by_jump(ligand_pose, protein_pose.total_residue())
    return complex_pose

def calculate_binding_energy(protein_pose, ligand_pose, complex_pose):
    """
    计算结合能 ΔG = E_complex - (E_protein + E_ligand)
    使用标准Rosetta评分函数
    """
    scorefxn = pyrosetta.get_fa_scorefxn()
    E_complex = scorefxn(complex_pose)
    E_protein = scorefxn(protein_pose)
    E_ligand = scorefxn(ligand_pose)
    delta_G = E_complex - (E_protein + E_ligand)
    return delta_G

def generate_and_evaluate_ligand(receptor_file, protein_pdb, target_delta_g=-50.0, num_iterations=5):
    """
    生成配体并使用Rosetta计算与蛋白质的结合能
    """
    smiles, points, predicted_delta_g = generate_ligand(receptor_file, target_delta_g=target_delta_g, num_iterations=num_iterations)
    print(f"生成的小分子配体 - SMILES: {smiles}, 预测 ΔG: {predicted_delta_g}")

    print("加载蛋白质结构...")
    protein_pose = load_protein_pose(protein_pdb)

    print("从SMILES生成配体Pose...")
    ligand_pose = smiles_to_pose(smiles)

    print("准备配体参数文件...")
    params_path, pdb_path = prepare_ligand_params(smiles, ligand_name="LIG")

    print("重新加载配体Pose...")
    ligand_pose = pyrosetta.pose_from_file(pdb_path)

    print("组装蛋白-配体复合体Pose...")
    complex_pose = assemble_complex(protein_pose, ligand_pose)

    print("计算结合能...")
    delta_G = calculate_binding_energy(protein_pose, ligand_pose, complex_pose)
    print(f"Rosetta 计算结合能 ΔG = {delta_G:.3f} Rosetta 能量单位")

    return smiles, predicted_delta_g, delta_G

if __name__ == "__main__":
    receptor_file = ""
    protein_pdb = ""

    # 生成并评估配体
    best_smiles, predicted_delta_g, rosetta_delta_g = generate_and_evaluate_ligand(
        receptor_file, protein_pdb, target_delta_g=-50.0, num_iterations=20
    )
    print(f"最终结果 - SMILES: {best_smiles}, 预测 ΔG: {predicted_delta_g}, Rosetta ΔG: {rosetta_delta_g:.3f}")
