import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import tqdm
import pandas as pd
from torch_geometric.nn import MessagePassing, knn_graph
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 自定义 Label Smoothing BCEWithLogitsLoss
class LabelSmoothingBCEWithLogitsLoss(nn.Module):
    def __init__(self, smoothing=0.05, pos_weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.pos_weight = pos_weight

    def forward(self, pred, target):
        target = target * (1 - self.smoothing) + 0.5 * self.smoothing
        return F.binary_cross_entropy_with_logits(pred, target, pos_weight=self.pos_weight)

# 自定义 SE(3) 等变卷积层
class SE3EquivariantConv(MessagePassing):
    def __init__(self, in_channels=11, out_channels=64, k=16):
        super(SE3EquivariantConv, self).__init__(aggr='mean')
        self.k = k
        self.mlp = nn.Sequential(
            nn.Linear(in_channels + 3, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, out_channels)
        )

    def forward(self, x, pos, batch):
        x = x.view(-1, x.size(-1))  # [B * N, in_channels]
        pos = pos.view(-1, 3)  # [B * N, 3]
        edge_index = knn_graph(pos, k=self.k, batch=batch, loop=False)
        out = self.propagate(edge_index, x=x, pos=pos)
        return out.view(-1, 5000, out.size(-1))  # [B, N, out_channels]

    def message(self, x_i, x_j, pos_i, pos_j):
        rel_pos = pos_i - pos_j  # [num_edges, 3]
        input = torch.cat([x_i, rel_pos], dim=-1)  # [num_edges, in_channels + 3]
        return self.mlp(input)

# 自定义数据集
class PPIDataset(Dataset):
    def __init__(self, data_dir, pair_file):
        self.data_dir = data_dir
        self.pt_files = {os.path.basename(f).split('.')[0]: f for f in glob.glob(os.path.join(data_dir, "*.pt"))}
        self.feature_dim = 11  # features [5000, 11]
        self.valid_files = list(self.pt_files.values())
        print(f"有效文件数: {len(self.valid_files)}")
        
        # 统计正例比例
        pos_ratios = [torch.load(f)['iface_labels'].float().mean().item() for f in self.valid_files]
        avg_pos_ratio = np.median(pos_ratios) if pos_ratios else 0.05
        print(f"平均正例比例 (中位数): {avg_pos_ratio:.4f}")
        
        # 检查配对文件
        if not os.path.isfile(pair_file):
            raise FileNotFoundError(f"配对文件 {pair_file} 不存在，请先生成！")
        
        # 加载配对文件
        self.pairs = pd.read_csv(pair_file)
        self.valid_pairs = []
        for idx, row in self.pairs.iterrows():
            rec_id_wt = row['receptor_id_wt']
            lig_id_wt = row['ligand_id_wt']
            if all(id in self.pt_files for id in [rec_id_wt, lig_id_wt]):
                self.valid_pairs.append(idx)
            else:
                print(f"警告：配对 {rec_id_wt}-{lig_id_wt} 中蛋白质文件缺失，将被忽略")
        print(f"有效配对数: {len(self.valid_pairs)}")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        pair_idx = self.valid_pairs[idx]
        row = self.pairs.iloc[pair_idx]
        rec_id_wt = row['receptor_id_wt']
        lig_id_wt = row['ligand_id_wt']
        interaction_label = row['interaction_label']
        delta_g = row['delta_g']
        
        # 加载野生型数据
        rec_data_wt = torch.load(self.pt_files[rec_id_wt], map_location=torch.device('cpu'))
        lig_data_wt = torch.load(self.pt_files[lig_id_wt], map_location=torch.device('cpu'))
        
        # 提取数据
        rec_points_wt = rec_data_wt['points']  # [5000, 3]
        rec_features_wt = rec_data_wt['features']  # [5000, 11]
        rec_labels_wt = rec_data_wt['iface_labels']  # [5000]
        lig_points_wt = lig_data_wt['points']  # [5000, 3]
        lig_features_wt = lig_data_wt['features']  # [5000, 11]
        lig_labels_wt = lig_data_wt['iface_labels']  # [5000]
        
        # 验证数据有效性
        assert not torch.isnan(rec_labels_wt).any(), f"Invalid iface_labels in {rec_id_wt}"
        assert not torch.isnan(lig_labels_wt).any(), f"Invalid iface_labels in {lig_id_wt}"
        assert not torch.isnan(rec_features_wt).any(), f"Invalid features in {rec_id_wt}"
        assert not torch.isnan(lig_features_wt).any(), f"Invalid features in {lig_id_wt}"
        
        return (rec_points_wt, rec_features_wt, rec_labels_wt, rec_id_wt,
                lig_points_wt, lig_features_wt, lig_labels_wt, lig_id_wt,
                interaction_label, delta_g)

# 微调模型
class PPIPredictor(nn.Module):
    def __init__(self, pretrained_encoder, out_channels=64):
        super(PPIPredictor, self).__init__()
        self.encoder = pretrained_encoder
        # Pocket 分类头
        self.pocket_head = nn.Sequential(
            nn.Linear(out_channels, 64),
            nn.ReLU(),
            nn.Dropout(0.2),  # 添加 Dropout
            nn.Linear(64, 1)
        )
        # Interaction 分类头
        self.interaction_head = nn.Sequential(
            nn.Linear(out_channels * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        # ΔG 回归头
        self.delta_g_head = nn.Sequential(
            nn.Linear(out_channels * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, rec_points_wt, rec_features_wt, lig_points_wt, lig_features_wt, batch):
        # 编码野生型特征
        rec_z_wt = self.encoder(rec_features_wt, rec_points_wt, batch)  # [B, 5000, 64]
        lig_z_wt = self.encoder(lig_features_wt, lig_points_wt, batch)  # [B, 5000, 64]
        
        # Pocket 预测
        rec_pocket_logits_wt = self.pocket_head(rec_z_wt).squeeze(-1)  # [B, 5000]
        lig_pocket_logits_wt = self.pocket_head(lig_z_wt).squeeze(-1)  # [B, 5000]
        
        # 聚合 pocket 特征
        rec_pocket_mask_wt = (torch.sigmoid(rec_pocket_logits_wt) > 0.5).float()  # [B, 5000]
        lig_pocket_mask_wt = (torch.sigmoid(lig_pocket_logits_wt) > 0.5).float()  # [B, 5000]
        
        # 防止除零
        rec_sum = rec_pocket_mask_wt.sum(dim=1, keepdim=True)
        lig_sum = lig_pocket_mask_wt.sum(dim=1, keepdim=True)
        rec_sum = torch.where(rec_sum > 0, rec_sum, torch.tensor(1.0, device=rec_sum.device))
        lig_sum = torch.where(lig_sum > 0, lig_sum, torch.tensor(1.0, device=lig_sum.device))
        
        rec_pocket_features_wt = (rec_z_wt * rec_pocket_mask_wt.unsqueeze(-1)).sum(dim=1) / rec_sum  # [B, 64]
        lig_pocket_features_wt = (lig_z_wt * lig_pocket_mask_wt.unsqueeze(-1)).sum(dim=1) / lig_sum  # [B, 64]
        
        # 拼接特征
        pair_features = torch.cat([rec_pocket_features_wt, lig_pocket_features_wt], dim=-1)  # [B, 128]
        
        # Interaction 预测
        interaction_logits = self.interaction_head(pair_features).squeeze(-1)  # [B]
        
        # ΔG 预测
        delta_g = self.delta_g_head(pair_features).squeeze(-1)  # [B]
        
        return rec_pocket_logits_wt, lig_pocket_logits_wt, interaction_logits, delta_g

# 训练函数
def train_ppi_predictor(model, train_loader, test_loader, epochs=50, lr=1e-4, device='cuda:1'):
    pos_weight = torch.tensor([3.0]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    pocket_criterion = LabelSmoothingBCEWithLogitsLoss(smoothing=0.1, pos_weight=pos_weight)
    interaction_criterion = nn.BCEWithLogitsLoss()
    delta_g_criterion = nn.MSELoss()
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        total_pocket_loss = 0.0
        total_interaction_loss = 0.0
        total_delta_g_loss = 0.0
        total_samples = 0
        
        for batch_idx, batch in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", total=len(train_loader), unit="batch")):
            (rec_points_wt, rec_features_wt, rec_labels_wt, rec_id_wt,
             lig_points_wt, lig_features_wt, lig_labels_wt, lig_id_wt,
             interaction_label, delta_g) = batch
            
            rec_features_wt = rec_features_wt.to(device)
            lig_features_wt = lig_features_wt.to(device)
            rec_points_wt = rec_points_wt.to(device)
            lig_points_wt = lig_points_wt.to(device)
            rec_labels_wt = rec_labels_wt.to(device)
            lig_labels_wt = lig_labels_wt.to(device)
            interaction_label = interaction_label.to(device).float()
            delta_g = delta_g.to(device).float()
            
            batch_size = rec_points_wt.size(0)
            batch_idx_tensor = torch.arange(batch_size, device=device).repeat_interleave(5000)
            
            optimizer.zero_grad()
            rec_pocket_logits_wt, lig_pocket_logits_wt, interaction_logits, pred_delta_g = model(
                rec_points_wt, rec_features_wt, lig_points_wt, lig_features_wt, batch_idx_tensor
            )
            
            pocket_loss = (pocket_criterion(rec_pocket_logits_wt, rec_labels_wt) +
                          pocket_criterion(lig_pocket_logits_wt, lig_labels_wt)) / 2
            interaction_loss = interaction_criterion(interaction_logits, interaction_label)
            delta_g_mask = (~torch.isnan(delta_g)).float()
            if torch.sum(delta_g_mask) > 0:
                delta_g_loss = delta_g_criterion(pred_delta_g * delta_g_mask, delta_g * delta_g_mask)
            else:
                delta_g_loss = torch.tensor(0.0, device=device)
            
            loss = 5 * pocket_loss + 50 * interaction_loss + delta_g_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_pocket_loss += pocket_loss.item() * batch_size
            total_interaction_loss += interaction_loss.item() * batch_size
            total_delta_g_loss += delta_g_loss.item() * batch_size
            total_samples += batch_size
            
            # 打印预测结果
            # pocket_pred = (torch.sigmoid(rec_pocket_logits_wt) > 0.55).float()
            # for i in range(batch_size):
            #     f1 = f1_score(rec_labels_wt[i].cpu().numpy(), pocket_pred[i].cpu().numpy())
            #     print(f"[Train Batch {batch_idx}] rec_id: {rec_id_wt[i]:<20}, True Positives: {rec_labels_wt[i].sum().item():>4}, Pred Positives: {pocket_pred[i].sum().item():>4}, True Pos Ratio: {rec_labels_wt[i].mean().item():.4f}, Batch F1: {f1:.4f}")
        
        avg_pocket_loss = total_pocket_loss / total_samples
        avg_interaction_loss = total_interaction_loss / total_samples
        avg_delta_g_loss = total_delta_g_loss / total_samples
        print(f"Epoch {epoch+1}, Train Pocket Loss: {avg_pocket_loss:.4f}, Train Interaction Loss: {avg_interaction_loss:.4f}, Train ΔG Loss: {avg_delta_g_loss:.4f}")
        
        # 测试阶段
        model.eval()
        test_pocket_acc, test_pocket_f1, test_pocket_precision, test_pocket_recall, test_delta_g_mse = [], [], [], [], []
        test_samples = 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm.tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Test]", total=len(test_loader), unit="batch")):
                (rec_points_wt, rec_features_wt, rec_labels_wt, rec_id_wt,
                 lig_points_wt, lig_features_wt, lig_labels_wt, lig_id_wt,
                 interaction_label, delta_g) = batch
                
                rec_features_wt = rec_features_wt.to(device)
                lig_features_wt = lig_features_wt.to(device)
                rec_points_wt = rec_points_wt.to(device)
                lig_points_wt = lig_points_wt.to(device)
                rec_labels_wt = rec_labels_wt.to(device)
                lig_labels_wt = lig_labels_wt.to(device)
                interaction_label = interaction_label.to(device).float()
                delta_g = delta_g.to(device).float()
                
                batch_size = rec_points_wt.size(0)
                batch_idx_tensor = torch.arange(batch_size, device=device).repeat_interleave(5000)
                
                rec_pocket_logits_wt, lig_pocket_logits_wt, inter_logits, pred_delta_g = model(
                    rec_points_wt, rec_features_wt, lig_points_wt, lig_features_wt, batch_idx_tensor
                )

                # Pocket Accuracy, F1, Precision, Recall
                pocket_pred = (torch.sigmoid(rec_pocket_logits_wt) > 0.5).cpu().numpy()
                pocket_acc = accuracy_score(rec_labels_wt.cpu().numpy().flatten(), pocket_pred.flatten())
                pocket_f1 = f1_score(rec_labels_wt.cpu().numpy().flatten(), pocket_pred.flatten())
                precision = precision_score(rec_labels_wt.cpu().numpy().flatten(), pocket_pred.flatten(), zero_division=0)
                recall = recall_score(rec_labels_wt.cpu().numpy().flatten(), pocket_pred.flatten(), zero_division=0)
                test_pocket_acc.append(pocket_acc * batch_size)
                test_pocket_f1.append(pocket_f1 * batch_size)
                test_pocket_precision.append(precision * batch_size)
                test_pocket_recall.append(recall * batch_size)
                
                # ΔG MSE
                mask = ~torch.isnan(delta_g)  # Keep as bool for indexing
                if mask.sum() > 0:
                    delta_g_mse = F.mse_loss(pred_delta_g[mask], delta_g[mask]).item()
                    test_delta_g_mse.append(delta_g_mse * mask.sum().item())
                
                test_samples += batch_size
                
                # 打印预测结果
                pocket_pred_tensor = (torch.sigmoid(rec_pocket_logits_wt) > 0.5).float()
                for i in range(batch_size):
                    f1 = f1_score(rec_labels_wt[i].cpu().numpy(), pocket_pred_tensor[i].cpu().numpy())
                    print(f"[Test Batch {batch_idx}] rec_id: {rec_id_wt[i]:<20}, True Positives: {rec_labels_wt[i].sum().item():>4}, Pred Positives: {pocket_pred_tensor[i].sum().item():>4}, True Pos Ratio: {rec_labels_wt[i].mean().item():.4f}, Batch F1: {f1:.4f}")
                    # 打印 True ΔG 和 Pred ΔG
                    if mask[i]:  # 仅对非 nan 的 delta_g 输出
                        print(f"[Test Batch {batch_idx}] pair: {rec_id_wt[i]:<20} - {lig_id_wt[i]:<20}, True ΔG: {delta_g[i].item():>7.4f}, Pred ΔG: {pred_delta_g[i].item():>7.4f}")
        
        avg_pocket_acc = sum(test_pocket_acc) / test_samples
        avg_pocket_f1 = sum(test_pocket_f1) / test_samples
        avg_pocket_precision = sum(test_pocket_precision) / test_samples
        avg_pocket_recall = sum(test_pocket_recall) / test_samples
        avg_delta_g_mse = sum(test_delta_g_mse) / test_samples if test_delta_g_mse else float('nan')
        print(f"Epoch {epoch+1}, Test Pocket Accuracy: {avg_pocket_acc:.4f}, Test Pocket F1: {avg_pocket_f1:.4f}, Test Pocket Precision: {avg_pocket_precision:.4f}, Test Pocket Recall: {avg_pocket_recall:.4f}, Test ΔG MSE: {avg_delta_g_mse:.4f}")
    
    return model

# 保存模型权重
def save_model(model, save_path):
    torch.save(model.state_dict(), save_path)
    print(f"模型权重已保存至 {save_path}")

# 生成配对文件
def generate_pair_file(excel_file, data_dir, output_file, train=True):
    df = pd.read_excel(excel_file)
    train_df = df.sample(frac=0.8, random_state=42)
    target_df = train_df if train else df.drop(train_df.index)
    
    # 去重 PDB 标识符
    target_df = target_df.drop_duplicates(subset=['PDB'], keep='first')
    
    pairs = []
    pt_files = {os.path.basename(f).split('.')[0]: f for f in glob.glob(os.path.join(data_dir, "*.pt"))}
    for _, row in target_df.iterrows():
        pdb_id = row['PDB']
        rec_id_wt = f"{pdb_id}_receptor_wt"
        lig_id_wt = f"{pdb_id}_ligand_wt"
        if all(id in pt_files for id in [rec_id_wt, lig_id_wt]):
            delta_g = torch.load(pt_files[rec_id_wt])['delta_g']
            pairs.append({
                'receptor_id_wt': rec_id_wt,
                'ligand_id_wt': lig_id_wt,
                'interaction_label': 1,  # SKEMPI 数据为正样本
                'delta_g': float(delta_g) if not torch.isnan(delta_g) else float('nan')
            })
        else:
            print(f"警告: PDB {pdb_id} 的 .pt 文件缺失 (receptor_wt 或 ligand_wt)")
    
    # 生成负样本（1:1 比例）
    negative_pairs = []
    pt_file_keys = list(pt_files.keys())
    for _ in range(len(pairs)):
        while True:
            rec_id = random.choice(pt_file_keys)
            lig_id = random.choice(pt_file_keys)
            if (rec_id != lig_id and 
                not any(p['receptor_id_wt'] == rec_id and p['ligand_id_wt'] == lig_id for p in pairs) and
                rec_id.split('_')[0] != lig_id.split('_')[0]):
                negative_pairs.append({
                    'receptor_id_wt': rec_id,
                    'ligand_id_wt': lig_id,
                    'interaction_label': 0,
                    'delta_g': float('nan')
                })
                break
    pairs.extend(negative_pairs)
    
    pairs_df = pd.DataFrame(pairs)
    pairs_df.to_csv(output_file, index=False)
    print(f"生成配对文件: {output_file}, 配对数: {len(pairs)}, 正样本: {len(pairs) - len(negative_pairs)}, 负样本: {len(negative_pairs)}")

# 可视化函数
def visualize_test_data(test_loader, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"创建可视化输出目录: {output_dir}")

    for batch_idx, batch in enumerate(tqdm.tqdm(test_loader, desc="Visualizing Test Data", total=len(test_loader), unit="batch")):
        (rec_points_wt, rec_features_wt, rec_labels_wt, rec_id_wt,
         lig_points_wt, lig_features_wt, lig_labels_wt, lig_id_wt,
         interaction_label, delta_g) = batch

        batch_size = rec_points_wt.size(0)
        for i in range(batch_size):
            # 创建 3D 图
            fig = plt.figure(figsize=(10, 8))
            ax = fig.add_subplot(111, projection='3d')

            # 可视化受体点云
            rec_points = rec_points_wt[i].cpu().numpy()
            rec_labels = rec_labels_wt[i].cpu().numpy()
            ax.scatter(rec_points[rec_labels == 0, 0], rec_points[rec_labels == 0, 1], rec_points[rec_labels == 0, 2],
                       c='blue', label='Receptor Non-Pocket', s=1, alpha=0.5)
            ax.scatter(rec_points[rec_labels == 1, 0], rec_points[rec_labels == 1, 1], rec_points[rec_labels == 1, 2],
                       c='red', label='Receptor Pocket', s=5)

            # 可视化配体点云
            lig_points = lig_points_wt[i].cpu().numpy()
            lig_labels = lig_labels_wt[i].cpu().numpy()
            ax.scatter(lig_points[lig_labels == 0, 0], lig_points[lig_labels == 0, 1], lig_points[lig_labels == 0, 2],
                       c='green', label='Ligand Non-Pocket', s=1, alpha=0.5)
            ax.scatter(lig_points[lig_labels == 1, 0], lig_points[lig_labels == 1, 1], lig_points[lig_labels == 1, 2],
                       c='yellow', label='Ligand Pocket', s=5)

            # 设置标题和标签
            ax.set_title(f"Pair: {rec_id_wt[i]} - {lig_id_wt[i]}")
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            ax.legend()

            # 保存图像
            output_path = os.path.join(output_dir, f"{rec_id_wt[i]}_{lig_id_wt[i]}_visualization.png")
            plt.savefig(output_path)
            plt.close(fig)
            print(f"保存可视化图像: {output_path}")

    print(f"完成所有测试数据可视化，共处理 {len(test_loader) * 4} 个配对。")

# 主程序
if __name__ == "__main__":
    # 数据路径
    train_data_dir = ""
    test_data_dir = ""
    pair_dir = ""
    train_pair_file = os.path.join(pair_dir, "train_pairs.csv")
    test_pair_file = os.path.join(pair_dir, "test_pairs.csv")
    excel_file = ""
    
    # 检查并生成配对文件
    if not os.path.isfile(train_pair_file):
        print(f"警告：{train_pair_file} 不存在，生成中...")
        generate_pair_file(excel_file, train_data_dir, train_pair_file, train=True)
    if not os.path.isfile(test_pair_file):
        print(f"警告：{test_pair_file} 不存在，生成中...")
        generate_pair_file(excel_file, test_data_dir, test_pair_file, train=False)
    
    # 数据加载
    train_dataset = PPIDataset(train_data_dir, train_pair_file)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
    test_dataset = PPIDataset(test_data_dir, test_pair_file)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)
    
    # 可视化测试数据
    output_dir = ""
    visualize_test_data(test_loader, output_dir)
    
    # 加载预训练编码器
    encoder = SE3EquivariantConv(in_channels=11, out_channels=64)
    try:
        encoder.load_state_dict(torch.load(""))
    except Exception as e:
        print(f"警告：无法加载预训练编码器: {e}. 使用随机初始化。")
    
    # 初始化模型
    model = PPIPredictor(encoder)
    
    # 训练和测试
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    if device.type == 'cpu':
        print("警告：CUDA:1 不可用，回退到 CPU")
    model = train_ppi_predictor(model, train_loader, test_loader, epochs=20, lr=5e-3, device=device)
    
    # 保存模型权重
    save_path = ""
    save_model(model, save_path)