
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import tqdm
import pandas as pd
from torch_geometric.nn import MessagePassing, knn_graph

# 自定义 SE(3) 等变卷积层
class SE3EquivariantConv(MessagePassing):
    def __init__(self, in_channels=11, out_channels=64, k=16):
        super(SE3EquivariantConv, self).__init__(aggr='mean')
        self.k = k
        self.mlp = nn.Sequential(
            nn.Linear(in_channels + 3, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, out_channels)
        )

    def forward(self, x, pos, batch):
        x = x.view(-1, x.size(-1))  # [B * N, in_channels]
        pos = pos.view(-1, 3)  # [B * N, 3]
        edge_index = knn_graph(pos, k=self.k, batch=batch, loop=False)
        out = self.propagate(edge_index, x=x, pos=pos)
        return out.view(-1, 5000, out.size(-1))  # [B, N, out_channels]

    def message(self, x_i, x_j, pos_i, pos_j):
        # 计算相对位置向量
        rel_pos = pos_i - pos_j  # [num_edges, 3]
        input = torch.cat([x_i, rel_pos], dim=-1)  # [num_edges, in_channels + 3]
        return self.mlp(input)

# 自定义数据集
class PPIDataset(Dataset):
    def __init__(self, data_dir, pair_file):
        self.data_dir = data_dir
        self.pt_files = {os.path.basename(f).split('.')[0]: f for f in glob.glob(os.path.join(data_dir, "*.pt"))}
        self.feature_dim = 11  # 仅 features [5000, 11]
        self.valid_files = list(self.pt_files.values())
        print(f"有效文件数: {len(self.valid_files)}")
        
        # 检查配对文件是否存在
        if not os.path.isfile(pair_file):
            raise FileNotFoundError(f"配对文件 {pair_file} 不存在，请先生成！")
        
        # 加载配对文件
        self.pairs = pd.read_csv(pair_file)
        self.valid_pairs = []
        for idx, row in self.pairs.iterrows():
            rec_id_wt = row['receptor_id_wt']
            lig_id_wt = row['ligand_id_wt']
            rec_id_mut = row['receptor_id_mut']
            lig_id_mut = row['ligand_id_mut']
            if all(id in self.pt_files for id in [rec_id_wt, lig_id_wt, rec_id_mut, lig_id_mut]):
                self.valid_pairs.append(idx)
            else:
                print(f"警告：配对 {rec_id_wt}-{lig_id_wt}-{rec_id_mut}-{lig_id_mut} 中蛋白质文件缺失，将被忽略")
        print(f"有效配对数: {len(self.valid_pairs)}")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        pair_idx = self.valid_pairs[idx]
        row = self.pairs.iloc[pair_idx]
        rec_id_wt = row['receptor_id_wt']
        lig_id_wt = row['ligand_id_wt']
        rec_id_mut = row['receptor_id_mut']
        lig_id_mut = row['ligand_id_mut']
        interaction_label = row['interaction_label']
        delta_delta_g = row['delta_delta_g'] if 'delta_delta_g' in row else None
        
        # 加载野生型和突变型数据
        rec_data_wt = torch.load(self.pt_files[rec_id_wt], map_location=torch.device('cpu'))
        lig_data_wt = torch.load(self.pt_files[lig_id_wt], map_location=torch.device('cpu'))
        rec_data_mut = torch.load(self.pt_files[rec_id_mut], map_location=torch.device('cpu'))
        lig_data_mut = torch.load(self.pt_files[lig_id_mut], map_location=torch.device('cpu'))
        
        # 提取数据
        rec_points_wt = rec_data_wt['points']  # [5000, 3]，用于 KNN 图
        rec_features_wt = rec_data_wt['features']  # [5000, 11]
        rec_labels_wt = rec_data_wt['iface_labels']  # [5000]
        delta_g_wt = rec_data_wt['delta_g'] if 'delta_g' in rec_data_wt else torch.tensor(float('nan'))
        
        lig_points_wt = lig_data_wt['points']  # [5000, 3]
        lig_features_wt = lig_data_wt['features']  # [5000, 11]
        lig_labels_wt = lig_data_wt['iface_labels']  # [5000]
        
        rec_points_mut = rec_data_mut['points']  # [5000, 3]
        rec_features_mut = rec_data_mut['features']  # [5000, 11]
        rec_labels_mut = rec_data_mut['iface_labels']  # [5000]
        delta_g_mut = rec_data_mut['delta_g'] if 'delta_g' in rec_data_mut else torch.tensor(float('nan'))
        
        lig_points_mut = lig_data_mut['points']  # [5000, 3]
        lig_features_mut = lig_data_mut['features']  # [5000, 11]
        lig_labels_mut = lig_data_mut['iface_labels']  # [5000]
        
        # 如果 delta_delta_g 未提供，尝试从 delta_g 计算
        if delta_delta_g is None:
            delta_delta_g = delta_g_mut - delta_g_wt if not (torch.isnan(delta_g_mut) or torch.isnan(delta_g_wt)) else torch.tensor(float('nan'))
        
        return (rec_points_wt, rec_features_wt, rec_labels_wt, rec_id_wt,
                lig_points_wt, lig_features_wt, lig_labels_wt, lig_id_wt,
                rec_points_mut, rec_features_mut, rec_labels_mut, rec_id_mut,
                lig_points_mut, lig_features_mut, lig_labels_mut, lig_id_mut,
                interaction_label, delta_delta_g)

