import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, MessagePassing
from torch_geometric.data import Data, Batch, Dataset
from torch.utils.data import DataLoader
from transformers import EsmModel, AutoTokenizer
from typing import Tuple, Optional, List, Dict
import numpy as np
import os
import glob
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO
from itertools import combinations
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, hamming_loss
from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

# ==========================================================================================
# 1. CORE MODEL DEFINITIONS
# ==========================================================================================
class WeightedGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_edges):
        super(WeightedGCNConv, self).__init__(aggr='add'); self.lin = nn.Linear(in_channels, out_channels); self.edge_weights = nn.Parameter(torch.ones(num_edges))
    def forward(self, x, edge_index): return self.propagate(edge_index, x=x)
    def message(self, x_j, edge_index): return self.edge_weights.view(-1, 1) * x_j
    def update(self, aggr_out): return self.lin(aggr_out)

class ContrastiveCoAttentionBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__(); self.seq_to_struct_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True); self.struct_to_seq_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True); self.ffn = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim)); self.norm1, self.norm2, self.norm3, self.norm4 = [nn.LayerNorm(embed_dim) for _ in range(4)]
    def forward(self, seq_embeds: torch.Tensor, struct_embeds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        struct_aware_seq, s2s_attn_weights = self.seq_to_struct_attn(query=seq_embeds, key=struct_embeds, value=struct_embeds); seq_embeds = self.norm1(seq_embeds + struct_aware_seq); seq_aware_struct, s2q_attn_weights = self.struct_to_seq_attn(query=struct_embeds, key=seq_embeds, value=seq_embeds); struct_embeds = self.norm2(struct_embeds + seq_aware_struct); seq_embeds = self.norm3(seq_embeds + self.ffn(seq_embeds)); struct_embeds = self.norm4(struct_embeds + self.ffn(struct_embeds)); return seq_embeds, struct_embeds, s2s_attn_weights, s2q_attn_weights

class LabelGraphInfusedDecoder(nn.Module):
    def __init__(self, in_features: int, num_labels: int, label_graph_edge_index: torch.Tensor):
        super().__init__(); self.num_labels = num_labels; self.register_buffer('label_graph_edge_index', label_graph_edge_index); self.initial_projection = nn.Linear(in_features, num_labels)
        num_edges = label_graph_edge_index.shape[1]
        if num_edges > 0: self.gcn = WeightedGCNConv(in_channels=1, out_channels=1, num_edges=num_edges)
        else: self.gcn = None
    def forward(self, fused_rep: torch.Tensor) -> torch.Tensor:
        initial_logits = self.initial_projection(fused_rep)
        if self.gcn is None or self.label_graph_edge_index.shape[1] == 0: return initial_logits
        x_label_graph = initial_logits.t().unsqueeze(-1)
        refined_logits_list = [self.gcn(x_label_graph[:, i, :], self.label_graph_edge_index) for i in range(x_label_graph.shape[1])]
        refined_logits = torch.cat(refined_logits_list, dim=1).squeeze(-1).t(); return initial_logits + refined_logits

class IMPeT(nn.Module):
    def __init__(self, num_labels: int, label_graph_edge_index: torch.Tensor, esm_model_name: str="facebook/esm2_t33_650M_UR50D", num_co_attention_blocks: int=2, num_heads: int=4, fine_tune_esm: bool=False):
        super().__init__(); self.sequence_encoder = EsmModel.from_pretrained(esm_model_name); self.fine_tune_esm = fine_tune_esm
        if not self.fine_tune_esm:
            for param in self.sequence_encoder.parameters(): param.requires_grad = False
        self.esm_embed_dim = self.sequence_encoder.config.hidden_size; self.aa_embedding = nn.Embedding(21, self.esm_embed_dim); self.gnn_layers = nn.ModuleList([GATConv(self.esm_embed_dim, self.esm_embed_dim) for _ in range(2)]); self.co_attention_blocks = nn.ModuleList([ContrastiveCoAttentionBlock(embed_dim=self.esm_embed_dim, num_heads=num_heads) for _ in range(num_co_attention_blocks)]); self.seq_proj = nn.Linear(self.esm_embed_dim, 128); self.struct_proj = nn.Linear(self.esm_embed_dim, 128); self.decoder = LabelGraphInfusedDecoder(self.esm_embed_dim * 2, num_labels, label_graph_edge_index)
    def forward(self, seq_tokens: dict, peptide_graphs: Batch, return_attention: bool=False, occluded_indices=None):
        if self.fine_tune_esm: seq_embeds = self.sequence_encoder(**seq_tokens).last_hidden_state
        else:
            with torch.no_grad(): seq_embeds = self.sequence_encoder(**seq_tokens).last_hidden_state
        node_features = self.aa_embedding(peptide_graphs.x); struct_embeds = node_features
        for gnn_layer in self.gnn_layers: struct_embeds = F.gelu(gnn_layer(struct_embeds, peptide_graphs.edge_index))
        if occluded_indices is not None: struct_embeds[occluded_indices] = 0.0
        num_nodes_per_graph = [torch.sum(peptide_graphs.batch == i) for i in range(peptide_graphs.num_graphs)]; struct_embeds_padded = torch.nn.utils.rnn.pad_sequence(torch.split(struct_embeds, num_nodes_per_graph), batch_first=True)
        s2s_attns = []
        for block in self.co_attention_blocks:
            seq_embeds, struct_embeds_padded, s2s_w, _ = block(seq_embeds, struct_embeds_padded)
            if return_attention: s2s_attns.append(s2s_w.detach().cpu())
        seq_rep = seq_embeds[:, 0, :]; unpadded_struct_embeds = struct_embeds_padded.reshape(-1, struct_embeds_padded.shape[-1])[:struct_embeds.shape[0]]; struct_rep = global_mean_pool(unpadded_struct_embeds, peptide_graphs.batch); fused_rep = torch.cat([seq_rep, struct_rep], dim=-1); final_logits = self.decoder(fused_rep)
        contrastive_seq_rep = F.normalize(self.seq_proj(seq_rep), p=2, dim=1); contrastive_struct_rep = F.normalize(self.struct_proj(struct_rep), p=2, dim=1)
        output = {"logits": final_logits, "contrastive_seq_rep": contrastive_seq_rep, "contrastive_struct_rep": contrastive_struct_rep}
        if return_attention: output["seq_to_struct_attention"] = s2s_attns
        return output

# ==========================================================================================
# 2. DATA HANDLING (USING PDBs DIRECTLY)
# ==========================================================================================
def parse_pdb_file(pdb_path: Path) -> np.ndarray:
    try:
        coords = []
        with open(pdb_path, 'r') as f:
            for line in f:
                if line.startswith("ATOM") and line[12:16].strip() == "CA":
                    coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
        return np.array(coords)
    except Exception as e:
        print(f"\n   ❌ Error parsing PDB file '{pdb_path.name}': {e}"); return np.array([])

def create_graph_from_coords(coords: np.ndarray) -> Data:
    num_residues = len(coords); rows, cols = [], []
    if num_residues > 1:
        from scipy.spatial.distance import pdist, squareform
        dist_matrix = squareform(pdist(coords)); adj_matrix = dist_matrix < 8.0
        np.fill_diagonal(adj_matrix, 0); rows, cols = np.where(adj_matrix)
    return Data(x=torch.arange(num_residues) % 21, edge_index=torch.tensor(np.array([rows, cols]), dtype=torch.long))

class PeptideDataset(Dataset):
    def __init__(self, samples: List[Dict], label_map: Dict[str, int], pdb_dir: Path):
        self.samples = samples; self.label_map = label_map; self.pdb_dir = pdb_dir
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        peptide_id = sample['id']
        sequence = sample['sequence']
        seq_index = sample['seq_index']
        
        pdb_file = self.pdb_dir / f"seq{seq_index}.pdb"
        
        if not pdb_file.exists():
            print(f"\n   ⚠️ Warning: PDB file not found for '{peptide_id}' (index {seq_index}). Expected: {pdb_file}. Skipping.")
            return None
        
        coords = parse_pdb_file(pdb_file)
        if coords.shape[0] != len(sequence):
            print(f"\n   ⚠️ Warning: Skipping '{peptide_id}' due to mismatched length (Seq: {len(sequence)}, PDB: {coords.shape[0]}).")
            return None 
            
        graph = create_graph_from_coords(coords)
        label_vector = torch.zeros(len(self.label_map))
        for label in sample['labels']:
            if label in self.label_map: label_vector[self.label_map[label]] = 1
        return {'id': peptide_id, 'sequence': sequence, 'graph': graph, 'labels': label_vector, 'seq_index': seq_index}

def safe_collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch: return None
    return collate_fn(batch)

class PeptideDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: Path, pdb_dir: Path, tokenizer, batch_size: int = 4, num_workers: int = 4):
        super().__init__(); self.data_dir = data_dir; self.pdb_dir = pdb_dir; self.tokenizer = tokenizer; self.batch_size = batch_size; self.num_workers = num_workers
    def setup(self, stage: Optional[str] = None):
        all_peptides, self.label_map = self._load_data(self.data_dir)
        if not all_peptides: self.train_dataset = None; return
        self.num_labels = len(self.label_map); self.label_graph_edge_index = self._create_label_graph(all_peptides, self.label_map)
        train_val, self.test_peptides = train_test_split(all_peptides, test_size=0.2, random_state=42)
        train, val = train_test_split(train_val, test_size=0.125, random_state=42)
        self.train_dataset = PeptideDataset(train, self.label_map, self.pdb_dir); self.val_dataset = PeptideDataset(val, self.label_map, self.pdb_dir); self.test_dataset = PeptideDataset(self.test_peptides, self.label_map, self.pdb_dir)
    def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=safe_collate_fn, num_workers=self.num_workers)
    def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=safe_collate_fn, num_workers=self.num_workers)
    def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=safe_collate_fn, num_workers=self.num_workers)
    def _load_data(self, root_dir):
        data_samples, unique_labels = [], set()
        for f_path in glob.glob(os.path.join(root_dir, "seqs.fasta")):
            class_path = os.path.join(root_dir, "classes.txt")
            if not os.path.exists(class_path): continue
            try:
                with open(class_path) as handle: class_lines = handle.read().strip().split('\n')
                fasta_sequences = list(SeqIO.parse(f_path, "fasta"))
                if len(fasta_sequences) != len(class_lines): continue
                for i, (record, class_line) in enumerate(zip(fasta_sequences, class_lines)):
                    labels = [label.strip() for label in class_line.strip().split(',') if label.strip()];
                    peptide_id = record.id.strip()
                    if not labels or not peptide_id:
                        print(f"   ⚠️ Warning: Skipping malformed entry with empty ID or labels in {f_path}."); continue
                    unique_labels.update(labels)
                    data_samples.append({
                        "id": peptide_id, 
                        "sequence": str(record.seq), 
                        "labels": labels,
                        "seq_index": i + 1  # Add 1-based index
                    })
            except Exception as e: print(f"   - Error processing files in {root_dir}: {e}")
        return data_samples, {label: i for i, label in enumerate(sorted(list(unique_labels)))}
    def _create_label_graph(self, data_samples, label_map):
        edges = set()
        for sample in data_samples:
            if len(sample["labels"]) > 1:
                for label_a, label_b in combinations(sample["labels"], 2):
                    idx_a, idx_b = label_map[label_a], label_map[label_b]; edges.add(tuple(sorted((idx_a, idx_b))))
        if not edges: return torch.empty((2, 0), dtype=torch.long)
        edge_list = list(edges); source_nodes, target_nodes = [e[0] for e in edge_list], [e[1] for e in edge_list]
        return torch.tensor([source_nodes + target_nodes, target_nodes + source_nodes], dtype=torch.long)
    
