import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import random
import logging
from torch.utils.data import Dataset, DataLoader
from scipy.spatial import cKDTree
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from Bio.PDB import PDBParser

# 设置日志
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 自定义数据集
class PDBbindDataset(Dataset):
    def __init__(self, data_dir, pair_file):
        self.data_dir = data_dir
        self.pairs = pd.read_csv(pair_file)
        self.data = []
        
        # Populate positive examples
        for _, row in self.pairs.iterrows():
            pdb_id = row['pdb_id']
            receptor_file = os.path.join(data_dir, f"{pdb_id}_receptor.pt")
            ligand_file = os.path.join(data_dir, f"{pdb_id}_ligand.pt")
            pocket_file = os.path.join(data_dir.replace('', ''), pdb_id, f"{pdb_id}_pocket.pdb")
            if os.path.exists(receptor_file) and os.path.exists(ligand_file) and os.path.exists(pocket_file):
                self.data.append({
                    'receptor': receptor_file,
                    'ligand': ligand_file,
                    'pocket': pocket_file,
                    'delta_g': row.get('delta_g', torch.nan),
                    'interaction_label': 1.0
                })
            else:
                logger.warning(f"Skipping invalid entry for pdb_id: {pdb_id}")
        
        # Populate negative examples
        pdb_ids = list(self.pairs['pdb_id'])
        for _ in range(len(self.data)):  # Match number of negative to positive
            receptor_id, ligand_id = random.sample(pdb_ids, 2)
            receptor_file = os.path.join(data_dir, f"{receptor_id}_receptor.pt")
            ligand_file = os.path.join(data_dir, f"{ligand_id}_ligand.pt")
            pocket_file = os.path.join(data_dir.replace('', ''), receptor_id, f"{receptor_id}_pocket.pdb")
            if os.path.exists(receptor_file) and os.path.exists(ligand_file) and os.path.exists(pocket_file):
                self.data.append({
                    'receptor': receptor_file,
                    'ligand': ligand_file,
                    'pocket': pocket_file,
                    'delta_g': torch.nan,
                    'interaction_label': 0.0
                })
            else:
                logger.warning(f"Skipping invalid negative example for receptor_id: {receptor_id}, ligand_id: {ligand_id}")
        
        if not self.data:
            raise ValueError(f"No valid samples found in dataset. Check paths: data_dir={data_dir}, pair_file={pair_file}")
        logger.info(f"Loaded {len(self.data)} samples (positive + negative)")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        receptor = torch.load(data['receptor'])
        ligand = torch.load(data['ligand'])
        # Load pocket data
        pocket_coords, _ = load_pdb(data['pocket'])
        if pocket_coords is None or len(pocket_coords) == 0:
            logger.warning(f"Failed to load pocket data for {data['receptor']['pdb_id']}, using default labels")
            pocket_labels = torch.zeros_like(receptor['iface_labels'], dtype=torch.float32)
        else:
            tree = cKDTree(pocket_coords)
            dist, _ = tree.query(receptor['points'].numpy())
            pocket_labels = torch.tensor(dist < 5.0, dtype=torch.float32)  # Threshold 5Å
        
        # Append a label channel (1 for receptor, 0 for ligand) to 2D features
        receptor_features = torch.cat([receptor['features'], torch.ones(receptor['features'].size(0), 1, dtype=torch.float32)], dim=-1)  # [N, 13]
        ligand_features = torch.cat([ligand['features'], torch.zeros(ligand['features'].size(0), 1, dtype=torch.float32)], dim=-1)     # [N, 13]
        
        # Handle delta_g correctly (check for NaN)
        delta_g_value = data['delta_g']
        # Check if delta_g is a float and convert to Tensor
        if isinstance(delta_g_value, float):
            delta_g_value = torch.tensor(delta_g_value, dtype=torch.float32)

        # Check for NaN
        delta_g = delta_g_value if not torch.isnan(delta_g_value) else torch.tensor(float('nan'), dtype=torch.float32)

        return {
            'receptor_points': receptor['points'].to(torch.float32),
            'receptor_normals': receptor['normals'].to(torch.float32),
            'receptor_features': receptor_features.to(torch.float32),
            'receptor_iface': receptor['iface_labels'].to(torch.float32),
            'receptor_pocket_labels': pocket_labels,
            'ligand_points': ligand['points'].to(torch.float32),
            'ligand_normals': ligand['normals'].to(torch.float32),
            'ligand_features': ligand_features.to(torch.float32),
            'ligand_iface': ligand['iface_labels'].to(torch.float32),
            'delta_g': delta_g,
            'interaction_label': torch.tensor(data['interaction_label'], dtype=torch.float32),
            'smiles': ligand.get('smiles', None)
        }

