import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import tqdm

# 自定义数据集
class PDBbindDataset(Dataset):
    def __init__(self, data_dir, pair_file):
        self.data_dir = data_dir
        self.pairs = pd.read_csv(pair_file)
        self.data = []
        for _, row in self.pairs.iterrows():
            receptor_file = os.path.join(data_dir, f"{row['pdb_id']}_receptor.pt")
            ligand_file = os.path.join(data_dir, f"{row['pdb_id']}_ligand.pt")
            if os.path.exists(receptor_file) and os.path.exists(ligand_file):
                self.data.append({
                    'receptor': receptor_file,
                    'ligand': ligand_file,
                    'pdb_id': row['pdb_id']
                })
        
        self.feature_dim = 12  # 调整为 12 维
        print(f"有效蛋白质-小分子对数: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        receptor = torch.load(data['receptor'], map_location=torch.device('cpu'))
        ligand = torch.load(data['ligand'], map_location=torch.device('cpu'))
        
        # 分别获取 points 和 features
        receptor_points = receptor['points']  # [5000, 3]
        receptor_features = receptor['features']  # [5000, 12]
        ligand_points = ligand['points']  # [500, 3]
        ligand_features = ligand['features']  # [500, 12]
        
        # 添加类型嵌入 (0 for receptor, 1 for ligand)
        receptor_type = torch.zeros(receptor_points.size(0), 1, device=receptor_points.device)
        ligand_type = torch.ones(ligand_points.size(0), 1, device=ligand_points.device)
        
        return {
            'receptor_points': receptor_points,
            'receptor_features': torch.cat([receptor_features, receptor_type], dim=-1),  # [5000, 13]
            'ligand_points': ligand_points,
            'ligand_features': torch.cat([ligand_features, ligand_type], dim=-1),  # [500, 13]
            'pdb_id': data['pdb_id']
        }

# SE3 等变卷积模块 (移除 KNN)
class SE3EquivariantConv(nn.Module):
    def __init__(self, in_channels=13, out_channels=64):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, 1)
        self.mlp = nn.Sequential(
            nn.Linear(out_channels + 3, out_channels),  # 融合点坐标
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, points, features):
        B, N, _ = points.shape
        
        # 直接使用卷积处理特征
        features = features.transpose(1, 2)  # [B, C, N]
        features = self.conv(features)
        features = features.transpose(1, 2)  # [B, N, C]
        
        # 融合点坐标，保持 SE(3) 等变性
        features = torch.cat([features, points], dim=-1)  # [B, N, C+3]
        features = self.mlp(features)
        return features

# 注意力融合模块
class AttentionFusion(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)
        self.scale = 1 / (channels ** 0.5)
    
    def forward(self, receptor_enc, ligand_enc):
        query = self.query(receptor_enc)  # [B, Nr, C]
        key = self.key(ligand_enc)        # [B, Nl, C]
        value = self.value(ligand_enc)    # [B, Nl, C]

        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        attn = torch.softmax(scores, dim=-1)

        fused = torch.matmul(attn, value)
        return torch.cat([receptor_enc.mean(dim=1), fused.mean(dim=1)], dim=-1)  # [B, 2*C]

# VQ-VAE 量化器
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings=512, embedding_dim=64, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)

    def forward(self, z):
        B, N, D = z.shape
        z_flattened = z.reshape(-1, D)
        distances = torch.sum((z_flattened.unsqueeze(1) - self.embedding.weight.unsqueeze(0))**2, dim=2)
        encoding_indices = torch.argmin(distances, dim=1)
        z_q = self.embedding(encoding_indices).reshape(B, N, D)
        codebook_loss = torch.mean((z_q.detach() - z)**2)
        commitment_loss = torch.mean((z_q - z.detach())**2)
        loss = codebook_loss + self.commitment_cost * commitment_loss
        z_q = z + (z_q - z).detach()
        return z_q, loss, encoding_indices.reshape(B, N)

# Surface-VQMAE 模型
class SurfaceVQMAE(nn.Module):
    def __init__(self, in_channels=13, latent_dim=64, num_embeddings=512, mask_ratio=0.5):
        super(SurfaceVQMAE, self).__init__()
        self.mask_ratio = mask_ratio
        self.receptor_encoder = SE3EquivariantConv(in_channels, latent_dim)
        self.ligand_encoder = SE3EquivariantConv(in_channels, latent_dim)
        self.vq = VectorQuantizer(num_embeddings, latent_dim)
        self.fusion = AttentionFusion(latent_dim)
        self.decoder = SE3EquivariantConv(latent_dim, in_channels)  # 重建 13D

    def forward(self, receptor_points, receptor_features, ligand_points, ligand_features):
        B, Nr, _ = receptor_points.shape
        _, Nl, _ = ligand_points.shape

        receptor_mask = torch.rand(B, Nr, device=receptor_points.device) > self.mask_ratio
        ligand_mask = torch.rand(B, Nl, device=ligand_points.device) > self.mask_ratio
        receptor_masked = receptor_features * receptor_mask.unsqueeze(-1)
        ligand_masked = ligand_features * ligand_mask.unsqueeze(-1)

        receptor_z = self.receptor_encoder(receptor_points, receptor_masked)
        ligand_z = self.ligand_encoder(ligand_points, ligand_masked)

        receptor_z_q, vq_loss_r, _ = self.vq(receptor_z)
        ligand_z_q, vq_loss_l, _ = self.vq(ligand_z)
        vq_loss = vq_loss_r + vq_loss_l

        fused_enc = self.fusion(receptor_z_q.mean(dim=1), ligand_z_q.mean(dim=1))

        receptor_recon = self.decoder(receptor_points, receptor_z_q)
        ligand_recon = self.decoder(ligand_points, ligand_z_q)
        recon_loss = (F.mse_loss(receptor_recon[receptor_mask], receptor_features[receptor_mask]) +
                     F.mse_loss(ligand_recon[ligand_mask], ligand_features[ligand_mask])) / 2

        return receptor_recon, ligand_recon, recon_loss, vq_loss, receptor_mask, ligand_mask