# ==========================================================================================
# 3. PYTORCH LIGHTNING SYSTEM
# ==========================================================================================
class IMPeTLightning(pl.LightningModule):
    def __init__(self, num_labels, label_graph_edge_index, label_map, learning_rate=1e-5):
        super().__init__(); self.save_hyperparameters(); self.learning_rate = learning_rate; self.label_map = label_map
        self.model = IMPeT(num_labels, label_graph_edge_index, fine_tune_esm=True)
        self.classification_criterion = nn.BCEWithLogitsLoss()
    def forward(self, seq_tokens, graphs, return_attention=False, occluded_indices=None):
        return self.model(seq_tokens, graphs, return_attention, occluded_indices)
    def training_step(self, batch, batch_idx):
        if batch is None: return None
        tokens = tokenizer(batch['sequences'], return_tensors="pt", padding=True, truncation=True, max_length=1022).to(self.device)
        graphs = batch['graphs'].to(self.device); labels = batch['labels'].to(self.device)
        output = self(tokens, graphs); class_loss = self.classification_criterion(output['logits'], labels); cont_loss = contrastive_loss(output['contrastive_seq_rep'], output['contrastive_struct_rep'])
        loss = class_loss + 0.1 * cont_loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=len(batch['sequences']))
        return loss
    def validation_step(self, batch, batch_idx):
        if batch is None: return None
        tokens = tokenizer(batch['sequences'], return_tensors="pt", padding=True, truncation=True, max_length=1022).to(self.device)
        graphs = batch['graphs'].to(self.device); labels = batch['labels'].to(self.device)
        output = self(tokens, graphs); loss = self.classification_criterion(output['logits'], labels)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=len(batch['sequences']))
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        if batch is None: return None
        tokens = tokenizer(batch['sequences'], return_tensors="pt", padding=True, truncation=True, max_length=1022).to(self.device)
        graphs = batch['graphs'].to(self.device)
        output = self(tokens, graphs)
        return {'preds': torch.sigmoid(output['logits']), 'labels': batch['labels']}
    def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