# 微调模型
class PPIPredictor(nn.Module):
    def __init__(self, pretrained_encoder, out_channels=64):
        super(PPIPredictor, self).__init__()
        self.encoder = pretrained_encoder
        # Pocket 分类头
        self.pocket_head = nn.Linear(out_channels, 1)
        # Interaction 分类头
        self.interaction_head = nn.Sequential(
            nn.Linear(out_channels * 4, 128),  # wt + mut
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        # ΔΔG 回归头
        self.delta_g_head = nn.Sequential(
            nn.Linear(out_channels * 4, 128),  # wt + mut
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, rec_points_wt, rec_features_wt, lig_points_wt, lig_features_wt,
                rec_points_mut, rec_features_mut, lig_points_mut, lig_features_mut, batch):
        # 编码野生型和突变型特征
        rec_z_wt = self.encoder(rec_features_wt, rec_points_wt, batch)  # [B, 5000, 64]
        lig_z_wt = self.encoder(lig_features_wt, lig_points_wt, batch)  # [B, 5000, 64]
        rec_z_mut = self.encoder(rec_features_mut, rec_points_mut, batch)  # [B, 5000, 64]
        lig_z_mut = self.encoder(lig_features_mut, lig_points_mut, batch)  # [B, 5000, 64]
        
        # Pocket 预测
        rec_pocket_logits_wt = self.pocket_head(rec_z_wt).squeeze(-1)  # [B, 5000]
        lig_pocket_logits_wt = self.pocket_head(lig_z_wt).squeeze(-1)  # [B, 5000]
        rec_pocket_logits_mut = self.pocket_head(rec_z_mut).squeeze(-1)  # [B, 5000]
        lig_pocket_logits_mut = self.pocket_head(lig_z_mut).squeeze(-1)  # [B, 5000]
        
        # 聚合 pocket 特征（基于预测的 pocket 点）
        rec_pocket_mask_wt = (rec_pocket_logits_wt > 0).float()  # [B, 5000]
        lig_pocket_mask_wt = (lig_pocket_logits_wt > 0).float()  # [B, 5000]
        rec_pocket_mask_mut = (rec_pocket_logits_mut > 0).float()  # [B, 5000]
        lig_pocket_mask_mut = (lig_pocket_logits_mut > 0).float()  # [B, 5000]
        
        rec_pocket_features_wt = (rec_z_wt * rec_pocket_mask_wt.unsqueeze(-1)).sum(dim=1) / (rec_pocket_mask_wt.sum(dim=1, keepdim=True) + 1e-6)  # [B, 64]
        lig_pocket_features_wt = (lig_z_wt * lig_pocket_mask_wt.unsqueeze(-1)).sum(dim=1) / (lig_pocket_mask_wt.sum(dim=1, keepdim=True) + 1e-6)  # [B, 64]
        rec_pocket_features_mut = (rec_z_mut * rec_pocket_mask_mut.unsqueeze(-1)).sum(dim=1) / (rec_pocket_mask_mut.sum(dim=1, keepdim=True) + 1e-6)  # [B, 64]
        lig_pocket_features_mut = (lig_z_mut * lig_pocket_mask_mut.unsqueeze(-1)).sum(dim=1) / (lig_pocket_mask_mut.sum(dim=1, keepdim=True) + 1e-6)  # [B, 64]
        
        # 拼接受体和配体特征（野生型 + 突变型）
        pair_features = torch.cat([rec_pocket_features_wt, lig_pocket_features_wt, rec_pocket_features_mut, lig_pocket_features_mut], dim=-1)  # [B, 256]
        
        # Interaction 预测
        interaction_logits = self.interaction_head(pair_features).squeeze(-1)  # [B]
        
        # ΔΔG 预测
        delta_delta_g = self.delta_g_head(pair_features).squeeze(-1)  # [B]
        
        return (rec_pocket_logits_wt, lig_pocket_logits_wt, rec_pocket_logits_mut, lig_pocket_logits_mut,
                interaction_logits, delta_delta_g)

# 训练函数
def train_ppi_predictor(model, train_loader, epochs=50, lr=1e-4, device='cuda:1'):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    pocket_criterion = nn.BCEWithLogitsLoss()
    interaction_criterion = nn.BCEWithLogitsLoss()
    delta_g_criterion = nn.MSELoss()
    model.to(device)
    model.train()
    
    for epoch in range(epochs):
        total_pocket_loss = 0.0
        total_interaction_loss = 0.0
        total_delta_g_loss = 0.0
        total_samples = 0
        for batch_idx, batch in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", total=len(train_loader), unit="batch")):
            (rec_points_wt, rec_features_wt, rec_labels_wt, rec_id_wt,
             lig_points_wt, lig_features_wt, lig_labels_wt, lig_id_wt,
             rec_points_mut, rec_features_mut, rec_labels_mut, rec_id_mut,
             lig_points_mut, lig_features_mut, lig_labels_mut, lig_id_mut,
             interaction_label, delta_delta_g) = batch
            
            rec_features_wt = rec_features_wt.to(device)
            lig_features_wt = lig_features_wt.to(device)
            rec_points_wt = rec_points_wt.to(device)
            lig_points_wt = lig_points_wt.to(device)
            rec_labels_wt = rec_labels_wt.to(device)
            lig_labels_wt = lig_labels_wt.to(device)
            rec_features_mut = rec_features_mut.to(device)
            lig_features_mut = lig_features_mut.to(device)
            rec_points_mut = rec_points_mut.to(device)
            lig_points_mut = lig_points_mut.to(device)
            rec_labels_mut = rec_labels_mut.to(device)
            lig_labels_mut = lig_labels_mut.to(device)
            interaction_label = interaction_label.to(device).float()
            delta_delta_g = delta_delta_g.to(device).float()
            
            # 批次索引用于 KNN 图
            batch_size = rec_points_wt.size(0)
            batch = torch.arange(batch_size, device=device).repeat_interleave(5000)
            
            optimizer.zero_grad()
            (rec_pocket_logits_wt, lig_pocket_logits_wt, rec_pocket_logits_mut, lig_pocket_logits_mut,
             interaction_logits, pred_delta_delta_g) = model(
                rec_points_wt, rec_features_wt, lig_points_wt, lig_features_wt,
                rec_points_mut, rec_features_mut, lig_points_mut, lig_features_mut, batch
            )
            
            # Pocket 分类损失（野生型 + 突变型）
            pocket_loss = (pocket_criterion(rec_pocket_logits_wt, rec_labels_wt) +
                           pocket_criterion(lig_pocket_logits_wt, lig_labels_wt) +
                           pocket_criterion(rec_pocket_logits_mut, rec_labels_mut) +
                           pocket_criterion(lig_pocket_logits_mut, lig_labels_mut)) / 4
            
            # Interaction 分类损失
            interaction_loss = interaction_criterion(interaction_logits, interaction_label)
            
            # ΔΔG 回归损失（仅对有效 ΔΔG）
            delta_g_mask = (~torch.isnan(delta_delta_g)).float()
            delta_g_loss = delta_g_criterion(pred_delta_delta_g * delta_g_mask, delta_delta_g * delta_g_mask)
            delta_g_loss = delta_g_loss * torch.sum(delta_g_mask) / (torch.sum(delta_g_mask) + 1e-6)
            
            # 总损失
            loss = pocket_loss + interaction_loss + delta_g_loss
            loss.backward()
            optimizer.step()
            
            total_pocket_loss += pocket_loss.item() * batch_size
            total_interaction_loss += interaction_loss.item() * batch_size
            total_delta_g_loss += delta_g_loss.item() * batch_size
            total_samples += batch_size
        
        avg_pocket_loss = total_pocket_loss / total_samples
        avg_interaction_loss = total_interaction_loss / total_samples
        avg_delta_g_loss = total_delta_g_loss / total_samples
        print(f"Epoch {epoch+1}, Pocket Loss: {avg_pocket_loss:.4f}, Interaction Loss: {avg_interaction_loss:.4f}, ΔΔG Loss: {avg_delta_g_loss:.4f}")
    
    return model

# 保存模型权重
def save_model(model, save_path):
    torch.save(model.state_dict(), save_path)
    print(f"模型权重已保存至 {save_path}")

# 生成配对文件（如果不存在）
def generate_pair_file(excel_file, data_dir, output_file, train=True):
    df = pd.read_excel(excel_file)
    # 数据划分 (80% 训练，20% 测试)
    train_df = df.sample(frac=0.8, random_state=42)
    target_df = train_df if train else df.drop(train_df.index)
    
    pairs = []
    pt_files = {os.path.basename(f).split('.')[0]: f for f in glob.glob(os.path.join(data_dir, "*.pt"))}
    for _, row in target_df.iterrows():
        pdb_id = row['PDB']
        rec_id_wt = f"{pdb_id}_receptor_wt"
        lig_id_wt = f"{pdb_id}_ligand_wt"
        rec_id_mut = f"{pdb_id}_receptor_mut"
        lig_id_mut = f"{pdb_id}_ligand_mut"
        if all(id in pt_files for id in [rec_id_wt, lig_id_wt, rec_id_mut, lig_id_mut]):
            delta_g_wt = torch.load(pt_files[rec_id_wt])['delta_g']
            delta_g_mut = torch.load(pt_files[rec_id_mut])['delta_g']
            delta_delta_g = delta_g_mut - delta_g_wt if not (torch.isnan(delta_g_wt) or torch.isnan(delta_g_mut)) else float('nan')
            pairs.append({
                'receptor_id_wt': rec_id_wt,
                'ligand_id_wt': lig_id_wt,
                'receptor_id_mut': rec_id_mut,
                'ligand_id_mut': lig_id_mut,
                'interaction_label': 1,  # SKEMPI 数据为正样本
                'delta_delta_g': delta_delta_g
            })
    pairs_df = pd.DataFrame(pairs)
    pairs_df.to_csv(output_file, index=False)
    print(f"生成配对文件: {output_file}, 配对数: {len(pairs)}")

# 主程序
if __name__ == "__main__":
    # 数据路径
    train_data_dir = ""
    test_data_dir = ""
    pair_dir = ""
    train_pair_file = os.path.join(pair_dir, "train_pairs.csv")
    test_pair_file = os.path.join(pair_dir, "test_pairs.csv")
    excel_file = ""
    
    # 检查并生成配对文件
    if not os.path.isfile(train_pair_file):
        print(f"警告：{train_pair_file} 不存在，生成中...")
        generate_pair_file(excel_file, train_data_dir, train_pair_file, train=True)
    if not os.path.isfile(test_pair_file):
        print(f"警告：{test_pair_file} 不存在，生成中...")
        generate_pair_file(excel_file, test_data_dir, test_pair_file, train=False)
    
    # 数据加载
    train_dataset = PPIDataset(train_data_dir, train_pair_file)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
    
    # 加载预训练编码器
    encoder = SE3EquivariantConv(in_channels=11, out_channels=64)
    try:
        encoder.load_state_dict(torch.load(""))
    except Exception as e:
        print(f"警告：无法加载预训练编码器: {e}. 使用随机初始化。")
        # 继续训练，即使预训练权重加载失败
    
    # 初始化模型
    model = PPIPredictor(encoder)
    
    # 训练
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    if device.type == 'cpu':
        print("警告：CUDA:1 不可用，回退到 CPU")
    model = train_ppi_predictor(model, train_loader, epochs=50, lr=1e-4, device=device)
    
    # 保存模型权重
    save_path = ""
    save_model(model, save_path)