def load_pdb(pdb_file, chain_id=None):
    """Load atomic coordinates and types from PDB file."""
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure('protein', pdb_file)
        coords, atom_types = [], []
        for model in structure:
            for chain in model:
                if chain_id is None or chain.id == chain_id:
                    for residue in chain:
                        for atom in residue:
                            coords.append(atom.get_coord())
                            atom_type = atom.element
                            atom_types.append(atom_type if atom_type in ['C', 'H', 'O', 'N', 'S', 'SE'] else 'C')
        return np.array(coords), np.array(atom_types)
    except Exception as e:
        logger.error(f"Failed to load PDB file {pdb_file}: {e}")
        return None, None

# SE3 等变卷积模块 (匹配预训练模型)
class SE3EquivariantConv(nn.Module):
    def __init__(self, in_channels=13, out_channels=64):  # Match pre-trained input/output channels
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, 1)  # Single conv layer as in pre-trained model
        self.mlp = nn.Sequential(
            nn.Linear(out_channels + 3, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, points, features):
        B, N, _ = points.shape
        features = features.transpose(1, 2)  # [B, C, N]
        features = self.conv(features)
        features = features.transpose(1, 2)  # [B, N, C]
        features = torch.cat([features, points], dim=-1)  # [B, N, C+3]
        features = self.mlp(features)
        return features

# 注意力融合模块
class AttentionFusion(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)
        self.scale = 1 / (channels ** 0.5)
    
    def forward(self, receptor_enc, ligand_enc):
        query = self.query(receptor_enc)
        key = self.key(ligand_enc)
        value = self.value(ligand_enc)
        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        attn = torch.softmax(scores, dim=-1)
        fused = torch.matmul(attn, value)
        receptor_pooled = receptor_enc.max(dim=1)[0]  # [B, 64]
        fused_pooled = fused.max(dim=1)[0]            # [B, 64]
        return torch.cat([receptor_pooled, fused_pooled], dim=-1)  # [B, 128]

# 结合预测模型
class BindingPredictor(nn.Module):
    def __init__(self, pretrained_path=None):
        super().__init__()
        # 复用无监督预训练的编码器
        self.receptor_encoder = SE3EquivariantConv(in_channels=13, out_channels=64)
        self.ligand_encoder = SE3EquivariantConv(in_channels=13, out_channels=64)
        self.fusion = AttentionFusion(64)
        self.pocket_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, 1)
        )
        self.interaction_head = nn.Sequential(
            nn.Linear(128, 128),  # 2*64 from fusion
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.delta_g_head = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
        # 加载预训练权重并调试
        if pretrained_path and os.path.exists(pretrained_path):
            pretrained_state = torch.load(pretrained_path, map_location=torch.device('cpu'))
            logger.info(f"Pre-trained state dictionary keys: {pretrained_state.keys()}")
            # Load compatible weights
            model_dict = self.state_dict()
            pretrained_dict = {}
            for k, v in pretrained_state.items():
                if k in model_dict:
                    pretrained_dict[k] = v
            model_dict.update(pretrained_dict)
            self.load_state_dict(model_dict, strict=False)
            logger.info("Successfully loaded compatible pre-trained weights")
            # 冻结预训练层（可选）
            # for param in self.receptor_encoder.parameters():
            #     param.requires_grad = False
            # for param in self.ligand_encoder.parameters():
            #     param.requires_grad = False
        else:
            logger.warning(f"Pre-trained model not found or not loaded from {pretrained_path}")
    
    def forward(self, receptor_points, receptor_features, ligand_points, ligand_features):
        receptor_enc = self.receptor_encoder(receptor_points, receptor_features)
        ligand_enc = self.ligand_encoder(ligand_points, ligand_features)
        receptor_pocket = self.pocket_head(receptor_enc).squeeze(-1)
        ligand_pocket = self.pocket_head(ligand_enc).squeeze(-1)
        interaction_features = self.fusion(receptor_enc, ligand_enc)
        interaction = self.interaction_head(interaction_features)
        delta_g = self.delta_g_head(interaction_features)
        return receptor_pocket, ligand_pocket, interaction, delta_g

def visualize_pockets(receptor_points, receptor_pocket, ligand_points, ligand_pocket, output_path):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    receptor_mask = torch.sigmoid(receptor_pocket) > 0.5
    ligand_mask = torch.sigmoid(ligand_pocket) > 0.2
    ax.scatter(receptor_points[receptor_mask, 0].cpu(), receptor_points[receptor_mask, 1].cpu(), receptor_points[receptor_mask, 2].cpu(), c='r', label='Receptor Pocket')
    ax.scatter(ligand_points[ligand_mask, 0].cpu(), ligand_points[ligand_mask, 1].cpu(), ligand_points[ligand_mask, 2].cpu(), c='b', label='Ligand Pocket')
    ax.legend()
    plt.savefig(output_path)
    plt.close()

def train_model(data_dir, pair_file, output_path, pretrained_path=None):
    # Explicitly set device to cuda:2
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(f"Training on {device}")
    
    dataset = PDBbindDataset(data_dir, pair_file)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
    model = BindingPredictor(pretrained_path).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    pocket_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(3.0, dtype=torch.float32))
    interaction_criterion = nn.BCEWithLogitsLoss()
    delta_g_criterion = nn.MSELoss(reduction='none')
    
    for epoch in range(50):
        model.train()
        total_loss = 0
        for i, batch in enumerate(dataloader):
            receptor_points = batch['receptor_points'].to(device)
            receptor_features = batch['receptor_features'].to(device)
            ligand_points = batch['ligand_points'].to(device)
            ligand_features = batch['ligand_features'].to(device)
            receptor_iface = batch['receptor_iface'].to(device)
            ligand_iface = batch['ligand_iface'].to(device)
            delta_g = batch['delta_g'].to(device)
            interaction_label = batch['interaction_label'].to(device)
            receptor_pocket_labels = batch['receptor_pocket_labels'].to(device)
            
            receptor_pocket, ligand_pocket, interaction, pred_delta_g = model(receptor_points, receptor_features, ligand_points, ligand_features)
            
            # Pocket loss
            pocket_loss = (pocket_criterion(receptor_pocket, receptor_iface) + 
                          pocket_criterion(ligand_pocket, ligand_iface)) / 2
            true_pocket_loss = pocket_criterion(receptor_pocket, receptor_pocket_labels)
            pocket_loss = pocket_loss + 0.5 * true_pocket_loss  # Weight true pocket loss
            
            # Interaction loss
            interaction_loss = interaction_criterion(interaction.squeeze(-1), interaction_label)
            
            # Delta-G loss
            delta_g_mask = ~torch.isnan(delta_g)
            delta_g_loss = delta_g_criterion(pred_delta_g[delta_g_mask].squeeze(-1), delta_g[delta_g_mask]).mean() if delta_g_mask.any() else torch.tensor(0.0, dtype=torch.float32).to(device)
            
            # 总损失
            loss = 5 * pocket_loss + 50 * interaction_loss + delta_g_loss
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
            
            # Extract true and predicted delta-G for non-NaN entries
            true_delta_g = delta_g[delta_g_mask].detach().cpu().numpy() if delta_g_mask.any() else []
            pred_delta_g_values = pred_delta_g[delta_g_mask].squeeze(-1).detach().cpu().numpy() if delta_g_mask.any() else []
            
            # Print individual losses and delta-G values
            print(f"Epoch {epoch+1}, Batch {i+1}, Pocket Loss: {pocket_loss.item():.4f}, "
                  f"Interaction Loss: {interaction_loss.item():.4f}, "
                  f"Delta-G Loss: {delta_g_loss.item():.4f}, Total Loss: {loss.item():.4f}, "
                  f"True Delta-G: {true_delta_g}, "
                  f"Predicted Delta-G: {pred_delta_g_values}")
            
            if epoch == 49:
                visualize_pockets(receptor_points[0], receptor_pocket[0], ligand_points[0], ligand_pocket[0], 
                                 f"")
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    
    torch.save(model.state_dict(), output_path)

if __name__ == "__main__":
    data_dir = ""
    pair_file = ""
    pretrained_path = ""
    output_path = ""
    train_model(data_dir, pair_file, output_path, pretrained_path)