# ==========================================================================================
# 4. INTERPRETABILITY SUITE & HELPERS
# ==========================================================================================
def collate_fn(batch):
    sequences = [item['sequence'] for item in batch]; graphs = [item['graph'] for item in batch]; labels = torch.stack([item['labels'] for item in batch]); ids = [item['id'] for item in batch]; seq_indices = [item['seq_index'] for item in batch]
    return {'ids': ids, 'sequences': sequences, 'graphs': Batch.from_data_list(graphs), 'labels': labels, 'seq_indices': seq_indices}
def contrastive_loss(seq_reps, struct_reps, temperature=0.07):
    logits = (seq_reps @ struct_reps.T) / temperature; labels = torch.arange(len(seq_reps), device=seq_reps.device)
    return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
def visualize_attention_heatmap(model_output: Dict, sequence: str, peptide_id: str, output_dir: Path):
    """
    Visualizes the model's attention weights as a heatmap, with figure size and fonts
    adjusted for the sequence length, and saves the output as a PDF.

    Args:
        model_output (Dict): Dictionary containing the model's output, including attention weights.
        sequence (str): The amino acid sequence string.
        peptide_id (str): An identifier for the peptide.
        output_dir (Path): The directory where the output PDF will be saved.
    """
    output_dir.mkdir(exist_ok=True, parents=True)


    attention_weights = model_output["seq_to_struct_attention"][-1]
    avg_attention = attention_weights.squeeze(0).numpy()
    
    seq_len = len(sequence)
    avg_attention = avg_attention[1:seq_len+1, :seq_len]


    fig_dim = max(8, seq_len * 0.45)
    
    title_fontsize = 20
    label_fontsize = 16
    tick_fontsize = 12
    
    fig, ax = plt.subplots(figsize=(fig_dim, fig_dim))
    
    sns.heatmap(avg_attention, ax=ax, cmap="viridis", cbar=True, square=True,
                cbar_kws={"shrink": 0.75}) 
    ax.set_xlabel("Structural Residue Index", fontweight='bold', fontsize=label_fontsize)
    ax.set_ylabel("Sequence Residue Index", fontweight='bold', fontsize=label_fontsize)
    fig.suptitle(f"Attention Heatmap for {peptide_id}", fontsize=title_fontsize, fontweight='bold')

    ax.set_xticks(np.arange(seq_len) + 0.5)
    ax.set_xticklabels(list(sequence), fontsize=tick_fontsize)
    ax.set_yticks(np.arange(seq_len) + 0.5)
    ax.set_yticklabels(list(sequence), rotation=0, fontsize=tick_fontsize)
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])

    save_path = output_dir / f"{peptide_id}_attention_heatmap.pdf"
    plt.savefig(save_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.close()

    print(f"    -> Saved attention heatmap to {save_path}")
def run_advanced_interpretability_suite(model, tokenizer, test_peptides, label_map, pdb_dir, output_dir):
    print("\n   🔬 Running Advanced Interpretability Suite..."); positive_candidate, negative_candidate = None, None; target_class_name = list(label_map.keys())[0]; target_class_idx = label_map[target_class_name]
    model.eval()
    with torch.no_grad():
        for peptide in tqdm(test_peptides[:20], desc="     [Finding Candidates]    "):
            is_positive = target_class_name in peptide['labels']
            pdb_file = pdb_dir / f"seq{peptide['seq_index']}.pdb"
            if not pdb_file.exists(): continue
            coords = parse_pdb_file(pdb_file)
            if coords.shape[0] != len(peptide['sequence']): continue
            graph = create_graph_from_coords(coords)
            tokens = tokenizer([peptide['sequence']], return_tensors='pt', padding=True, truncation=True, max_length=1022).to(model.device); graph_batch = Batch.from_data_list([graph]).to(model.device)
            pred_prob = torch.sigmoid(model(tokens, graph_batch)['logits'])[0, target_class_idx].item()
            if is_positive and pred_prob > 0.5 and not positive_candidate: positive_candidate = peptide
            if not is_positive and pred_prob < 0.5 and not negative_candidate: negative_candidate = peptide
            if positive_candidate and negative_candidate: break
    if positive_candidate:
        print(f"   Analyzing positive peptide '{positive_candidate['id']}' for class '{target_class_name}'"); analysis_dir = output_dir / "advanced_analysis" / positive_candidate['id']; analysis_dir.mkdir(parents=True, exist_ok=True)
        generate_mutational_heatmap(model, tokenizer, positive_candidate, target_class_idx, pdb_dir, analysis_dir); run_structural_occlusion(model, tokenizer, positive_candidate, target_class_idx, pdb_dir, analysis_dir)
    if negative_candidate:
        print(f"   Analyzing negative peptide '{negative_candidate['id']}' for counterfactuals"); analysis_dir = output_dir / "advanced_analysis" / negative_candidate['id']; analysis_dir.mkdir(parents=True, exist_ok=True)
        run_greedy_counterfactual_search(model, tokenizer, negative_candidate, target_class_idx, pdb_dir, analysis_dir)
def generate_mutational_heatmap(model, tokenizer, peptide, target_class_idx, pdb_dir, output_dir):
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"; seq = peptide['sequence']; heatmap = np.zeros((len(seq), len(amino_acids)))
    with torch.no_grad():
        pdb_file = pdb_dir / f"seq{peptide['seq_index']}.pdb"; coords = parse_pdb_file(pdb_file)
        if coords.shape[0] != len(seq): return
        graph = Batch.from_data_list([create_graph_from_coords(coords)]).to(model.device)
        tokens = tokenizer([seq], return_tensors='pt', padding=True, truncation=True, max_length=1022).to(model.device); baseline_logit = model(tokens, graph)['logits'][0, target_class_idx].item()
        for i in tqdm(range(len(seq)), desc="     [In Silico Mutagenesis]"):
            for j, new_aa in enumerate(amino_acids):
                if new_aa == seq[i]: continue
                mutated_seq = list(seq); mutated_seq[i] = new_aa; mutated_seq = "".join(mutated_seq)
                tokens = tokenizer([mutated_seq], return_tensors='pt', padding=True, truncation=True, max_length=1022).to(model.device)
                heatmap[i, j] = model(tokens, graph)['logits'][0, target_class_idx].item() - baseline_logit
    plt.figure(figsize=(12, max(8, len(seq)//3))); sns.heatmap(heatmap, cmap="RdBu_r", center=0, xticklabels=list(amino_acids), yticklabels=list(seq)); plt.xlabel("Mutate to Amino Acid"); plt.ylabel("Position in Sequence"); plt.title(f"Mutational Sensitivity for {peptide['id']}")
    save_path = output_dir / "mutational_heatmap.png"; plt.savefig(save_path, dpi=300); plt.close(); print(f"     -> Saved mutational heatmap to {save_path}")
def run_structural_occlusion(model, tokenizer, peptide, target_class_idx, pdb_dir, output_dir):
    seq = peptide['sequence']; motif_size = 10; motifs = {f"Res {i*motif_size+1}-{(i+1)*motif_size}": list(range(i*motif_size, min((i+1)*motif_size, len(seq)))) for i in range((len(seq) + motif_size - 1) // motif_size)}; score_drops = {}
    with torch.no_grad():
        tokens = tokenizer([seq], return_tensors='pt', padding=True, truncation=True, max_length=1022).to(model.device)
        pdb_file = pdb_dir / f"seq{peptide['seq_index']}.pdb"; coords = parse_pdb_file(pdb_file)
        if coords.shape[0] != len(seq): return
        graph = Batch.from_data_list([create_graph_from_coords(coords)]).to(model.device)
        baseline_score = torch.sigmoid(model(tokens, graph)['logits'])[0, target_class_idx].item()
        for motif_name, indices in tqdm(motifs.items(), desc="     [Structural Occlusion]  "):
            occluded_score = torch.sigmoid(model(tokens, graph, occluded_indices=indices)['logits'])[0, target_class_idx].item()
            score_drops[motif_name] = baseline_score - occluded_score
    plt.figure(figsize=(10, 6)); plt.bar(score_drops.keys(), score_drops.values()); plt.ylabel("Prediction Score Drop (Importance)"); plt.xticks(rotation=45, ha="right"); plt.title(f"Structural Occlusion Importance for {peptide['id']}"); plt.tight_layout()
    save_path = output_dir / "structural_occlusion.png"; plt.savefig(save_path, dpi=300); plt.close(); print(f"     -> Saved structural occlusion plot to {save_path}")
def run_greedy_counterfactual_search(model, tokenizer, peptide, target_class_idx, pdb_dir, output_dir):
    seq = peptide['sequence']; best_counterfactual = None; max_prob = -1
    with torch.no_grad():
        pdb_file = pdb_dir / f"seq{peptide['seq_index']}.pdb"; coords = parse_pdb_file(pdb_file)
        if coords.shape[0] != len(seq): return
        graph = Batch.from_data_list([create_graph_from_coords(coords)]).to(model.device)
        tokens = tokenizer([seq], return_tensors='pt', padding=True, truncation=True, max_length=1022).to(model.device); original_prob = torch.sigmoid(model(tokens, graph)['logits'])[0, target_class_idx].item()
        for i in tqdm(range(len(seq)), desc="     [Counterfactual Search] "):
            for new_aa in "ACDEFGHIKLMNPQRSTVWY":
                if new_aa == seq[i]: continue
                mutated_seq = list(seq); mutated_seq[i] = new_aa; mutated_seq = "".join(mutated_seq)
                tokens = tokenizer([mutated_seq], return_tensors='pt', padding=True, truncation=True, max_length=1022).to(model.device)
                prob = torch.sigmoid(model(tokens, graph)['logits'])[0, target_class_idx].item()
                if prob > 0.5 and prob > max_prob: max_prob = prob; best_counterfactual = (f"{seq[i]}{i+1}{new_aa}", mutated_seq)
    report_path = output_dir / "counterfactual_report.txt"
    with open(report_path, "w") as f:
        f.write(f"Counterfactual Analysis for Peptide: {peptide['id']}\n" + "="*40 + f"\nOriginal Sequence: {seq}\nOriginal Prediction Probability: {original_prob:.4f} (Inactive)\n\n")
        if best_counterfactual: f.write(f"Best single mutation to flip prediction to ACTIVE:\nMutation: {best_counterfactual[0]}\nNew Sequence: {best_counterfactual[1]}\nNew Prediction Probability: {max_prob:.4f} (Active)\n")
        else: f.write("No single-point mutation was found to flip the prediction to active.\n")
    print(f"     -> Saved counterfactual report to {report_path}")

# ==========================================================================================
# 5. MAIN SCRIPT
# ==========================================================================================
if __name__ == '__main__':
    torch.set_float32_matmul_precision('medium'); pl.seed_everything(42, workers=True)
    
    PARENT_DATA_DIR = Path("./Subset")
    PDB_SOURCE_DIR = Path("./Structures")
    OUTPUT_DIR = Path("./UpdatedFiguresInitial")
    NUM_EPOCHS = 30
    
    if not PARENT_DATA_DIR.exists():
        print("Creating dummy data directory..."); (PARENT_DATA_DIR / "Antimicrobial_Peptides").mkdir(parents=True)
        with open(PARENT_DATA_DIR / "Antimicrobial_Peptides" / "seqs.fasta", "w") as f: f.write(">AMP_1\nGLWSKIKEVGKEAAKAAAKAAGKAALGAVSEAV\n>AMP_2\nRKSNLRRIKKGIHIIKKYG\n>AMP_3_inactive\nADEFGHILMNPQSTVWY\n")
        with open(PARENT_DATA_DIR / "Antimicrobial_Peptides" / "classes.txt", "w") as f: f.write("antimicrobial, hemolytic\nantimicrobial\ninactive\n")

    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    sub_datasets = sorted([d for d in PARENT_DATA_DIR.iterdir() if d.is_dir()])
    all_reports_df = []

    for sub_dataset_path in sub_datasets:
        sub_dataset_name = sub_dataset_path.name
        print(f"\n--- 🧬 Processing Sub-dataset: {sub_dataset_name} ---")
        
        dataset_pdb_dir = PDB_SOURCE_DIR / sub_dataset_name
        
        datamodule = PeptideDataModule(data_dir=sub_dataset_path, pdb_dir=dataset_pdb_dir, tokenizer=tokenizer, batch_size=4, num_workers=2)
        datamodule.setup()
        if not datamodule.train_dataset: print("   No valid data found. Skipping.\n"); continue
        print(f"   Data split: {len(datamodule.train_dataset)} train, {len(datamodule.val_dataset)} val, {len(datamodule.test_dataset)} test.")

        model = IMPeTLightning(num_labels=datamodule.num_labels, label_graph_edge_index=datamodule.label_graph_edge_index, label_map=datamodule.label_map)
        checkpoint_callback = ModelCheckpoint(dirpath=OUTPUT_DIR / "models" / sub_dataset_name, filename='best-model', save_top_k=1, monitor='val_loss', mode='min')
        trainer = pl.Trainer(accelerator='auto', devices='auto', max_epochs=NUM_EPOCHS, callbacks=[checkpoint_callback], default_root_dir=OUTPUT_DIR / "logs", logger=True, log_every_n_steps=10)
        
        print(f"   Fine-tuning model for {NUM_EPOCHS} epochs...")
        trainer.fit(model, datamodule)
        
        print("\n   Loading best model for final evaluation on test set...")
        best_model_path = checkpoint_callback.best_model_path
        if not best_model_path:
            print("   No best model was saved. Evaluating last model. Results may be suboptimal.")
            best_model_for_eval = model
        else:
            best_model_for_eval = IMPeTLightning.load_from_checkpoint(best_model_path, map_location=torch.device('cuda'))
        
        cpu_trainer = pl.Trainer(accelerator='gpu', devices=1, logger=False)
        test_outputs = cpu_trainer.predict(best_model_for_eval, datamodule.test_dataloader())
        if test_outputs and any(test_outputs):
            all_preds = torch.cat([x['preds'] for x in test_outputs if x is not None]).cpu().numpy()
            all_true = torch.cat([x['labels'] for x in test_outputs if x is not None]).cpu().numpy()
            binary_preds = (all_preds > 0.5).astype(int)
            
            report_dict = classification_report(all_true, binary_preds, target_names=list(datamodule.label_map.keys()), output_dict=True, zero_division=0)
            report_df = pd.DataFrame(report_dict).transpose().reset_index().rename(columns={'index': 'class'})
            report_df['dataset'] = sub_dataset_name
            all_reports_df.append(report_df)
            print(f"   📈 Evaluation complete. Accuracy (Subset): {accuracy_score(all_true, binary_preds):.4f}")

            best_model_cpu = best_model_for_eval.to('cpu')
            
            print("   Generating standard visualizations for first batch of test set...")
            first_batch = next(iter(datamodule.test_dataloader()))
            if first_batch:
                with torch.no_grad(): 
                    output = best_model_cpu(tokenizer(first_batch['sequences'], return_tensors="pt", padding=True, truncation=True, max_length=1022), first_batch['graphs'], return_attention=True)
                for i in range(len(first_batch['ids'])):
                    peptide_id, sequence = first_batch['ids'][i], first_batch['sequences'][i]
                    single_output = {'seq_to_struct_attention': [att[i:i+1] for att in output['seq_to_struct_attention']]}
                    visualize_attention_heatmap(single_output, sequence, peptide_id, OUTPUT_DIR / "visualizations" / sub_dataset_name / peptide_id)
            
    if all_reports_df:
        final_report = pd.concat(all_reports_df)
        report_path = OUTPUT_DIR / "classification_reports.csv"
        final_report.to_csv(report_path, index=False)
        print(f"\n✅ All datasets processed. Combined classification report saved to: {report_path}")