# 对比损失 (可选)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, receptor_enc, ligand_enc):
        B, _ = receptor_enc.shape
        receptor_labels = torch.arange(B, device=receptor_enc.device)
        ligand_labels = torch.arange(B, device=ligand_enc.device)
        all_enc = torch.cat([receptor_enc, ligand_enc], dim=0)
        logits = torch.matmul(all_enc, all_enc.T) / self.temperature
        labels = torch.cat([receptor_labels, ligand_labels + B], dim=0)
        loss = self.criterion(logits, labels)
        return loss

# 训练函数
def train_vqmae(model, train_loader, test_loader, epochs=50, lr=2e-4, device='cuda'):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    model.train()
    
    best_loss = float('inf')
    best_epoch = 0
    best_encoder_path = ""
    contrastive_loss = ContrastiveLoss()
    
    for epoch in range(epochs):
        total_loss = 0.0
        total_recon_loss = 0.0
        total_vq_loss = 0.0
        total_cont_loss = 0.0
        for batch in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)"):
            receptor_points = batch['receptor_points'].to(device)
            receptor_features = batch['receptor_features'].to(device)
            ligand_points = batch['ligand_points'].to(device)
            ligand_features = batch['ligand_features'].to(device)
            
            optimizer.zero_grad()
            receptor_recon, ligand_recon, recon_loss, vq_loss, receptor_mask, ligand_mask = model(
                receptor_points, receptor_features, ligand_points, ligand_features
            )
            
            receptor_enc = model.receptor_encoder(receptor_points, receptor_features).mean(dim=1)
            ligand_enc = model.ligand_encoder(ligand_points, ligand_features).mean(dim=1)
            cont_loss = contrastive_loss(receptor_enc, ligand_enc)
            
            loss = recon_loss + vq_loss + 0.1 * cont_loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_vq_loss += vq_loss.item()
            total_cont_loss += cont_loss.item()
        
        avg_loss = total_loss / len(train_loader)
        avg_recon_loss = total_recon_loss / len(train_loader)
        avg_vq_loss = total_vq_loss / len(train_loader)
        avg_cont_loss = total_cont_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}, Recon Loss: {avg_recon_loss:.4f}, VQ Loss: {avg_vq_loss:.4f}, Cont Loss: {avg_cont_loss:.4f}")
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = epoch + 1
            torch.save(model.state_dict(), best_encoder_path)
            print(f"保存最低训练损失模型 (Epoch {best_epoch}, Loss: {best_loss:.4f}) 到 {best_encoder_path}")
    
    avg_test_loss, avg_test_recon_loss, avg_test_vq_loss = evaluate_test_loss(model, test_loader, device)
    print(f"测试集最终损失: Total Loss: {avg_test_loss:.4f}, Recon Loss: {avg_test_recon_loss:.4f}, VQ Loss: {avg_test_vq_loss:.4f}")
    
    return model

# 测试集评估函数
def evaluate_test_loss(model, test_loader, device):
    model.eval()
    total_test_loss = 0.0
    total_test_recon_loss = 0.0
    total_test_vq_loss = 0.0
    with torch.no_grad():
        for batch in tqdm.tqdm(test_loader, desc="Evaluating Test Set"):
            receptor_points = batch['receptor_points'].to(device)
            receptor_features = batch['receptor_features'].to(device)
            ligand_points = batch['ligand_points'].to(device)
            ligand_features = batch['ligand_features'].to(device)
            _, _, recon_loss, vq_loss, _, _ = model(receptor_points, receptor_features, ligand_points, ligand_features)
            loss = recon_loss + vq_loss
            total_test_loss += loss.item()
            total_test_recon_loss += recon_loss.item()
            total_test_vq_loss += vq_loss.item()
    
    avg_test_loss = total_test_loss / len(test_loader)
    avg_test_recon_loss = total_test_recon_loss / len(test_loader)
    avg_test_vq_loss = total_test_vq_loss / len(test_loader)
    return avg_test_loss, avg_test_recon_loss, avg_test_vq_loss

# 保存编码器权重
def save_encoder(model, save_path):
    torch.save(model.state_dict(), save_path)
    print(f"编码器权重已保存至 {save_path}")

# 主程序
if __name__ == "__main__":
    train_data_dir = ""
    train_pair_file = ""
    test_data_dir = ""
    test_pair_file = ""
    
    train_dataset = PDBbindDataset(train_data_dir, train_pair_file)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
    test_dataset = PDBbindDataset(test_data_dir, test_pair_file)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)
    
    in_channels = 13  # 包括类型嵌入
    latent_dim = 64
    num_embeddings = 512
    mask_ratio = 0.5
    
    model = SurfaceVQMAE(in_channels=in_channels, latent_dim=latent_dim, num_embeddings=num_embeddings, mask_ratio=mask_ratio)
    
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    if device.type == 'cpu':
        print("警告：CUDA 不可用，回退到 CPU")
    model = train_vqmae(model, train_loader, test_loader, epochs=50, lr=2e-4, device=device)
    
    save_path = ""
    save_encoder(model, save_path)