{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab61885-3ab0-43d7-a757-21c5e8c96966",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary size: 24\n",
      "Special tokens: PAD=0, UNK=1, SEP=2, EOS=3\n",
      "Starting Ablation Study...\n",
      "============================================================\n",
      "Loading data from data.csv\n",
      "Original data: 34618 samples\n",
      "Mapped columns: ['CDR3', 'MHC', 'Epitope'] -> ['cdr3', 'mhc', 'epitope']\n",
      "Removed 29594 duplicate samples\n",
      "Final data: 5024 unique positive samples\n",
      "Length ranges - TCR: 7-22, Peptide: 8-27, MHC: 34-62\n",
      "Creating balanced dataset with 1:10 expansion strategy\n",
      "Starting with 5024 positive samples\n",
      "Available for negative generation: 187 peptides, 35 MHCs, 4687 TCRs\n",
      "Generating 50240 negative samples...\n",
      "Generated 50240 negative samples\n",
      "Expanding positive samples to 50240\n",
      "Final balanced dataset: 100480 samples\n",
      "  - Positive: 50240 (expansion: 10.0x)\n",
      "  - Negative: 50240 (expansion: 10.0x)\n",
      "  - Balance ratio: 1.000\n",
      "Created 100480 samples, 50240 with generation tasks\n",
      "\n",
      "==================================================\n",
      "Running ablation: baseline\n",
      "Config: {'use_attention': True, 'use_residual': True, 'use_layer_norm': True, 'use_positional_encoding': True, 'use_discriminator': True, 'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'], 'shared_encoder': True}\n",
      "==================================================\n",
      "  Fold 1/3\n",
      "  Fold 2/3\n",
      "  Fold 3/3\n",
      "  Completed baseline\n",
      "\n",
      "==================================================\n",
      "Running ablation: no_attention\n",
      "Config: {'use_attention': False, 'use_residual': True, 'use_layer_norm': True, 'use_positional_encoding': True, 'use_discriminator': True, 'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'], 'shared_encoder': True}\n",
      "==================================================\n",
      "  Fold 1/3\n",
      "  Fold 2/3\n",
      "  Fold 3/3\n",
      "  Completed no_attention\n",
      "\n",
      "==================================================\n",
      "Running ablation: no_positional_encoding\n",
      "Config: {'use_attention': True, 'use_residual': True, 'use_layer_norm': True, 'use_positional_encoding': False, 'use_discriminator': True, 'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'], 'shared_encoder': True}\n",
      "==================================================\n",
      "  Fold 1/3\n",
      "  Fold 2/3\n",
      "  Fold 3/3\n",
      "  Completed no_positional_encoding\n",
      "\n",
      "==================================================\n",
      "Running ablation: no_layer_norm\n",
      "Config: {'use_attention': True, 'use_residual': True, 'use_layer_norm': False, 'use_positional_encoding': True, 'use_discriminator': True, 'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'], 'shared_encoder': True}\n",
      "==================================================\n",
      "  Fold 1/3\n",
      "  Fold 2/3\n",
      "  Fold 3/3\n",
      "  Completed no_layer_norm\n",
      "\n",
      "==================================================\n",
      "Running ablation: no_discriminator\n",
      "Config: {'use_attention': True, 'use_residual': True, 'use_layer_norm': True, 'use_positional_encoding': True, 'use_discriminator': False, 'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'], 'shared_encoder': True}\n",
      "==================================================\n",
      "  Fold 1/3\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score\n",
    "import os\n",
    "import json\n",
    "import copy\n",
    "import itertools\n",
    "from typing import Dict, List, Any, Optional\n",
    "from collections import defaultdict\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from data import ImmuneDataProcessor, MultiTaskDataset, collate_fn, create_cv_splits\n",
    "from model import MultiTaskImmuneModel\n",
    "\n",
    "class AblationModel(nn.Module):\n",
    "    \"\"\"可配置的消融模型\"\"\"\n",
    "    def __init__(self, \n",
    "                 vocab_size: int,\n",
    "                 d_model: int = 512,\n",
    "                 max_len: int = 150,\n",
    "                 n_encoder_layers: int = 6,\n",
    "                 n_decoder_layers: int = 4,\n",
    "                 n_heads: int = 8,\n",
    "                 dropout: float = 0.1,\n",
    "                 vocab_dict: Optional[Dict] = None,\n",
    "                 # 消融配置\n",
    "                 use_attention: bool = True,\n",
    "                 use_residual: bool = True,\n",
    "                 use_layer_norm: bool = True,\n",
    "                 use_positional_encoding: bool = True,\n",
    "                 use_discriminator: bool = True,\n",
    "                 enabled_tasks: List[str] = None,\n",
    "                 shared_encoder: bool = True):\n",
    "        \n",
    "        super().__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.d_model = d_model\n",
    "        self.max_len = max_len\n",
    "        self.vocab_dict = vocab_dict or {}\n",
    "        self.enabled_tasks = enabled_tasks or ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen']\n",
    "        self.use_attention = use_attention\n",
    "        self.use_discriminator = use_discriminator\n",
    "        self.shared_encoder = shared_encoder\n",
    "        \n",
    "        # Embedding\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)\n",
    "        \n",
    "        # Positional Encoding\n",
    "        if use_positional_encoding:\n",
    "            self.pos_encoding = self._create_positional_encoding(max_len, d_model)\n",
    "        else:\n",
    "            self.pos_encoding = None\n",
    "        \n",
    "        # Encoder\n",
    "        if shared_encoder:\n",
    "            if use_attention:\n",
    "                encoder_layer = nn.TransformerEncoderLayer(\n",
    "                    d_model=d_model,\n",
    "                    nhead=n_heads,\n",
    "                    dropout=dropout,\n",
    "                    batch_first=True,\n",
    "                    norm_first=use_layer_norm\n",
    "                )\n",
    "                self.encoder = nn.TransformerEncoder(encoder_layer, n_encoder_layers)\n",
    "            else:\n",
    "                # 使用简单的LSTM替代Transformer\n",
    "                self.encoder = nn.LSTM(d_model, d_model//2, n_encoder_layers, \n",
    "                                     batch_first=True, bidirectional=True, dropout=dropout)\n",
    "        else:\n",
    "            # 为每个任务创建独立的编码器\n",
    "            self.encoders = nn.ModuleDict()\n",
    "            for task in self.enabled_tasks:\n",
    "                if 'gen' not in task:  # 分类任务\n",
    "                    if use_attention:\n",
    "                        encoder_layer = nn.TransformerEncoderLayer(\n",
    "                            d_model=d_model, nhead=n_heads, dropout=dropout, batch_first=True\n",
    "                        )\n",
    "                        self.encoders[task] = nn.TransformerEncoder(encoder_layer, n_encoder_layers)\n",
    "                    else:\n",
    "                        self.encoders[task] = nn.LSTM(d_model, d_model//2, n_encoder_layers,\n",
    "                                                    batch_first=True, bidirectional=True, dropout=dropout)\n",
    "        \n",
    "        # 分类器\n",
    "        classifier_dim = d_model\n",
    "        for task in ['pt', 'pmt', 'pm']:\n",
    "            if task in self.enabled_tasks:\n",
    "                setattr(self, f'{task}_classifier', nn.Sequential(\n",
    "                    nn.Linear(classifier_dim, classifier_dim // 2),\n",
    "                    nn.ReLU(),\n",
    "                    nn.Dropout(dropout),\n",
    "                    nn.Linear(classifier_dim // 2, 2)\n",
    "                ))\n",
    "        \n",
    "        # 判别器\n",
    "        if use_discriminator and any(task in self.enabled_tasks for task in ['pt', 'pmt', 'pm']):\n",
    "            for task in ['pt', 'pmt', 'pm']:\n",
    "                if task in self.enabled_tasks:\n",
    "                    setattr(self, f'{task}_discriminator', nn.Sequential(\n",
    "                        nn.Linear(classifier_dim, classifier_dim // 4),\n",
    "                        nn.ReLU(),\n",
    "                        nn.Dropout(dropout),\n",
    "                        nn.Linear(classifier_dim // 4, 1),\n",
    "                        nn.Sigmoid()\n",
    "                    ))\n",
    "        \n",
    "        # 生成器\n",
    "        if use_attention:\n",
    "            decoder_layer = nn.TransformerDecoderLayer(\n",
    "                d_model=d_model,\n",
    "                nhead=n_heads,\n",
    "                dropout=dropout,\n",
    "                batch_first=True,\n",
    "                norm_first=use_layer_norm\n",
    "            )\n",
    "            \n",
    "            if 'tcr_gen' in self.enabled_tasks:\n",
    "                self.tcr_generator = nn.Sequential(\n",
    "                    nn.TransformerDecoder(decoder_layer, n_decoder_layers),\n",
    "                    nn.Linear(d_model, vocab_size)\n",
    "                )\n",
    "            \n",
    "            if 'pep_gen' in self.enabled_tasks:\n",
    "                self.pep_generator = nn.Sequential(\n",
    "                    nn.TransformerDecoder(copy.deepcopy(decoder_layer), n_decoder_layers),\n",
    "                    nn.Linear(d_model, vocab_size)\n",
    "                )\n",
    "        else:\n",
    "            # 使用简单的LSTM解码器\n",
    "            if 'tcr_gen' in self.enabled_tasks:\n",
    "                self.tcr_generator = nn.Sequential(\n",
    "                    nn.LSTM(d_model, d_model, n_decoder_layers, batch_first=True, dropout=dropout),\n",
    "                    nn.Linear(d_model, vocab_size)\n",
    "                )\n",
    "            \n",
    "            if 'pep_gen' in self.enabled_tasks:\n",
    "                self.pep_generator = nn.Sequential(\n",
    "                    nn.LSTM(d_model, d_model, n_decoder_layers, batch_first=True, dropout=dropout),\n",
    "                    nn.Linear(d_model, vocab_size)\n",
    "                )\n",
    "    \n",
    "    def _create_positional_encoding(self, max_len: int, d_model: int) -> torch.Tensor:\n",
    "        \"\"\"创建位置编码\"\"\"\n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * \n",
    "                           (-np.log(10000.0) / d_model))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        return pe.unsqueeze(0)\n",
    "    \n",
    "    def encode_sequence(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, task: str = None) -> torch.Tensor:\n",
    "        \"\"\"编码序列\"\"\"\n",
    "        x = self.embedding(input_ids)\n",
    "        \n",
    "        # 添加位置编码\n",
    "        if self.pos_encoding is not None:\n",
    "            seq_len = x.size(1)\n",
    "            pos_enc = self.pos_encoding[:, :seq_len, :].to(x.device)\n",
    "            x = x + pos_enc\n",
    "        \n",
    "        # 选择编码器\n",
    "        if self.shared_encoder:\n",
    "            encoder = self.encoder\n",
    "        else:\n",
    "            encoder = self.encoders.get(task, self.encoder)\n",
    "        \n",
    "        # 编码\n",
    "        if self.use_attention:\n",
    "            # Transformer编码器\n",
    "            src_key_padding_mask = (attention_mask == 0)\n",
    "            encoded = encoder(x, src_key_padding_mask=src_key_padding_mask)\n",
    "        else:\n",
    "            # LSTM编码器\n",
    "            if isinstance(encoder, nn.LSTM):\n",
    "                packed = nn.utils.rnn.pack_padded_sequence(\n",
    "                    x, attention_mask.sum(dim=1).cpu(), batch_first=True, enforce_sorted=False\n",
    "                )\n",
    "                encoded, _ = encoder(packed)\n",
    "                encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)\n",
    "            else:\n",
    "                encoded = encoder(x)\n",
    "        \n",
    "        return encoded\n",
    "    \n",
    "    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        \"\"\"前向传播\"\"\"\n",
    "        outputs = {}\n",
    "        \n",
    "        # 分类任务\n",
    "        for task in ['pt', 'pmt', 'pm']:\n",
    "            if task in self.enabled_tasks and f'{task}_input' in batch:\n",
    "                # 编码\n",
    "                encoded = self.encode_sequence(\n",
    "                    batch[f'{task}_input'], \n",
    "                    batch[f'{task}_mask'],\n",
    "                    task\n",
    "                )\n",
    "                \n",
    "                # 池化\n",
    "                pooled = encoded.mean(dim=1)  # 简单平均池化\n",
    "                \n",
    "                # 分类\n",
    "                classifier = getattr(self, f'{task}_classifier')\n",
    "                logits = classifier(pooled)\n",
    "                confidence = torch.softmax(logits, dim=-1)[:, 1:2]\n",
    "                \n",
    "                outputs[f'{task}_logits'] = logits\n",
    "                outputs[f'{task}_confidence'] = confidence\n",
    "                \n",
    "                # 判别器\n",
    "                if self.use_discriminator and hasattr(self, f'{task}_discriminator'):\n",
    "                    discriminator = getattr(self, f'{task}_discriminator')\n",
    "                    disc_output = discriminator(pooled)\n",
    "                    outputs[f'{task}_discriminator'] = disc_output\n",
    "        \n",
    "        # 生成任务\n",
    "        if len(batch.get('positive_indices', [])) > 0:\n",
    "            for task in ['tcr_gen', 'pep_gen']:\n",
    "                if task in self.enabled_tasks and f'{task}_input' in batch:\n",
    "                    # 编码\n",
    "                    encoded = self.encode_sequence(\n",
    "                        batch[f'{task}_input'],\n",
    "                        batch[f'{task}_mask']\n",
    "                    )\n",
    "                    \n",
    "                    # 生成\n",
    "                    generator = getattr(self, f'{task[:-4]}_generator')  # 移除'_gen'后缀\n",
    "                    \n",
    "                    if self.use_attention and isinstance(generator[0], nn.TransformerDecoder):\n",
    "                        # Transformer解码器\n",
    "                        tgt = encoded  # 简化处理\n",
    "                        decoded = generator[0](tgt, encoded)\n",
    "                        gen_logits = generator[1](decoded)\n",
    "                    else:\n",
    "                        # LSTM解码器或简单线性层\n",
    "                        if isinstance(generator[0], nn.LSTM):\n",
    "                            decoded, _ = generator[0](encoded)\n",
    "                            gen_logits = generator[1](decoded)\n",
    "                        else:\n",
    "                            gen_logits = generator(encoded)\n",
    "                    \n",
    "                    outputs[f'{task}_logits'] = gen_logits\n",
    "                    outputs[f'{task}_targets'] = batch[f'{task}_target']\n",
    "        \n",
    "        return outputs\n",
    "\n",
    "class AblationTrainer:\n",
    "    \"\"\"消融实验训练器\"\"\"\n",
    "    def __init__(self, config: Dict[str, Any]):\n",
    "        self.config = config\n",
    "        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        \n",
    "        # 数据处理器\n",
    "        self.data_processor = ImmuneDataProcessor(\n",
    "            data_path=config['data_path'],\n",
    "            max_len=config.get('max_len', 120),\n",
    "            random_seed=config['seed']\n",
    "        )\n",
    "        \n",
    "        # 创建反向词汇表\n",
    "        self.vocab_dict = {v: k for k, v in self.data_processor.token_to_id.items()}\n",
    "        \n",
    "        # 损失函数\n",
    "        self.classification_loss = nn.CrossEntropyLoss()\n",
    "        self.generation_loss = nn.CrossEntropyLoss(ignore_index=0)\n",
    "        self.discriminator_loss = nn.BCELoss()\n",
    "        \n",
    "        # 结果存储\n",
    "        self.ablation_results = []\n",
    "    \n",
    "    def create_model(self, ablation_config: Dict[str, Any]) -> AblationModel:\n",
    "        \"\"\"根据消融配置创建模型\"\"\"\n",
    "        model_config = {\n",
    "            'vocab_size': self.data_processor.vocab_size,\n",
    "            'd_model': self.config.get('d_model', 512),\n",
    "            'max_len': self.config.get('max_len', 120),\n",
    "            'n_encoder_layers': self.config.get('n_encoder_layers', 6),\n",
    "            'n_decoder_layers': self.config.get('n_decoder_layers', 4),\n",
    "            'n_heads': self.config.get('n_heads', 8),\n",
    "            'dropout': self.config.get('dropout', 0.1),\n",
    "            'vocab_dict': self.vocab_dict\n",
    "        }\n",
    "        model_config.update(ablation_config)\n",
    "        \n",
    "        return AblationModel(**model_config)\n",
    "    \n",
    "    def compute_metrics(self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, float]:\n",
    "        \"\"\"计算所有指标\"\"\"\n",
    "        metrics = {}\n",
    "        \n",
    "        # 分类指标\n",
    "        for task in ['pt', 'pmt', 'pm']:\n",
    "            if f'{task}_logits' in outputs:\n",
    "                with torch.no_grad():\n",
    "                    logits = outputs[f'{task}_logits']\n",
    "                    labels = batch['labels']\n",
    "                    \n",
    "                    probs = torch.softmax(logits, dim=-1)\n",
    "                    preds = torch.argmax(logits, dim=-1)\n",
    "                    \n",
    "                    labels_cpu = labels.cpu().numpy()\n",
    "                    preds_cpu = preds.cpu().numpy()\n",
    "                    probs_cpu = probs[:, 1].cpu().numpy()\n",
    "                    \n",
    "                    acc = accuracy_score(labels_cpu, preds_cpu)\n",
    "                    precision, recall, f1, _ = precision_recall_fscore_support(\n",
    "                        labels_cpu, preds_cpu, average='binary', zero_division=0\n",
    "                    )\n",
    "                    \n",
    "                    try:\n",
    "                        auc = roc_auc_score(labels_cpu, probs_cpu)\n",
    "                    except:\n",
    "                        auc = 0.0\n",
    "                    \n",
    "                    metrics.update({\n",
    "                        f'{task}_accuracy': acc,\n",
    "                        f'{task}_precision': precision,\n",
    "                        f'{task}_recall': recall,\n",
    "                        f'{task}_f1': f1,\n",
    "                        f'{task}_auc': auc\n",
    "                    })\n",
    "        \n",
    "        # 生成指标\n",
    "        for task in ['tcr_gen', 'pep_gen']:\n",
    "            if f'{task}_logits' in outputs:\n",
    "                with torch.no_grad():\n",
    "                    logits = outputs[f'{task}_logits']\n",
    "                    targets = outputs[f'{task}_targets']\n",
    "                    \n",
    "                    # 简化的生成指标\n",
    "                    shift_logits = logits[..., :-1, :].contiguous()\n",
    "                    shift_labels = targets[..., 1:].contiguous()\n",
    "                    \n",
    "                    shift_logits = shift_logits.view(-1, shift_logits.size(-1))\n",
    "                    shift_labels = shift_labels.view(-1)\n",
    "                    \n",
    "                    loss = self.generation_loss(shift_logits, shift_labels)\n",
    "                    \n",
    "                    preds = torch.argmax(shift_logits, dim=-1)\n",
    "                    mask = shift_labels != 0\n",
    "                    \n",
    "                    if mask.sum() > 0:\n",
    "                        correct = (preds == shift_labels) & mask\n",
    "                        token_accuracy = correct.sum().float() / mask.sum().float()\n",
    "                    else:\n",
    "                        token_accuracy = torch.tensor(0.0)\n",
    "                    \n",
    "                    metrics.update({\n",
    "                        f'{task}_loss': loss.item(),\n",
    "                        f'{task}_token_accuracy': token_accuracy.item(),\n",
    "                        f'{task}_perplexity': torch.exp(loss).item()\n",
    "                    })\n",
    "        \n",
    "        return metrics\n",
    "    \n",
    "    def train_and_evaluate(self, model, train_loader, val_loader, ablation_name: str) -> Dict[str, float]:\n",
    "        \"\"\"训练并评估单个消融配置\"\"\"\n",
    "        model = model.to(self.device)\n",
    "        optimizer = optim.Adam(model.parameters(), lr=self.config.get('learning_rate', 1e-4))\n",
    "        \n",
    "        best_val_acc = 0\n",
    "        best_metrics = {}\n",
    "        \n",
    "        num_epochs = self.config.get('num_epochs', 10)\n",
    "        \n",
    "        for epoch in range(num_epochs):\n",
    "            # 训练\n",
    "            model.train()\n",
    "            train_metrics = []\n",
    "            \n",
    "            for batch in train_loader:\n",
    "                # 移动数据到设备\n",
    "                for key, value in batch.items():\n",
    "                    if isinstance(value, torch.Tensor):\n",
    "                        batch[key] = value.to(self.device)\n",
    "                \n",
    "                optimizer.zero_grad()\n",
    "                \n",
    "                try:\n",
    "                    outputs = model(batch)\n",
    "                    losses = []\n",
    "                    \n",
    "                    # 分类损失\n",
    "                    for task in ['pt', 'pmt', 'pm']:\n",
    "                        if f'{task}_logits' in outputs:\n",
    "                            cls_loss = self.classification_loss(outputs[f'{task}_logits'], batch['labels'])\n",
    "                            losses.append(cls_loss)\n",
    "                            \n",
    "                            # 判别器损失\n",
    "                            if f'{task}_discriminator' in outputs:\n",
    "                                disc_loss = self.discriminator_loss(\n",
    "                                    outputs[f'{task}_discriminator'].squeeze(),\n",
    "                                    batch['labels'].float()\n",
    "                                )\n",
    "                                losses.append(disc_loss * 0.1)\n",
    "                    \n",
    "                    # 生成损失\n",
    "                    for task in ['tcr_gen', 'pep_gen']:\n",
    "                        if f'{task}_logits' in outputs:\n",
    "                            gen_loss = self.generation_loss(\n",
    "                                outputs[f'{task}_logits'].view(-1, outputs[f'{task}_logits'].size(-1)),\n",
    "                                outputs[f'{task}_targets'].view(-1)\n",
    "                            )\n",
    "                            losses.append(gen_loss * 0.5)\n",
    "                    \n",
    "                    if losses:\n",
    "                        total_loss = sum(losses)\n",
    "                        total_loss.backward()\n",
    "                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "                        optimizer.step()\n",
    "                        \n",
    "                        # 计算指标\n",
    "                        batch_metrics = self.compute_metrics(outputs, batch)\n",
    "                        train_metrics.append(batch_metrics)\n",
    "                \n",
    "                except Exception as e:\n",
    "                    continue\n",
    "            \n",
    "            # 验证\n",
    "            model.eval()\n",
    "            val_metrics = []\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for batch in val_loader:\n",
    "                    # 移动数据到设备\n",
    "                    for key, value in batch.items():\n",
    "                        if isinstance(value, torch.Tensor):\n",
    "                            batch[key] = value.to(self.device)\n",
    "                    \n",
    "                    try:\n",
    "                        outputs = model(batch)\n",
    "                        batch_metrics = self.compute_metrics(outputs, batch)\n",
    "                        val_metrics.append(batch_metrics)\n",
    "                    except Exception:\n",
    "                        continue\n",
    "            \n",
    "            # 计算epoch平均指标\n",
    "            if val_metrics:\n",
    "                avg_val_metrics = {}\n",
    "                for key in val_metrics[0].keys():\n",
    "                    values = [m[key] for m in val_metrics if key in m]\n",
    "                    avg_val_metrics[key] = np.mean(values) if values else 0.0\n",
    "                \n",
    "                # 选择最佳模型\n",
    "                val_acc = np.mean([avg_val_metrics.get(f'{task}_accuracy', 0) for task in ['pt', 'pmt', 'pm']])\n",
    "                \n",
    "                if val_acc > best_val_acc:\n",
    "                    best_val_acc = val_acc\n",
    "                    best_metrics = avg_val_metrics.copy()\n",
    "                    best_metrics['ablation'] = ablation_name\n",
    "                    best_metrics['epoch'] = epoch\n",
    "        \n",
    "        return best_metrics\n",
    "    \n",
    "    def run_ablation_study(self):\n",
    "        \"\"\"运行完整的消融研究\"\"\"\n",
    "        print(\"Starting Ablation Study...\")\n",
    "        print(\"=\" * 60)\n",
    "        \n",
    "        # 加载数据\n",
    "        df = self.data_processor.load_and_process_data()\n",
    "        df_balanced = self.data_processor.create_balanced_dataset(df, negative_ratio=1.0)\n",
    "        dataset = self.data_processor.create_five_task_dataset(df_balanced)\n",
    "        \n",
    "        # 创建交叉验证分割\n",
    "        cv_splits = create_cv_splits(df_balanced, n_splits=self.config.get('n_folds', 3))\n",
    "        \n",
    "        # 定义消融实验\n",
    "        ablation_configs = self.define_ablation_experiments()\n",
    "        \n",
    "        all_results = []\n",
    "        \n",
    "        for ablation_name, ablation_config in ablation_configs.items():\n",
    "            print(f\"\\n{'='*50}\")\n",
    "            print(f\"Running ablation: {ablation_name}\")\n",
    "            print(f\"Config: {ablation_config}\")\n",
    "            print(f\"{'='*50}\")\n",
    "            \n",
    "            fold_results = []\n",
    "            \n",
    "            for fold, (train_indices, val_indices) in enumerate(cv_splits):\n",
    "                print(f\"  Fold {fold + 1}/{len(cv_splits)}\")\n",
    "                \n",
    "                train_data = [dataset[i] for i in train_indices]\n",
    "                val_data = [dataset[i] for i in val_indices]\n",
    "                \n",
    "                train_dataset = MultiTaskDataset(train_data)\n",
    "                val_dataset = MultiTaskDataset(val_data)\n",
    "                \n",
    "                train_loader = DataLoader(\n",
    "                    train_dataset,\n",
    "                    batch_size=self.config.get('batch_size', 32),\n",
    "                    shuffle=True,\n",
    "                    collate_fn=collate_fn\n",
    "                )\n",
    "                \n",
    "                val_loader = DataLoader(\n",
    "                    val_dataset,\n",
    "                    batch_size=self.config.get('batch_size', 32),\n",
    "                    shuffle=False,\n",
    "                    collate_fn=collate_fn\n",
    "                )\n",
    "                \n",
    "                try:\n",
    "                    # 创建模型\n",
    "                    model = self.create_model(ablation_config)\n",
    "                    \n",
    "                    # 训练和评估\n",
    "                    fold_metrics = self.train_and_evaluate(model, train_loader, val_loader, ablation_name)\n",
    "                    fold_metrics['fold'] = fold\n",
    "                    fold_results.append(fold_metrics)\n",
    "                    \n",
    "                except Exception as e:\n",
    "                    print(f\"    Error in fold {fold}: {e}\")\n",
    "                    continue\n",
    "            \n",
    "            # 计算该消融的平均结果\n",
    "            if fold_results:\n",
    "                avg_result = self.compute_average_results(fold_results, ablation_name)\n",
    "                all_results.append(avg_result)\n",
    "                print(f\"  Completed {ablation_name}\")\n",
    "            else:\n",
    "                print(f\"  Failed {ablation_name}\")\n",
    "        \n",
    "        # 保存结果\n",
    "        self.save_ablation_results(all_results)\n",
    "        return all_results\n",
    "    \n",
    "    def define_ablation_experiments(self) -> Dict[str, Dict[str, Any]]:\n",
    "        \"\"\"定义消融实验配置\"\"\"\n",
    "        ablations = {\n",
    "            # 基线完整模型\n",
    "            'baseline': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            # 架构消融\n",
    "            'no_attention': {\n",
    "                'use_attention': False,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'no_positional_encoding': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': False,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'no_layer_norm': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': False,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'no_discriminator': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': False,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'separate_encoders': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': False\n",
    "            },\n",
    "            \n",
    "            # 任务组合消融\n",
    "            'classification_only': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'generation_only': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': False,\n",
    "                'enabled_tasks': ['tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'pt_only': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'pmt_only': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pmt'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            'tcr_gen_only': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': False,\n",
    "                'enabled_tasks': ['tcr_gen'],\n",
    "                'shared_encoder': True\n",
    "            },\n",
    "            \n",
    "            # 模型大小消融\n",
    "            'small_model': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True,\n",
    "                'd_model': 256,\n",
    "                'n_encoder_layers': 3,\n",
    "                'n_decoder_layers': 2,\n",
    "                'n_heads': 4\n",
    "            },\n",
    "            \n",
    "            'large_model': {\n",
    "                'use_attention': True,\n",
    "                'use_residual': True,\n",
    "                'use_layer_norm': True,\n",
    "                'use_positional_encoding': True,\n",
    "                'use_discriminator': True,\n",
    "                'enabled_tasks': ['pt', 'pmt', 'pm', 'tcr_gen', 'pep_gen'],\n",
    "                'shared_encoder': True,\n",
    "                'd_model': 768,\n",
    "                'n_encoder_layers': 8,\n",
    "                'n_decoder_layers': 6,\n",
    "                'n_heads': 12\n",
    "            }\n",
    "        }\n",
    "        \n",
    "        return ablations\n",
    "    \n",
    "    def compute_average_results(self, fold_results: List[Dict], ablation_name: str) -> Dict[str, float]:\n",
    "        \"\"\"计算多折结果的平均值和置信区间\"\"\"\n",
    "        if not fold_results:\n",
    "            return {'ablation': ablation_name, 'status': 'failed'}\n",
    "        \n",
    "        avg_result = {'ablation': ablation_name}\n",
    "        \n",
    "        # 收集所有指标\n",
    "        all_metrics = defaultdict(list)\n",
    "        for result in fold_results:\n",
    "            for key, value in result.items():\n",
    "                if key not in ['fold', 'ablation', 'epoch']:\n",
    "                    all_metrics[key].append(value)\n",
    "        \n",
    "        # 计算统计量\n",
    "        for metric, values in all_metrics.items():\n",
    "            if values:\n",
    "                mean_val = np.mean(values)\n",
    "                std_val = np.std(values)\n",
    "                ci_95 = 1.96 * std_val / np.sqrt(len(values))\n",
    "                \n",
    "                avg_result[f'{metric}_mean'] = mean_val\n",
    "                avg_result[f'{metric}_std'] = std_val\n",
    "                avg_result[f'{metric}_ci95'] = ci_95\n",
    "                avg_result[f'{metric}_min'] = np.min(values)\n",
    "                avg_result[f'{metric}_max'] = np.max(values)\n",
    "        \n",
    "        return avg_result\n",
    "    \n",
    "    def save_ablation_results(self, results: List[Dict]):\n",
    "        \"\"\"保存消融研究结果\"\"\"\n",
    "        if not results:\n",
    "            print(\"No results to save!\")\n",
    "            return\n",
    "        \n",
    "        # 保存详细结果\n",
    "        results_df = pd.DataFrame(results)\n",
    "        output_path = 'ablation_study_results.csv'\n",
    "        results_df.to_csv(output_path, index=False)\n",
    "        \n",
    "        print(f\"\\nAblation study results saved to {output_path}\")\n",
    "        \n",
    "        # 生成汇总报告\n",
    "        self.generate_ablation_report(results_df)\n",
    "    \n",
    "    def generate_ablation_report(self, results_df: pd.DataFrame):\n",
    "        \"\"\"生成消融研究报告\"\"\"\n",
    "        print(\"\\n\" + \"=\"*80)\n",
    "        print(\"ABLATION STUDY REPORT\")\n",
    "        print(\"=\"*80)\n",
    "        \n",
    "        # 找到基线结果\n",
    "        baseline_idx = results_df[results_df['ablation'] == 'baseline'].index\n",
    "        if len(baseline_idx) > 0:\n",
    "            baseline = results_df.loc[baseline_idx[0]]\n",
    "            print(f\"\\nBaseline Performance:\")\n",
    "            for task in ['pt', 'pmt', 'pm']:\n",
    "                acc_col = f'{task}_accuracy_mean'\n",
    "                f1_col = f'{task}_f1_mean'\n",
    "                if acc_col in baseline:\n",
    "                    print(f\"  {task.upper()}: Acc={baseline[acc_col]:.4f}±{baseline.get(f'{task}_accuracy_std', 0):.4f}, \"\n",
    "                          f\"F1={baseline.get(f1_col, 0):.4f}±{baseline.get(f'{task}_f1_std', 0):.4f}\")\n",
    "            \n",
    "            for task in ['tcr_gen', 'pep_gen']:\n",
    "                acc_col = f'{task}_token_accuracy_mean'\n",
    "                if acc_col in baseline:\n",
    "                    print(f\"  {task.upper()}: Token Acc={baseline[acc_col]:.4f}±{baseline.get(f'{task}_token_accuracy_std', 0):.4f}\")\n",
    "        \n",
    "        print(f\"\\nAblation Results (showing main classification accuracy):\")\n",
    "        print(\"-\" * 80)\n",
    "        \n",
    "        # 按性能排序\n",
    "        if 'pt_accuracy_mean' in results_df.columns:\n",
    "            results_df_sorted = results_df.sort_values('pt_accuracy_mean', ascending=False)\n",
    "            \n",
    "            for _, row in results_df_sorted.iterrows():\n",
    "                ablation = row['ablation']\n",
    "                pt_acc = row.get('pt_accuracy_mean', 0)\n",
    "                pt_std = row.get('pt_accuracy_std', 0)\n",
    "                \n",
    "                if ablation == 'baseline':\n",
    "                    indicator = \" (BASELINE)\"\n",
    "                elif pt_acc > baseline.get('pt_accuracy_mean', 0):\n",
    "                    indicator = \" ↗ (BETTER)\"\n",
    "                else:\n",
    "                    indicator = \" ↘ (WORSE)\"\n",
    "                \n",
    "                print(f\"{ablation:20} | PT Acc: {pt_acc:.4f}±{pt_std:.4f}{indicator}\")\n",
    "        \n",
    "        # 关键发现\n",
    "        print(f\"\\nKey Findings:\")\n",
    "        print(\"-\" * 40)\n",
    "        \n",
    "        if len(baseline_idx) > 0:\n",
    "            baseline_acc = baseline.get('pt_accuracy_mean', 0)\n",
    "            \n",
    "            # 找出最佳和最差的配置\n",
    "            best_row = results_df.loc[results_df['pt_accuracy_mean'].idxmax()] if 'pt_accuracy_mean' in results_df.columns else None\n",
    "            worst_row = results_df.loc[results_df['pt_accuracy_mean'].idxmin()] if 'pt_accuracy_mean' in results_df.columns else None\n",
    "            \n",
    "            if best_row is not None:\n",
    "                improvement = best_row['pt_accuracy_mean'] - baseline_acc\n",
    "                print(f\"• Best configuration: {best_row['ablation']} (+{improvement:.4f})\")\n",
    "            \n",
    "            if worst_row is not None:\n",
    "                degradation = baseline_acc - worst_row['pt_accuracy_mean']\n",
    "                print(f\"• Worst configuration: {worst_row['ablation']} (-{degradation:.4f})\")\n",
    "            \n",
    "            # 分析各个组件的影响\n",
    "            component_analysis = {\n",
    "                'no_attention': 'Attention mechanism',\n",
    "                'no_positional_encoding': 'Positional encoding',\n",
    "                'no_layer_norm': 'Layer normalization',\n",
    "                'no_discriminator': 'Discriminator',\n",
    "                'separate_encoders': 'Encoder sharing'\n",
    "            }\n",
    "            \n",
    "            print(f\"• Component importance:\")\n",
    "            for ablation, component in component_analysis.items():\n",
    "                ablation_row = results_df[results_df['ablation'] == ablation]\n",
    "                if len(ablation_row) > 0:\n",
    "                    impact = baseline_acc - ablation_row.iloc[0].get('pt_accuracy_mean', 0)\n",
    "                    print(f\"  - {component}: {impact:+.4f}\")\n",
    "\n",
    "def main():\n",
    "    \"\"\"主函数\"\"\"\n",
    "    config = {\n",
    "        'data_path': 'data.csv',\n",
    "        'seed': 42,\n",
    "        'n_folds': 3,\n",
    "        'batch_size': 32,\n",
    "        'num_epochs': 8,  # 减少epoch以加快实验速度\n",
    "        'learning_rate': 1e-4,\n",
    "        'max_len': 120,\n",
    "        'd_model': 512,\n",
    "        'n_encoder_layers': 6,\n",
    "        'n_decoder_layers': 4,\n",
    "        'n_heads': 8,\n",
    "        'dropout': 0.1,\n",
    "    }\n",
    "    \n",
    "    # 设置随机种子\n",
    "    torch.manual_seed(config['seed'])\n",
    "    np.random.seed(config['seed'])\n",
    "    \n",
    "    # 运行消融研究\n",
    "    trainer = AblationTrainer(config)\n",
    "    results = trainer.run_ablation_study()\n",
    "    \n",
    "    print(f\"\\nAblation study completed!\")\n",
    "    print(f\"Check ablation_study_results.csv for detailed results.\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cef33594-b1ab-4378-868b-a7deb4884012",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "conda-base-py",
   "name": "workbench-notebooks.m128",
   "type": "gcloud",
   "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel) (Local)",
   "language": "python",
   "name": "conda-base-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
