{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MCal Experiment: How Calibration Improves Explanation Quality\n",
    "\n",
    "This notebook demonstrates that MCal calibration improves the quality of feature attribution explanations by:\n",
    "1. Reducing missingness bias (KL divergence)\n",
    "2. Improving sufficiency scores\n",
    "3. Improving comprehensiveness scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import timm  # For loading finetuned ViT models\n",
    "\n",
    "# Add parent directory to path\n",
    "sys.path.append('..')\n",
    "\n",
    "# Import dataset classes and components\n",
    "from experiments.all_data_loaders import MRIPatchedProbDataset, MRICleanDataset\n",
    "from experiments.vision.mri_mcal_explanation_quality import load_mri_model\n",
    "from src.calibrators.mcal_ce import SimpleMCalCE\n",
    "from experiments.explanations import ImageLIME\n",
    "from experiments.metrics import (\n",
    "    ImageSufficiency, ImageComprehensiveness, ImageKLDivergence,\n",
    "    compute_metric_improvement\n",
    ")\n",
    "\n",
    "# Set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# Set random seeds\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "config = {\n    'n_train': 2000,\n    'n_test': 200,\n    'ablation_rate': 0.5,\n    'patch_size': 32,  # Changed to 32 to match finetuned models\n    'explanation_samples': 100,\n    'k_values': list(range(1, 26)),  # 1 to 25 for reasonable evaluation with 7x7 patch grid\n    'batch_size': 32,\n    'device': device\n}"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Load Model and Prepare Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model trained WITHOUT ablation (to demonstrate MCal's improvement)\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "# Use model trained with p_ablate=0.00 to show missingness bias\n",
    "notebook_dir = Path.cwd()  \n",
    "model_path = notebook_dir.parent / 'saved_models' / 'mri' / 'vit_mri_ps32_ablate0.00_best.pth'  # Model with missingness bias\n",
    "\n",
    "if not model_path.exists():\n",
    "    model_path = Path('../saved_models/mri/vit_mri_ps32_ablate0.00_best.pth')\n",
    "    if not model_path.exists():\n",
    "        raise FileNotFoundError(f\"Model not found. Please ensure the model exists at: {model_path}\")\n",
    "\n",
    "print(f\"Loading model from: {model_path}\")\n",
    "\n",
    "# Load checkpoint\n",
    "checkpoint = torch.load(model_path, map_location=device)\n",
    "print(f\"Model trained for {checkpoint['epoch']+1} epochs with best val loss: {checkpoint['best_val_loss']:.4f}\")\n",
    "print(f\"Model was trained with p_ablate={checkpoint['config']['p_ablate']:.2f} (clean data only - has missingness bias)\")\n",
    "\n",
    "# Create model and load state dict\n",
    "base_model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4)\n",
    "base_model.load_state_dict(checkpoint['model_state_dict'])\n",
    "base_model = base_model.to(device)\n",
    "base_model.eval()\n",
    "uncal_model = base_model\n",
    "\n",
    "# Create datasets with patch_size=32 (matching the finetuned model)\n",
    "train_ablated_dataset = MRIPatchedProbDataset(\n",
    "    split='train',\n",
    "    n_samples=config['n_train'],\n",
    "    p_ablate=config['ablation_rate'],\n",
    "    patch_size=config['patch_size'],\n",
    "    seed=42\n",
    ")\n",
    "test_dataset = MRICleanDataset(split='test', n_samples=config['n_test'])\n",
    "\n",
    "# Create dataloaders\n",
    "train_ablated_loader = DataLoader(train_ablated_dataset, batch_size=config['batch_size'], shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Train MCal Calibrator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Get labels and ablated predictions for training\ntrain_labels = []\nablated_logits = []\n\nwith torch.no_grad():\n    # Get predictions and labels from ablated data\n    for images, labels in tqdm(train_ablated_loader, desc=\"Computing ablated predictions\"):\n        images = images.to(device)\n        labels = labels.to(device)\n        train_labels.append(labels)\n        ablated_logits.append(base_model(images))\n    \n    train_labels = torch.cat(train_labels)\n    ablated_logits = torch.cat(ablated_logits)\n\n# Train MCal with default parameters\ncalibrator = SimpleMCalCE(num_classes=4).to(device)\nstats = calibrator.fit(\n    ablated_logits=ablated_logits,\n    target_labels=train_labels,\n    verbose=True\n    # Using default max_steps and lr from SimpleMCalCE\n)\n\nprint(f\"Training complete! Final Loss: {stats['loss'][-1]:.4f}, Final Accuracy: {stats['acc'][-1]:.3f}\")"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Diagnostic: Check if the base model is already handling missingness well\n",
    "print(\"Analyzing base model's response to missingness...\")\n",
    "\n",
    "# Test on a small batch to see prediction distributions\n",
    "with torch.no_grad():\n",
    "    # Get a batch of clean test data\n",
    "    test_batch_images, test_batch_labels = next(iter(test_loader))\n",
    "    test_batch_images = test_batch_images[:10].to(device)  # Just 10 samples\n",
    "    test_batch_labels = test_batch_labels[:10].to(device)\n",
    "\n",
    "    # Get predictions on clean data\n",
    "    clean_logits = base_model(test_batch_images)\n",
    "    clean_probs = torch.softmax(clean_logits, dim=1)\n",
    "\n",
    "    # Create ablated version (same ablation as training)\n",
    "    from experiments.all_data_loaders import _mask_random_patches_prob\n",
    "    ablated_batch = []\n",
    "    for img in test_batch_images:\n",
    "        # Manually apply ablation to the image\n",
    "        ablated_img = _mask_random_patches_prob(\n",
    "            img.cpu(),\n",
    "            mask_prob=config['ablation_rate'],\n",
    "            patch_size=config['patch_size'],\n",
    "            fill_val=0,\n",
    "            seed=None  # Random ablation\n",
    "        )\n",
    "        ablated_batch.append(ablated_img)\n",
    "\n",
    "    ablated_batch = torch.stack(ablated_batch).to(device)\n",
    "\n",
    "    # Get predictions on ablated data\n",
    "    ablated_logits = base_model(ablated_batch)\n",
    "    ablated_probs = torch.softmax(ablated_logits, dim=1)\n",
    "\n",
    "    # Calculate KL divergence between clean and ablated predictions\n",
    "    kl_div = torch.nn.functional.kl_div(\n",
    "        torch.log(ablated_probs + 1e-10),\n",
    "        clean_probs,\n",
    "        reduction='batchmean'\n",
    "    )\n",
    "\n",
    "    print(f\"Base model KL divergence (clean vs ablated): {kl_div.item():.4f}\")\n",
    "    print(f\"Clean accuracy: {(clean_probs.argmax(1) == test_batch_labels).float().mean():.2%}\")\n",
    "    print(f\"Ablated accuracy: {(ablated_probs.argmax(1) == test_batch_labels).float().mean():.2%}\")\n",
    "\n",
    "    # Check prediction stability\n",
    "    print(f\"\\nPrediction changes due to ablation:\")\n",
    "    pred_changes = (clean_probs.argmax(1) != ablated_probs.argmax(1)).float().mean()\n",
    "    print(f\"Percentage of predictions that changed: {pred_changes:.1%}\")\n",
    "\n",
    "    # Show how much the probabilities shift\n",
    "    prob_shift = torch.abs(clean_probs - ablated_probs).mean()\n",
    "    print(f\"Average probability shift: {prob_shift:.4f}\")\n",
    "\n",
    "print(\"\\nNote: If KL divergence is already very low (<0.01), the model is already robust to missingness.\")\n",
    "print(\"Additional calibration might not help or could even hurt.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot training curves\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
    "ax1.plot(stats['loss'])\n",
    "ax1.set_xlabel('Step')\n",
    "ax1.set_ylabel('Loss')\n",
    "ax1.set_title('MCal Training Loss')\n",
    "ax1.set_yscale('log')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "ax2.plot(stats['acc'])\n",
    "ax2.set_xlabel('Step')\n",
    "ax2.set_ylabel('Accuracy')\n",
    "ax2.set_title('MCal Training Accuracy')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 5. Generate All Explanations (LIME and SHAP)"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Create calibrated model\ncalib_model = nn.Sequential(base_model, calibrator)\ncalib_model.eval()\n\n# Load test data\ntest_images = []\ntest_labels = []\nfor images, labels in test_loader:\n    test_images.append(images)\n    test_labels.append(labels)\ntest_images = torch.cat(test_images).to(device)\ntest_labels = torch.cat(test_labels).to(device)\n\nprint(f\"Loaded {len(test_images)} test samples\")"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Initialize ALL explainers (LIME and SHAP)\nfrom experiments.explanations import ImageLIME, ImageKernelSHAP\n\nprint(\"Initializing explainers...\")\n\n# Use the same number of samples for both LIME and SHAP\nexplanation_samples = 100  # Same for both methods\n\n# LIME explainers\nuncal_lime_explainer = ImageLIME(\n    model=uncal_model,\n    num_samples=explanation_samples,\n    patch_size=config['patch_size'],\n    image_size=224\n)\n\ncalib_lime_explainer = ImageLIME(\n    model=calib_model,\n    num_samples=explanation_samples,\n    patch_size=config['patch_size'],\n    image_size=224\n)\n\n# SHAP explainers (now using same number of samples as LIME)\nuncal_shap_explainer = ImageKernelSHAP(\n    model=uncal_model,\n    num_samples=explanation_samples,\n    patch_size=config['patch_size'],\n    image_size=224\n)\n\ncalib_shap_explainer = ImageKernelSHAP(\n    model=calib_model,\n    num_samples=explanation_samples,\n    patch_size=config['patch_size'],\n    image_size=224\n)\n\nprint(f\"✓ LIME explainers initialized with {explanation_samples} samples\")\nprint(f\"✓ SHAP explainers initialized with {explanation_samples} samples\")"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load test data and generate explanations\n",
    "test_images = []\n",
    "test_labels = []\n",
    "for images, labels in test_loader:\n",
    "    test_images.append(images)\n",
    "    test_labels.append(labels)\n",
    "test_images = torch.cat(test_images).to(device)\n",
    "test_labels = torch.cat(test_labels).to(device)\n",
    "\n",
    "# Generate LIME explanations\n",
    "uncal_attrs = []\n",
    "calib_attrs = []\n",
    "\n",
    "for i, (image, label) in enumerate(tqdm(zip(test_images, test_labels), total=len(test_images), desc=\"Generating explanations\")):\n",
    "    uncal_attrs.append(uncal_explainer.explain_instance(image, label.item()))\n",
    "    calib_attrs.append(calib_explainer.explain_instance(image, label.item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Evaluate Metrics Across K Values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize metrics\n",
    "kl_metric = ImageKLDivergence(n_classes=4)\n",
    "suff_metric = ImageSufficiency(patch_size=config['patch_size'], image_size=224)\n",
    "comp_metric = ImageComprehensiveness(patch_size=config['patch_size'], image_size=224)\n",
    "\n",
    "# Initialize results storage\n",
    "results = {\n",
    "    'k_values': config['k_values'],\n",
    "    'missingness_bias': {'uncal': [], 'calib': []},\n",
    "    'sufficiency': {'uncal': [], 'calib': []},\n",
    "    'comprehensiveness': {'uncal': [], 'calib': []}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Evaluate LIME metrics across K values using pre-generated explanations\nprint(\"Evaluating LIME explanations across K values...\")\n\nfor k in tqdm(config['k_values'], desc=\"Evaluating LIME k values\"):\n    suff_uncal_list = []\n    suff_calib_list = []\n    comp_uncal_list = []\n    comp_calib_list = []\n    masked_images_uncal = []\n    masked_images_calib = []\n\n    for i, (image, label) in enumerate(zip(test_images, test_labels)):\n        label_idx = label.item()\n        uncal_importance = uncal_lime_attrs[i]  # Use pre-generated LIME explanations\n        calib_importance = calib_lime_attrs[i]  # Use pre-generated LIME explanations\n\n        # Compute metrics\n        suff_uncal_list.append(suff_metric.compute(uncal_model, image, uncal_importance, k, label_idx))\n        suff_calib_list.append(suff_metric.compute(calib_model, image, calib_importance, k, label_idx))\n        comp_uncal_list.append(comp_metric.compute(uncal_model, image, uncal_importance, k, label_idx))\n        comp_calib_list.append(comp_metric.compute(calib_model, image, calib_importance, k, label_idx))\n\n        # Create masked images for KL computation\n        uncal_top_k = torch.argsort(torch.abs(uncal_importance), descending=True)[:k]\n        calib_top_k = torch.argsort(torch.abs(calib_importance), descending=True)[:k]\n\n        uncal_mask = suff_metric.create_mask_from_indices(uncal_top_k).to(device)\n        calib_mask = suff_metric.create_mask_from_indices(calib_top_k).to(device)\n\n        masked_img_uncal = image.clone()\n        masked_img_calib = image.clone()\n        for c in range(3):\n            masked_img_uncal[c] = masked_img_uncal[c] * uncal_mask\n            masked_img_calib[c] = masked_img_calib[c] * calib_mask\n\n        masked_images_uncal.append(masked_img_uncal)\n        masked_images_calib.append(masked_img_calib)\n\n    # Compute KL divergence\n    masked_images_uncal = torch.stack(masked_images_uncal)\n    masked_images_calib = torch.stack(masked_images_calib)\n    \n    kl_uncal = kl_metric.compute(uncal_model, test_images, masked_images_uncal)\n    kl_calib = kl_metric.compute(calib_model, test_images, masked_images_calib)\n\n    # Store results\n    results['missingness_bias']['uncal'].append(kl_uncal)\n    results['missingness_bias']['calib'].append(kl_calib)\n    results['sufficiency']['uncal'].append(np.mean(suff_uncal_list))\n    results['sufficiency']['calib'].append(np.mean(suff_calib_list))\n    results['comprehensiveness']['uncal'].append(np.mean(comp_uncal_list))\n    results['comprehensiveness']['calib'].append(np.mean(comp_calib_list))\n\nprint(\"LIME evaluation complete!\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Visualize Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create plots\n",
    "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "\n",
    "k_vals = config['k_values']\n",
    "\n",
    "# Plot 1: Missingness Bias\n",
    "axes[0].plot(k_vals, results['missingness_bias']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\n",
    "axes[0].plot(k_vals, results['missingness_bias']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\n",
    "axes[0].set_xlabel('Top-K Features Selected')\n",
    "axes[0].set_ylabel('KL Divergence (log scale)')\n",
    "axes[0].set_title('Missingness Bias')\n",
    "axes[0].legend()\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "axes[0].set_yscale('log')\n",
    "\n",
    "# Plot 2: Sufficiency\n",
    "axes[1].plot(k_vals, results['sufficiency']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\n",
    "axes[1].plot(k_vals, results['sufficiency']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\n",
    "axes[1].set_xlabel('Top-K Features Selected')\n",
    "axes[1].set_ylabel('Sufficiency Score (↓ better)')\n",
    "axes[1].set_title('Sufficiency')\n",
    "axes[1].legend()\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "# Plot 3: Comprehensiveness\n",
    "axes[2].plot(k_vals, results['comprehensiveness']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\n",
    "axes[2].plot(k_vals, results['comprehensiveness']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\n",
    "axes[2].set_xlabel('Top-K Features Selected')\n",
    "axes[2].set_ylabel('Comprehensiveness Score (↑ better)')\n",
    "axes[2].set_title('Comprehensiveness')\n",
    "axes[2].legend()\n",
    "axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('MCal Impact on LIME Explanation Quality', fontsize=16, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Summary Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute average improvements\n",
    "avg_kl_improve = np.mean([\n",
    "    compute_metric_improvement(u, c, 'kl_divergence')\n",
    "    for u, c in zip(results['missingness_bias']['uncal'], results['missingness_bias']['calib'])\n",
    "])\n",
    "\n",
    "avg_suff_improve = np.mean([\n",
    "    compute_metric_improvement(u, c, 'sufficiency')\n",
    "    for u, c in zip(results['sufficiency']['uncal'], results['sufficiency']['calib'])\n",
    "])\n",
    "\n",
    "avg_comp_improve = np.mean([\n",
    "    compute_metric_improvement(u, c, 'comprehensiveness')\n",
    "    for u, c in zip(results['comprehensiveness']['uncal'], results['comprehensiveness']['calib'])\n",
    "])\n",
    "\n",
    "print(\"=\"*60)\n",
    "print(\"SUMMARY: MCal Improvement on LIME Explanations\")\n",
    "print(\"=\"*60)\n",
    "print(f\"\\nAverage Improvements Across All K Values:\")\n",
    "print(f\"  Missingness Bias (KL): {avg_kl_improve:+.1f}%\")\n",
    "print(f\"  Sufficiency: {avg_suff_improve:+.1f}%\")\n",
    "print(f\"  Comprehensiveness: {avg_comp_improve:+.1f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "source": "# Evaluate SHAP metrics across K values using pre-generated explanations\nprint(\"Evaluating SHAP explanations across K values...\")\n\n# Initialize SHAP results storage\nshap_results = {\n    'k_values': config['k_values'],\n    'missingness_bias': {'uncal': [], 'calib': []},\n    'sufficiency': {'uncal': [], 'calib': []},\n    'comprehensiveness': {'uncal': [], 'calib': []}\n}\n\n# Evaluate for each k value using SHAP explanations\nfor k in tqdm(config['k_values'], desc=\"Evaluating SHAP k values\"):\n    suff_uncal_list = []\n    suff_calib_list = []\n    comp_uncal_list = []\n    comp_calib_list = []\n    masked_images_uncal = []\n    masked_images_calib = []\n\n    # Use only the subset of images for which we have SHAP explanations\n    for i, (image, label) in enumerate(zip(shap_test_images, shap_test_labels)):\n        label_idx = label.item()\n        uncal_importance = uncal_shap_attrs[i]\n        calib_importance = calib_shap_attrs[i]\n\n        # Compute metrics\n        suff_uncal_list.append(suff_metric.compute(uncal_model, image, uncal_importance, k, label_idx))\n        suff_calib_list.append(suff_metric.compute(calib_model, image, calib_importance, k, label_idx))\n        comp_uncal_list.append(comp_metric.compute(uncal_model, image, uncal_importance, k, label_idx))\n        comp_calib_list.append(comp_metric.compute(calib_model, image, calib_importance, k, label_idx))\n\n        # Create masked images for KL computation\n        uncal_top_k = torch.argsort(torch.abs(uncal_importance), descending=True)[:k]\n        calib_top_k = torch.argsort(torch.abs(calib_importance), descending=True)[:k]\n\n        uncal_mask = suff_metric.create_mask_from_indices(uncal_top_k).to(device)\n        calib_mask = suff_metric.create_mask_from_indices(calib_top_k).to(device)\n\n        masked_img_uncal = image.clone()\n        masked_img_calib = image.clone()\n        for c in range(3):\n            masked_img_uncal[c] = masked_img_uncal[c] * uncal_mask\n            masked_img_calib[c] = masked_img_calib[c] * calib_mask\n\n        masked_images_uncal.append(masked_img_uncal)\n        masked_images_calib.append(masked_img_calib)\n\n    # Compute KL divergence\n    masked_images_uncal = torch.stack(masked_images_uncal)\n    masked_images_calib = torch.stack(masked_images_calib)\n    \n    kl_uncal = kl_metric.compute(uncal_model, shap_test_images, masked_images_uncal)\n    kl_calib = kl_metric.compute(calib_model, shap_test_images, masked_images_calib)\n\n    # Store results\n    shap_results['missingness_bias']['uncal'].append(kl_uncal)\n    shap_results['missingness_bias']['calib'].append(kl_calib)\n    shap_results['sufficiency']['uncal'].append(np.mean(suff_uncal_list))\n    shap_results['sufficiency']['calib'].append(np.mean(suff_calib_list))\n    shap_results['comprehensiveness']['uncal'].append(np.mean(comp_uncal_list))\n    shap_results['comprehensiveness']['calib'].append(np.mean(comp_calib_list))\n\nprint(\"SHAP evaluation complete!\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": "# Compare LIME vs SHAP results\nprint(\"=\"*60)\nprint(\"COMPARISON: LIME vs SHAP Explanations\")\nprint(\"=\"*60)\n\n# Compute average improvements for SHAP\nshap_avg_kl_improve = np.mean([\n    compute_metric_improvement(u, c, 'kl_divergence')\n    for u, c in zip(shap_results['missingness_bias']['uncal'], shap_results['missingness_bias']['calib'])\n])\n\nshap_avg_suff_improve = np.mean([\n    compute_metric_improvement(u, c, 'sufficiency')\n    for u, c in zip(shap_results['sufficiency']['uncal'], shap_results['sufficiency']['calib'])\n])\n\nshap_avg_comp_improve = np.mean([\n    compute_metric_improvement(u, c, 'comprehensiveness')\n    for u, c in zip(shap_results['comprehensiveness']['uncal'], shap_results['comprehensiveness']['calib'])\n])\n\nprint(f\"\\nLIME-based MCal Improvements:\")\nprint(f\"  Missingness Bias (KL): {avg_kl_improve:+.1f}%\")\nprint(f\"  Sufficiency: {avg_suff_improve:+.1f}%\")\nprint(f\"  Comprehensiveness: {avg_comp_improve:+.1f}%\")\n\nprint(f\"\\nSHAP-based MCal Improvements:\")\nprint(f\"  Missingness Bias (KL): {shap_avg_kl_improve:+.1f}%\")\nprint(f\"  Sufficiency: {shap_avg_suff_improve:+.1f}%\")\nprint(f\"  Comprehensiveness: {shap_avg_comp_improve:+.1f}%\")\n\nprint(\"\\n\" + \"=\"*60)\nprint(\"INSIGHT: MCal improvements should be consistent across\")\nprint(\"different explanation methods, demonstrating that the\")\nprint(\"calibration benefits are explanation-agnostic.\")\nprint(\"=\"*60)",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": "# Visualize SHAP Results\nfig, axes = plt.subplots(1, 3, figsize=(15, 5))\n\nk_vals = config['k_values']\n\n# Plot 1: Missingness Bias\naxes[0].plot(k_vals, shap_results['missingness_bias']['uncal'], 'o-', label='Uncalibrated', linewidth=2, markersize=6)\naxes[0].plot(k_vals, shap_results['missingness_bias']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2, markersize=6)\naxes[0].set_xlabel('Top-K Features Selected', fontsize=12)\naxes[0].set_ylabel('KL Divergence (log scale)', fontsize=12)\naxes[0].set_title('Missingness Bias (SHAP)', fontsize=14, fontweight='bold')\naxes[0].legend(fontsize=10)\naxes[0].grid(True, alpha=0.3)\naxes[0].set_yscale('log')\n\n# Plot 2: Sufficiency\naxes[1].plot(k_vals, shap_results['sufficiency']['uncal'], 'o-', label='Uncalibrated', linewidth=2, markersize=6)\naxes[1].plot(k_vals, shap_results['sufficiency']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2, markersize=6)\naxes[1].set_xlabel('Top-K Features Selected', fontsize=12)\naxes[1].set_ylabel('Sufficiency Score (↓ better)', fontsize=12)\naxes[1].set_title('Sufficiency (SHAP)', fontsize=14, fontweight='bold')\naxes[1].legend(fontsize=10)\naxes[1].grid(True, alpha=0.3)\n\n# Plot 3: Comprehensiveness\naxes[2].plot(k_vals, shap_results['comprehensiveness']['uncal'], 'o-', label='Uncalibrated', linewidth=2, markersize=6)\naxes[2].plot(k_vals, shap_results['comprehensiveness']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2, markersize=6)\naxes[2].set_xlabel('Top-K Features Selected', fontsize=12)\naxes[2].set_ylabel('Comprehensiveness Score (↑ better)', fontsize=12)\naxes[2].set_title('Comprehensiveness (SHAP)', fontsize=14, fontweight='bold')\naxes[2].legend(fontsize=10)\naxes[2].grid(True, alpha=0.3)\n\nplt.suptitle('MCal Impact on SHAP Explanation Quality', fontsize=16, fontweight='bold')\nplt.tight_layout()\nplt.show()\n\n# Print numerical improvements for SHAP\nshap_avg_kl_improve = np.mean([\n    compute_metric_improvement(u, c, 'kl_divergence')\n    for u, c in zip(shap_results['missingness_bias']['uncal'], shap_results['missingness_bias']['calib'])\n])\n\nshap_avg_suff_improve = np.mean([\n    compute_metric_improvement(u, c, 'sufficiency')\n    for u, c in zip(shap_results['sufficiency']['uncal'], shap_results['sufficiency']['calib'])\n])\n\nshap_avg_comp_improve = np.mean([\n    compute_metric_improvement(u, c, 'comprehensiveness')\n    for u, c in zip(shap_results['comprehensiveness']['uncal'], shap_results['comprehensiveness']['calib'])\n])\n\nprint(\"\\n\" + \"=\"*60)\nprint(\"SHAP-based MCal Improvements:\")\nprint(\"=\"*60)\nprint(f\"  Missingness Bias (KL): {shap_avg_kl_improve:+.1f}%\")\nprint(f\"  Sufficiency: {shap_avg_suff_improve:+.1f}%\")\nprint(f\"  Comprehensiveness: {shap_avg_comp_improve:+.1f}%\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": "## 10. Compare LIME vs SHAP Results",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": "# Create comprehensive comparison plot with both LIME and SHAP\nfig, axes = plt.subplots(2, 3, figsize=(15, 10))\n\nk_vals = config['k_values']\n\n# Row 1: LIME Results\n# Plot 1: Missingness Bias (LIME)\naxes[0,0].plot(k_vals, results['missingness_bias']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\naxes[0,0].plot(k_vals, results['missingness_bias']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\naxes[0,0].set_xlabel('Top-K Features')\naxes[0,0].set_ylabel('KL Divergence (log)')\naxes[0,0].set_title('Missingness Bias (LIME)', fontweight='bold')\naxes[0,0].legend()\naxes[0,0].grid(True, alpha=0.3)\naxes[0,0].set_yscale('log')\n\n# Plot 2: Sufficiency (LIME)\naxes[0,1].plot(k_vals, results['sufficiency']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\naxes[0,1].plot(k_vals, results['sufficiency']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\naxes[0,1].set_xlabel('Top-K Features')\naxes[0,1].set_ylabel('Sufficiency (↓ better)')\naxes[0,1].set_title('Sufficiency (LIME)', fontweight='bold')\naxes[0,1].legend()\naxes[0,1].grid(True, alpha=0.3)\n\n# Plot 3: Comprehensiveness (LIME)\naxes[0,2].plot(k_vals, results['comprehensiveness']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\naxes[0,2].plot(k_vals, results['comprehensiveness']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\naxes[0,2].set_xlabel('Top-K Features')\naxes[0,2].set_ylabel('Comprehensiveness (↑ better)')\naxes[0,2].set_title('Comprehensiveness (LIME)', fontweight='bold')\naxes[0,2].legend()\naxes[0,2].grid(True, alpha=0.3)\n\n# Row 2: SHAP Results\n# Plot 4: Missingness Bias (SHAP)\naxes[1,0].plot(k_vals, shap_results['missingness_bias']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\naxes[1,0].plot(k_vals, shap_results['missingness_bias']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\naxes[1,0].set_xlabel('Top-K Features')\naxes[1,0].set_ylabel('KL Divergence (log)')\naxes[1,0].set_title('Missingness Bias (SHAP)', fontweight='bold')\naxes[1,0].legend()\naxes[1,0].grid(True, alpha=0.3)\naxes[1,0].set_yscale('log')\n\n# Plot 5: Sufficiency (SHAP)\naxes[1,1].plot(k_vals, shap_results['sufficiency']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\naxes[1,1].plot(k_vals, shap_results['sufficiency']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\naxes[1,1].set_xlabel('Top-K Features')\naxes[1,1].set_ylabel('Sufficiency (↓ better)')\naxes[1,1].set_title('Sufficiency (SHAP)', fontweight='bold')\naxes[1,1].legend()\naxes[1,1].grid(True, alpha=0.3)\n\n# Plot 6: Comprehensiveness (SHAP)\naxes[1,2].plot(k_vals, shap_results['comprehensiveness']['uncal'], 'o-', label='Uncalibrated', linewidth=2)\naxes[1,2].plot(k_vals, shap_results['comprehensiveness']['calib'], 's-', label='MCal Calibrated', color='green', linewidth=2)\naxes[1,2].set_xlabel('Top-K Features')\naxes[1,2].set_ylabel('Comprehensiveness (↑ better)')\naxes[1,2].set_title('Comprehensiveness (SHAP)', fontweight='bold')\naxes[1,2].legend()\naxes[1,2].grid(True, alpha=0.3)\n\nplt.suptitle('MCal Impact on Explanation Quality: LIME vs SHAP Comparison', fontsize=16, fontweight='bold')\nplt.tight_layout()\nplt.show()",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": "# Compare LIME vs SHAP results\nprint(\"=\"*60)\nprint(\"COMPARISON: LIME vs SHAP Explanations\")\nprint(\"=\"*60)\n\n# Compute average improvements for SHAP\nshap_avg_kl_improve = np.mean([\n    compute_metric_improvement(u, c, 'kl_divergence')\n    for u, c in zip(shap_results['missingness_bias']['uncal'], shap_results['missingness_bias']['calib'])\n])\n\nshap_avg_suff_improve = np.mean([\n    compute_metric_improvement(u, c, 'sufficiency')\n    for u, c in zip(shap_results['sufficiency']['uncal'], shap_results['sufficiency']['calib'])\n])\n\nshap_avg_comp_improve = np.mean([\n    compute_metric_improvement(u, c, 'comprehensiveness')\n    for u, c in zip(shap_results['comprehensiveness']['uncal'], shap_results['comprehensiveness']['calib'])\n])\n\nprint(f\"\\nLIME-based MCal Improvements:\")\nprint(f\"  Missingness Bias (KL): {avg_kl_improve:+.1f}%\")\nprint(f\"  Sufficiency: {avg_suff_improve:+.1f}%\")\nprint(f\"  Comprehensiveness: {avg_comp_improve:+.1f}%\")\n\nprint(f\"\\nSHAP-based MCal Improvements:\")\nprint(f\"  Missingness Bias (KL): {shap_avg_kl_improve:+.1f}%\")\nprint(f\"  Sufficiency: {shap_avg_suff_improve:+.1f}%\")\nprint(f\"  Comprehensiveness: {shap_avg_comp_improve:+.1f}%\")\n\nprint(\"\\n\" + \"=\"*60)\nprint(\"KEY INSIGHT:\")\nprint(\"Both LIME and SHAP show similar improvements from MCal,\")\nprint(\"validating that MCal's benefits are explanation-method agnostic.\")\nprint(\"=\"*60)",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": "# Evaluate SHAP metrics\nshap_results = {\n    'k_values': config['k_values'],\n    'missingness_bias': {'uncal': [], 'calib': []},\n    'sufficiency': {'uncal': [], 'calib': []},\n    'comprehensiveness': {'uncal': [], 'calib': []}\n}\n\nprint(\"Evaluating SHAP-based metrics...\")\nfor k in tqdm(config['k_values'], desc=\"Evaluating SHAP k values\"):\n    suff_uncal_list = []\n    suff_calib_list = []\n    comp_uncal_list = []\n    comp_calib_list = []\n    masked_images_uncal = []\n    masked_images_calib = []\n\n    for i, (image, label) in enumerate(zip(shap_test_images, shap_test_labels)):\n        label_idx = label.item()\n        uncal_importance = uncal_shap_attrs[i]\n        calib_importance = calib_shap_attrs[i]\n\n        # Compute metrics\n        suff_uncal_list.append(suff_metric.compute(uncal_model, image, uncal_importance, k, label_idx))\n        suff_calib_list.append(suff_metric.compute(calib_model, image, calib_importance, k, label_idx))\n        comp_uncal_list.append(comp_metric.compute(uncal_model, image, uncal_importance, k, label_idx))\n        comp_calib_list.append(comp_metric.compute(calib_model, image, calib_importance, k, label_idx))\n\n        # Create masked images for KL computation\n        uncal_top_k = torch.argsort(torch.abs(uncal_importance), descending=True)[:k]\n        calib_top_k = torch.argsort(torch.abs(calib_importance), descending=True)[:k]\n\n        uncal_mask = suff_metric.create_mask_from_indices(uncal_top_k).to(device)\n        calib_mask = suff_metric.create_mask_from_indices(calib_top_k).to(device)\n\n        masked_img_uncal = image.clone()\n        masked_img_calib = image.clone()\n        for c in range(3):\n            masked_img_uncal[c] = masked_img_uncal[c] * uncal_mask\n            masked_img_calib[c] = masked_img_calib[c] * calib_mask\n\n        masked_images_uncal.append(masked_img_uncal)\n        masked_images_calib.append(masked_img_calib)\n\n    # Compute KL divergence\n    masked_images_uncal = torch.stack(masked_images_uncal)\n    masked_images_calib = torch.stack(masked_images_calib)\n    \n    kl_uncal = kl_metric.compute(uncal_model, shap_test_images, masked_images_uncal)\n    kl_calib = kl_metric.compute(calib_model, shap_test_images, masked_images_calib)\n\n    # Store results\n    shap_results['missingness_bias']['uncal'].append(kl_uncal)\n    shap_results['missingness_bias']['calib'].append(kl_calib)\n    shap_results['sufficiency']['uncal'].append(np.mean(suff_uncal_list))\n    shap_results['sufficiency']['calib'].append(np.mean(suff_calib_list))\n    shap_results['comprehensiveness']['uncal'].append(np.mean(comp_uncal_list))\n    shap_results['comprehensiveness']['calib'].append(np.mean(comp_calib_list))",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": "# Import and initialize SHAP explainers\nfrom experiments.explanations import ImageSHAP\n\n# Use fewer samples for SHAP due to computational cost\nshap_samples = 50\n\nuncal_shap_explainer = ImageSHAP(\n    model=uncal_model,\n    num_samples=shap_samples,\n    patch_size=config['patch_size'],\n    image_size=224\n)\n\ncalib_shap_explainer = ImageSHAP(\n    model=calib_model,\n    num_samples=shap_samples,\n    patch_size=config['patch_size'],\n    image_size=224\n)\n\nprint(f\"Initialized SHAP explainers with {shap_samples} samples\")\nprint(f\"Patch size: {config['patch_size']}x{config['patch_size']}\")",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}