{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\r\n",
    "os.kill(os.getpid(), 9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Mount Google Drive\r\n",
    "\r\n",
    "from google.colab import drive\r\n",
    "drive.mount('/content/drive')\r\n",
    "\r\n",
    "print(\"Success\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Dependencies\r\n",
    "\r\n",
    "!pip install timm wget\r\n",
    "!pip install torchmetrics\r\n",
    "!pip install pytorchcv\r\n",
    "\r\n",
    "import sys\r\n",
    "sys.path.append(\"/content/model\")\r\n",
    "from IPython.display import clear_output\r\n",
    "clear_output()\r\n",
    "\r\n",
    "import os, shutil, random, subprocess\r\n",
    "import wget\r\n",
    "from tqdm import tqdm\r\n",
    "from functools import partial\r\n",
    "import json\r\n",
    "import matplotlib.pyplot as plt\r\n",
    "import time\r\n",
    "\r\n",
    "\r\n",
    "drive_root = \"/content/drive/MyDrive/model_calibration/\"\r\n",
    "local_root = \"/content/\"\r\n",
    "dataset_dir = os.path.join(local_root, \"data\")\r\n",
    "checkpoint_dir = os.path.join(drive_root,\"checkpoints\")\r\n",
    "figure_dir = os.path.join(drive_root,\"figures\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Download CIFAR-10 and CIFAR-100 with balanced train/val/test split (deterministic)\r\n",
    "import os\r\n",
    "import shutil\r\n",
    "import numpy as np # Use numpy for efficient array operations and random number generation\r\n",
    "from torchvision.datasets import CIFAR10, CIFAR100\r\n",
    "from torchvision.datasets.utils import check_integrity, download_and_extract_archive\r\n",
    "from tqdm import tqdm # For progress bars\r\n",
    "from collections import Counter # For verification\r\n",
    "\r\n",
    "# --- Configuration ---\r\n",
    "DATASET_DIR = \"./data\"  # Root directory to save datasets\r\n",
    "VAL_SPLIT_RATIO = 0.1   # Use 10% of the original training data from EACH class for validation\r\n",
    "RANDOM_SEED = 42        # Seed for reproducible train/val split\r\n",
    "\r\n",
    "# Datasets to process\r\n",
    "DATASETS_CONFIG = [\r\n",
    "    {\r\n",
    "        \"name\": \"cifar-10\",\r\n",
    "        \"cls\": CIFAR10,\r\n",
    "        \"url\": \"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\",\r\n",
    "        \"filename\": \"cifar-10-python.tar.gz\",\r\n",
    "        \"tgz_md5\": \"c58f30108f718f92721af3b95e74349a\",\r\n",
    "        \"base_folder\": \"cifar-10-batches-py\",\r\n",
    "    },\r\n",
    "    {\r\n",
    "        \"name\": \"cifar-100\",\r\n",
    "        \"cls\": CIFAR100,\r\n",
    "        \"url\": \"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz\",\r\n",
    "        \"filename\": \"cifar-100-python.tar.gz\",\r\n",
    "        \"tgz_md5\": \"eb9058c3a382ffc7106e4002c42a8d85\",\r\n",
    "        \"base_folder\": \"cifar-100-python\",\r\n",
    "    },\r\n",
    "]\r\n",
    "\r\n",
    "# --- Helper Functions ---\r\n",
    "\r\n",
    "def save_images_to_folders(dataset, indices, target_dir, class_names, dataset_name, split_name):\r\n",
    "    \"\"\"\r\n",
    "    Saves images from a dataset subset into the target directory structure.\r\n",
    "    Structure: target_dir / class_name / original_index.png\r\n",
    "    \"\"\"\r\n",
    "    print(f\"📦 [{dataset_name}] Processing and saving {split_name} set ({len(indices)} images)...\")\r\n",
    "    os.makedirs(target_dir, exist_ok=True) # Ensure target directory exists\r\n",
    "\r\n",
    "    saved_count = 0\r\n",
    "    for idx in tqdm(indices, desc=f\"Saving {split_name} images\"):\r\n",
    "        try:\r\n",
    "            img, label_index = dataset[idx] # dataset[idx] returns (PIL Image, label_index)\r\n",
    "            class_name = class_names[label_index]\r\n",
    "            class_dir = os.path.join(target_dir, class_name)\r\n",
    "            os.makedirs(class_dir, exist_ok=True) # Ensure class directory exists\r\n",
    "            img_path = os.path.join(class_dir, f\"{idx}.png\") # Use original index in filename\r\n",
    "            img.save(img_path)\r\n",
    "            saved_count += 1\r\n",
    "        except IndexError:\r\n",
    "            print(f\"Warning: Index {idx} out of bounds for dataset {dataset_name}. Skipping.\")\r\n",
    "        except Exception as e:\r\n",
    "            print(f\"Error processing index {idx} for {dataset_name}: {e}. Skipping.\")\r\n",
    "\r\n",
    "    print(f\"✅ [{dataset_name}] Saved {saved_count} images to {target_dir}\")\r\n",
    "    # Verification step (optional but good)\r\n",
    "    verify_saved_count(target_dir, len(indices), split_name)\r\n",
    "\r\n",
    "\r\n",
    "def verify_saved_count(split_dir, expected_count, split_name):\r\n",
    "    \"\"\"Counts files and verifies against expected count.\"\"\"\r\n",
    "    actual_count = 0\r\n",
    "    if not os.path.isdir(split_dir):\r\n",
    "        print(f\"❌ Error: Directory not found {split_dir} for verification.\")\r\n",
    "        return False\r\n",
    "\r\n",
    "    for class_folder in os.listdir(split_dir):\r\n",
    "        class_path = os.path.join(split_dir, class_folder)\r\n",
    "        if os.path.isdir(class_path):\r\n",
    "            try:\r\n",
    "                num_files = len([f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))])\r\n",
    "                actual_count += num_files\r\n",
    "            except OSError as e:\r\n",
    "                print(f\"Could not access files in {class_path} for verification: {e}\")\r\n",
    "\r\n",
    "    print(f\"📊 [{split_name}] Verification: Expected {expected_count}, Found {actual_count} images.\")\r\n",
    "    if actual_count != expected_count:\r\n",
    "        print(f\"⚠️ Warning: Mismatch in expected vs found images for {split_name} set in {split_dir}!\")\r\n",
    "        return False\r\n",
    "    return True\r\n",
    "\r\n",
    "# --- Main Script ---\r\n",
    "\r\n",
    "# Use a seeded random number generator for reproducibility\r\n",
    "rng = np.random.default_rng(RANDOM_SEED)\r\n",
    "\r\n",
    "# Ensure the root data directory exists\r\n",
    "os.makedirs(DATASET_DIR, exist_ok=True)\r\n",
    "\r\n",
    "for cfg in DATASETS_CONFIG:\r\n",
    "    dataset_name = cfg[\"name\"]\r\n",
    "    ds_class = cfg[\"cls\"]\r\n",
    "    print(f\"\\n{'='*10} Processing {dataset_name} {'='*10}\")\r\n",
    "\r\n",
    "    # --- Define Paths ---\r\n",
    "    dataset_root = os.path.join(DATASET_DIR, dataset_name)\r\n",
    "    train_dir = os.path.join(dataset_root, \"train\")\r\n",
    "    val_dir = os.path.join(dataset_root, \"val\")\r\n",
    "    test_dir = os.path.join(dataset_root, \"test\")\r\n",
    "    download_root = DATASET_DIR\r\n",
    "\r\n",
    "    # --- Clean and Create Directories ---\r\n",
    "    if os.path.exists(dataset_root):\r\n",
    "        print(f\"🧹 Cleaning up existing directory: {dataset_root}\")\r\n",
    "        shutil.rmtree(dataset_root)\r\n",
    "    print(f\"📁 Creating directory structure in: {dataset_root}\")\r\n",
    "    os.makedirs(train_dir, exist_ok=True)\r\n",
    "    os.makedirs(val_dir, exist_ok=True)\r\n",
    "    os.makedirs(test_dir, exist_ok=True)\r\n",
    "\r\n",
    "    # --- Download and Load Datasets (as PIL Images) ---\r\n",
    "    print(f\"📥 Downloading/Verifying {dataset_name}...\")\r\n",
    "    archive_path = os.path.join(download_root, cfg[\"filename\"])\r\n",
    "    extract_path = os.path.join(download_root, cfg[\"base_folder\"])\r\n",
    "\r\n",
    "    # Manual download check and extraction logic (kept from original)\r\n",
    "    if not check_integrity(archive_path, cfg[\"tgz_md5\"]):\r\n",
    "         print(f\"File {cfg['filename']} not found or corrupt, downloading...\")\r\n",
    "         download_and_extract_archive(cfg[\"url\"], download_root, filename=cfg[\"filename\"], md5=cfg[\"tgz_md5\"])\r\n",
    "    else:\r\n",
    "         print(f\"File {cfg['filename']} found and verified.\")\r\n",
    "         if not os.path.exists(extract_path):\r\n",
    "             print(f\"Extracting {cfg['filename']}...\")\r\n",
    "             download_and_extract_archive(cfg[\"url\"], download_root, filename=cfg[\"filename\"], md5=cfg[\"tgz_md5\"])\r\n",
    "\r\n",
    "    # Load datasets\r\n",
    "    try:\r\n",
    "        ds_train_original = ds_class(root=download_root, train=True, download=False)\r\n",
    "        ds_test_original = ds_class(root=download_root, train=False, download=False)\r\n",
    "    except Exception as e:\r\n",
    "        print(f\"❌ Error loading dataset {dataset_name}. Check download/extraction at '{download_root}'.\")\r\n",
    "        print(f\"Error details: {e}\")\r\n",
    "        continue # Skip to next dataset\r\n",
    "\r\n",
    "    class_names = ds_train_original.classes\r\n",
    "    num_classes = len(class_names)\r\n",
    "    print(f\"ℹ️ Found {num_classes} classes for {dataset_name}. First few: {class_names[:5]}...\")\r\n",
    "\r\n",
    "    # --- Perform Balanced Train/Validation Split (Per Class) ---\r\n",
    "    print(f\"🔪 Splitting original training data ({len(ds_train_original)} samples) by class ({VAL_SPLIT_RATIO*100}% validation)...\")\r\n",
    "    targets = np.array(ds_train_original.targets) # Convert targets to numpy array\r\n",
    "    train_indices = []\r\n",
    "    val_indices = []\r\n",
    "    original_train_indices_by_class = [np.where(targets == i)[0] for i in range(num_classes)]\r\n",
    "\r\n",
    "    for i, indices_i in enumerate(original_train_indices_by_class):\r\n",
    "        n_samples_class = len(indices_i)\r\n",
    "        if n_samples_class == 0:\r\n",
    "            print(f\"  Warning: Class {i} ({class_names[i]}) has 0 samples.\")\r\n",
    "            continue\r\n",
    "\r\n",
    "        # Calculate number of validation samples for this class\r\n",
    "        # Ensure at least 1 sample is left for training if possible\r\n",
    "        n_val_class = int(round(n_samples_class * VAL_SPLIT_RATIO))\r\n",
    "        n_val_class = max(0, min(n_val_class, n_samples_class - 1)) # Ensure n_val is valid\r\n",
    "        if n_samples_class <= 1 : # Cannot split if 0 or 1 samples\r\n",
    "             n_val_class = 0\r\n",
    "        n_train_class = n_samples_class - n_val_class\r\n",
    "\r\n",
    "        # Shuffle indices for this class using the seeded RNG\r\n",
    "        shuffled_indices_i = rng.permutation(indices_i)\r\n",
    "\r\n",
    "        # Append indices to respective lists\r\n",
    "        val_indices.extend(shuffled_indices_i[:n_val_class])\r\n",
    "        train_indices.extend(shuffled_indices_i[n_val_class:])\r\n",
    "\r\n",
    "    num_train = len(train_indices)\r\n",
    "    num_val = len(val_indices)\r\n",
    "    num_test = len(ds_test_original)\r\n",
    "\r\n",
    "    print(f\"  ➡️ New Train set size: {num_train}\")\r\n",
    "    print(f\"  ➡️ Validation set size: {num_val} (split per class)\")\r\n",
    "    print(f\"  ➡️ Test set size: {num_test}\")\r\n",
    "\r\n",
    "\r\n",
    "    # --- Save Images to Folders ---\r\n",
    "    save_images_to_folders(ds_train_original, train_indices, train_dir, class_names, dataset_name, \"train\")\r\n",
    "    save_images_to_folders(ds_train_original, val_indices, val_dir, class_names, dataset_name, \"val\")\r\n",
    "    save_images_to_folders(ds_test_original, list(range(num_test)), test_dir, class_names, dataset_name, \"test\") # Test indices are 0..N-1\r\n",
    "\r\n",
    "    # --- Final Verification (Optional but Recommended) ---\r\n",
    "    print(f\"\\n📊 Final verification for {dataset_name}:\")\r\n",
    "    verify_saved_count(train_dir, num_train, \"Train\")\r\n",
    "    verify_saved_count(val_dir, num_val, \"Val\")\r\n",
    "    verify_saved_count(test_dir, num_test, \"Test\")\r\n",
    "\r\n",
    "    # Example: Check class balance in val set\r\n",
    "    val_targets_final = targets[val_indices]\r\n",
    "    val_class_counts = Counter(val_targets_final)\r\n",
    "    print(f\"  Validation set class counts (Top 5): {val_class_counts.most_common(5)}\")\r\n",
    "\r\n",
    "\r\n",
    "print(\"\\n✅ All datasets processed.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Define L2-ECE and ATB\r\n",
    "import torch\r\n",
    "import torch.nn as nn\r\n",
    "import torch.nn.functional as F\r\n",
    "from torch.nn.functional import softmax\r\n",
    "from torchmetrics.classification import MulticlassCalibrationError\r\n",
    "import torch.optim as optim\r\n",
    "from torchvision import datasets, transforms\r\n",
    "from torch.utils.data import DataLoader, Subset, random_split\r\n",
    "from torchmetrics.classification import MulticlassCalibrationError\r\n",
    "import timm\r\n",
    "import matplotlib.pyplot as plt\r\n",
    "import numpy as np\r\n",
    "from matplotlib import cm\r\n",
    "import torch.optim as optim\r\n",
    "from timm.data import create_transform\r\n",
    "from timm.scheduler import CosineLRScheduler\r\n",
    "from pytorchcv.model_provider import get_model\r\n",
    "import torch.nn.utils as nn_utils\r\n",
    "import itertools\r\n",
    "\r\n",
    "def ece_fix_base(confidences: torch.Tensor, correct: torch.Tensor, sq, n_bins) -> torch.Tensor:\r\n",
    "    assert len(confidences) == len(correct), \"Inputs must have the same length.\"\r\n",
    "    device = confidences.device\r\n",
    "\r\n",
    "    bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=device)\r\n",
    "    bin_lowers = bin_boundaries[:-1]\r\n",
    "    bin_uppers = bin_boundaries[1:]\r\n",
    "\r\n",
    "    ece = torch.zeros(1, device=device)\r\n",
    "    total_samples = len(confidences)\r\n",
    "\r\n",
    "    for i in range(n_bins):\r\n",
    "        in_bin = (confidences > bin_lowers[i]) & (confidences <= bin_uppers[i])\r\n",
    "        prop_in_bin = in_bin.float().mean() # The weight of the current bin (n_k / N)\r\n",
    "\r\n",
    "        if prop_in_bin.item() > 0:\r\n",
    "            accuracy_in_bin = correct[in_bin].float().mean()\r\n",
    "            avg_confidence_in_bin = confidences[in_bin].mean()\r\n",
    "            if not sq:\r\n",
    "                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin\r\n",
    "            else:\r\n",
    "                ece += (avg_confidence_in_bin - accuracy_in_bin).pow(2) * (prop_in_bin ** 2)\r\n",
    "\r\n",
    "    return ece\r\n",
    "\r\n",
    "def ece_quantile_base(confidences: torch.Tensor, correct: torch.Tensor, sq, n_bins) -> torch.Tensor:\r\n",
    "    \"\"\"\r\n",
    "    Vectorised ATB implementation (O(N log N) dominated by sort).\r\n",
    "\r\n",
    "    Args:\r\n",
    "        probs   : (N, C) soft‑max probabilities\r\n",
    "        targets : (N,)   ground‑truth labels (int64)\r\n",
    "\r\n",
    "    Returns:\r\n",
    "        scalar ATB value\r\n",
    "    \"\"\"\r\n",
    "    assert(len(confidences) == len(correct))\r\n",
    "    device, dtype = confidences.device, confidences.dtype\r\n",
    "\r\n",
    "    # 1. confidence & correctness -------------------------------------------\r\n",
    "    # confidences, preds = probs.max(dim=1)              # (N,)\r\n",
    "    # correct = preds.eq(targets).to(dtype)              # (N,)\r\n",
    "\r\n",
    "    # 2. sort by confidence (ascending, same as original code) --------------\r\n",
    "    conf_sorted, idx = confidences.sort()              # (N,)\r\n",
    "    acc_sorted  = correct[idx]                         # (N,)\r\n",
    "\r\n",
    "    bin_size = len(confidences) // n_bins\r\n",
    "    ece = torch.zeros(1, device=device, dtype=confidences.dtype).squeeze()\r\n",
    "\r\n",
    "    for i in range(n_bins):\r\n",
    "        start = i * bin_size\r\n",
    "        end = (i + 1) * bin_size if i < n_bins - 1 else len(confidences)\r\n",
    "\r\n",
    "        bin_conf = conf_sorted[start:end]\r\n",
    "        bin_acc = acc_sorted[start:end]\r\n",
    "\r\n",
    "        if len(bin_conf) == 0:\r\n",
    "            continue\r\n",
    "\r\n",
    "        avg_acc = bin_acc.mean()\r\n",
    "        avg_conf = bin_conf.mean()\r\n",
    "        bin_weight = len(bin_conf) / len(confidences)\r\n",
    "\r\n",
    "        # L2 version: (acc - conf)^2\r\n",
    "        if not sq:\r\n",
    "            ece += torch.abs(avg_acc - avg_conf) / n_bins\r\n",
    "        else:\r\n",
    "            ece += (avg_acc - avg_conf).pow(2) / (n_bins ** 2)\r\n",
    "\r\n",
    "    return ece\r\n",
    "\r\n",
    "def atb_base(confidences: torch.Tensor, correct: torch.Tensor) -> torch.Tensor:\r\n",
    "    \"\"\"\r\n",
    "    Vectorised ATB implementation (O(N log N) dominated by sort).\r\n",
    "\r\n",
    "    Args:\r\n",
    "        probs   : (N, C) soft‑max probabilities\r\n",
    "        targets : (N,)   ground‑truth labels (int64)\r\n",
    "\r\n",
    "    Returns:\r\n",
    "        scalar ATB value\r\n",
    "    \"\"\"\r\n",
    "    assert(len(confidences) == len(correct))\r\n",
    "    device, dtype = confidences.device, confidences.dtype\r\n",
    "\r\n",
    "    # 1. confidence & correctness -------------------------------------------\r\n",
    "    # confidences, preds = probs.max(dim=1)              # (N,)\r\n",
    "    # correct = preds.eq(targets).to(dtype)              # (N,)\r\n",
    "\r\n",
    "    # 2. sort by confidence (ascending, same as original code) --------------\r\n",
    "    conf_sorted, idx = confidences.sort()              # (N,)\r\n",
    "    acc_sorted  = correct[idx]                         # (N,)\r\n",
    "\r\n",
    "    # 3. per‑sample delta = p_i − 1_{correct_i} -----------------------------\r\n",
    "    delta = conf_sorted - acc_sorted                   # (N,)\r\n",
    "    total_delta = delta.sum()                          #  scalar\r\n",
    "\r\n",
    "    # prefix sums BEFORE adding delta[i] (len=N+1) --------------------------\r\n",
    "    # prefix_left[k]  = Σ_{j<k} delta[j]\r\n",
    "    prefix_left = torch.cat([\r\n",
    "        torch.zeros(1, device=device, dtype=dtype),\r\n",
    "        delta.cumsum(dim=0)\r\n",
    "    ])                                                 # (N+1,)\r\n",
    "\r\n",
    "    prefix_right = total_delta - prefix_left           # (N+1,)\r\n",
    "\r\n",
    "    # 4. interval widths -----------------------------------------------------\r\n",
    "    widths = torch.cat([\r\n",
    "        conf_sorted[:1],                               # [0, conf0]\r\n",
    "        conf_sorted[1:] - conf_sorted[:-1],            # (conf_{i} - conf_{i-1})\r\n",
    "        1.0 - conf_sorted[-1:]                         # [conf_{N-1}, 1]\r\n",
    "    ])                                                 # (N+1,)\r\n",
    "\r\n",
    "    # 5. integral ------------------------------------------------------------\r\n",
    "    atb = ((prefix_left**2 + prefix_right**2) * widths).sum()\r\n",
    "    atb = atb / (conf_sorted.size(0) ** 2)\r\n",
    "    return atb\r\n",
    "\r\n",
    "def classification_error(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    device, dtype = logits.device, logits.dtype\r\n",
    "    probs = F.softmax(logits, dim=1)\r\n",
    "    confidences, preds = probs.max(dim=1)\r\n",
    "    correct = preds.eq(targets).to(dtype)\r\n",
    "    return 1. - correct.sum() / len(correct)\r\n",
    "\r\n",
    "def squared_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    \"\"\"\r\n",
    "    Calculates the squared loss.\r\n",
    "    \"\"\"\r\n",
    "    probs = F.softmax(logits, dim=1)\r\n",
    "    # Create one-hot encoding of targets\r\n",
    "    targets_one_hot = F.one_hot(targets, num_classes=probs.size(1)).float().to(probs.device)\r\n",
    "    # Calculate squared error between probabilities and one-hot targets\r\n",
    "    squared_loss = F.mse_loss(probs, targets_one_hot, reduction='mean')\r\n",
    "    return squared_loss\r\n",
    "\r\n",
    "def spherical_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    \"\"\"\r\n",
    "    Calculates the spherical loss: p_i / ||p_i||_2, where i is the true label.\r\n",
    "    \"\"\"\r\n",
    "    probs = F.softmax(logits, dim=1)\r\n",
    "    # Get probabilities of the true class\r\n",
    "    true_class_probs = probs.gather(1, targets.unsqueeze(1)).squeeze(1)\r\n",
    "    # Calculate the L2 norm of the probability vector for each sample\r\n",
    "    probs_norm = torch.norm(probs, p=2, dim=1)\r\n",
    "    # Calculate the loss for each sample\r\n",
    "    loss_per_sample = true_class_probs / probs_norm\r\n",
    "    # The spherical loss is often defined as minimizing the negative of this value\r\n",
    "    # or maximizing this value. So we can return the negative mean.\r\n",
    "    return -torch.mean(loss_per_sample)\r\n",
    "\r\n",
    "\r\n",
    "def method_confidence(method, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    device, dtype = logits.device, logits.dtype\r\n",
    "    probs = F.softmax(logits, dim=1)\r\n",
    "    confidences, preds = probs.max(dim=1)\r\n",
    "    correct = preds.eq(targets).to(dtype)\r\n",
    "    return method(confidences, correct)\r\n",
    "def atb_confidence(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    return method_confidence(atb_base, logits, targets)\r\n",
    "def l1ecefix_confidence(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_confidence(partial(ece_fix_base, sq=False, n_bins=n_bins), logits, targets)\r\n",
    "def l1ecequa_confidence(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_confidence(partial(ece_quantile_base, sq=False, n_bins=n_bins), logits, targets)\r\n",
    "def l2ecefix_confidence(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_confidence(partial(ece_fix_base, sq=True, n_bins=n_bins), logits, targets)\r\n",
    "def l2ecequa_confidence(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_confidence(partial(ece_quantile_base, sq=True, n_bins=n_bins), logits, targets)\r\n",
    "\r\n",
    "\r\n",
    "def method_classwise(method, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    device, dtype = logits.device, logits.dtype\r\n",
    "    probs = F.softmax(logits, dim=1)\r\n",
    "    total = 0.0\r\n",
    "    for c in range(probs.size(1)):\r\n",
    "        confidences = probs[:, c]\r\n",
    "        correct = (targets == c).to(probs.dtype)\r\n",
    "        total = total + method(confidences, correct)\r\n",
    "    return total / probs.size(1)\r\n",
    "def atb_classwise(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\r\n",
    "    return method_classwise(atb_base, logits, targets)\r\n",
    "def l1ecefix_classwise(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_classwise(partial(ece_fix_base, sq=False, n_bins=n_bins), logits, targets)\r\n",
    "def l1ecequa_classwise(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_classwise(partial(ece_quantile_base, sq=False, n_bins=n_bins), logits, targets)\r\n",
    "def l2ecefix_classwise(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_classwise(partial(ece_fix_base, sq=True, n_bins=n_bins), logits, targets)\r\n",
    "def l2ecequa_classwise(logits: torch.Tensor, targets: torch.Tensor, n_bins) -> torch.Tensor:\r\n",
    "    return method_classwise(partial(ece_quantile_base, sq=True, n_bins=n_bins), logits, targets)\r\n",
    "\r\n",
    "\r\n",
    "def mmce_weighted(probs: torch.Tensor, targets: torch.Tensor, seed = None) -> torch.Tensor:\r\n",
    "    \"\"\"\r\n",
    "    Compute weighted MMCE (Maximum Mean Calibration Error).\r\n",
    "\r\n",
    "    Args:\r\n",
    "        probs (torch.Tensor): Tensor of shape (batch_size, num_classes), softmax outputs.\r\n",
    "        targets (torch.Tensor): Tensor of shape (batch_size,), ground-truth integer labels.\r\n",
    "        sigma (float): Width of the Laplacian kernel.\r\n",
    "\r\n",
    "    Returns:\r\n",
    "        torch.Tensor: Scalar tensor representing MMCE regularization loss.\r\n",
    "    \"\"\"\r\n",
    "\r\n",
    "    sigma = 0.4\r\n",
    "    # Step 1: Get predicted class and confidence\r\n",
    "    confs, preds = probs.max(dim=1)  # shape: (batch_size,)\r\n",
    "    correct = (preds == targets).float()  # shape: (batch_size,)\r\n",
    "\r\n",
    "    # Step 2: Split into correct (c=1) and incorrect (c=0) predictions\r\n",
    "    pos_mask = (correct == 1)\r\n",
    "    neg_mask = (correct == 0)\r\n",
    "\r\n",
    "    r_pos = confs[pos_mask]  # correct predictions\r\n",
    "    r_neg = confs[neg_mask]  # incorrect predictions\r\n",
    "\r\n",
    "    n_pos = len(r_pos)\r\n",
    "    n_neg = len(r_neg)\r\n",
    "\r\n",
    "    if n_pos == 0 or n_neg == 0:\r\n",
    "        # Degenerate case: skip MMCE penalty\r\n",
    "        return torch.tensor(0.0, device=probs.device)\r\n",
    "\r\n",
    "    # Step 3: Define Laplacian kernel function\r\n",
    "    def kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\r\n",
    "        return torch.exp(-torch.abs(x.unsqueeze(1) - y.unsqueeze(0)) / sigma)\r\n",
    "\r\n",
    "    # Step 4: Compute terms\r\n",
    "    # --- Correct-positive term\r\n",
    "    K_pos = kernel(r_pos, r_pos)\r\n",
    "    pos_term = ((1 - r_pos).unsqueeze(1) @ (1 - r_pos).unsqueeze(0)) * K_pos\r\n",
    "    pos_term = pos_term.sum() / (n_pos ** 2)\r\n",
    "\r\n",
    "    # --- Incorrect-negative term\r\n",
    "    K_neg = kernel(r_neg, r_neg)\r\n",
    "    neg_term = (r_neg.unsqueeze(1) @ r_neg.unsqueeze(0)) * K_neg\r\n",
    "    neg_term = neg_term.sum() / (n_neg ** 2)\r\n",
    "\r\n",
    "    # --- Cross term (between correct and incorrect)\r\n",
    "    K_cross = kernel(r_pos, r_neg)\r\n",
    "    cross_term = ((1 - r_pos).unsqueeze(1) @ r_neg.unsqueeze(0)) * K_cross\r\n",
    "    cross_term = 2 * cross_term.sum() / (n_pos * n_neg)\r\n",
    "\r\n",
    "    # Step 5: Combine terms and return sqrt(MMCE^2)\r\n",
    "    mmce_sq = pos_term + neg_term - cross_term\r\n",
    "    return mmce_sq.sqrt()\r\n",
    "\r\n",
    "\r\n",
    "def plot_reliability_diagrams( # Renamed function slightly for clarity\r\n",
    "    model,\r\n",
    "    loaders: list,\r\n",
    "    loader_names: list,\r\n",
    "    num_classes: int,\r\n",
    "    n_bins: int = 20,\r\n",
    "    device: str = \"cuda\",\r\n",
    "    cmap_name: str = \"Blues\",\r\n",
    "):\r\n",
    "    num_loaders = len(loaders)\r\n",
    "    if loader_names is None:\r\n",
    "        loader_names = [f\"Loader {i+1}\" for i in range(num_loaders)]\r\n",
    "    elif len(loader_names) != num_loaders:\r\n",
    "        raise ValueError(\"Length of loader_names must match the length of loaders.\")\r\n",
    "\r\n",
    "    # Create subplots: 1 row, num_loaders columns\r\n",
    "    # Adjust figsize based on the number of loaders for better horizontal layout\r\n",
    "    # `squeeze=False` ensures axes is always a 2D array, even if num_loaders is 1\r\n",
    "    fig, axes = plt.subplots(1, num_loaders, figsize=(num_loaders * 6, 5), squeeze=False)\r\n",
    "\r\n",
    "    model.eval()\r\n",
    "    model.to(device)\r\n",
    "\r\n",
    "    # Variables to store cmap and norm for the shared colorbar later\r\n",
    "    last_cmap = None\r\n",
    "    last_norm = None\r\n",
    "\r\n",
    "    for i, val_loader in enumerate(loaders):\r\n",
    "        ax = axes[0, i] # Get the current subplot axis\r\n",
    "\r\n",
    "\r\n",
    "        # 1. Gather probabilities, predictions, labels (for this loader)\r\n",
    "        probs, preds, labels = [], [], []\r\n",
    "        with torch.no_grad():\r\n",
    "            for x, y in val_loader:\r\n",
    "                x, y = x.to(device), y.to(device)\r\n",
    "                out = model(x)\r\n",
    "                p = F.softmax(out, dim=1)\r\n",
    "                probs.append(p)\r\n",
    "                preds.append(p.argmax(dim=1))\r\n",
    "                labels.append(y)\r\n",
    "\r\n",
    "        # Check if loader returned any data\r\n",
    "        if not probs:\r\n",
    "             print(f\"Warning: Loader '{loader_names[i]}' provided no data. Skipping plot.\")\r\n",
    "             ax.set_title(f\"{loader_names[i]}\\n(No data)\")\r\n",
    "             ax.text(0.5, 0.5, 'No data', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)\r\n",
    "             continue # Skip to the next loader\r\n",
    "\r\n",
    "        probs   = torch.cat(probs)\r\n",
    "        preds   = torch.cat(preds)\r\n",
    "        labels  = torch.cat(labels)\r\n",
    "        conf, _ = probs.max(dim=1)      # model confidence\r\n",
    "\r\n",
    "        # 2. Bin statistics (for this loader)\r\n",
    "        bin_edges = torch.linspace(0.0, 1.0, n_bins + 1, device=device)\r\n",
    "        lowers, uppers = bin_edges[:-1], bin_edges[1:]\r\n",
    "\r\n",
    "        accs, bin_counts = [], []\r\n",
    "        conf_in_bins = [] # Store average confidence in each bin for ECE calculation if needed manually\r\n",
    "        for lo, up in zip(lowers, uppers):\r\n",
    "            mask  = (conf >= lo) & (conf < up)\r\n",
    "            count = mask.sum().item()\r\n",
    "            bin_counts.append(count)\r\n",
    "            if count > 0:\r\n",
    "                acc = (preds[mask] == labels[mask]).float().mean().item()\r\n",
    "                avg_conf = conf[mask].mean().item()\r\n",
    "                accs.append(acc)\r\n",
    "                conf_in_bins.append(avg_conf)\r\n",
    "            else:\r\n",
    "                accs.append(0.0)\r\n",
    "                conf_in_bins.append(0.0) # Assign 0 confidence if bin is empty\r\n",
    "\r\n",
    "        # 3. Expected Calibration Error (for this loader)\r\n",
    "        try:\r\n",
    "            ece = MulticlassCalibrationError(num_classes=num_classes, n_bins=n_bins, norm=\"l1\").to(device)(\r\n",
    "                probs, labels\r\n",
    "            ).item()\r\n",
    "        except Exception as e:\r\n",
    "            print(f\"Warning: Could not calculate ECE for {loader_names[i]} using torchmetrics. Error: {e}\")\r\n",
    "            total_samples = probs.shape[0]\r\n",
    "            ece_manual = 0.0\r\n",
    "            for k in range(n_bins):\r\n",
    "                 if bin_counts[k] > 0:\r\n",
    "                      ece_manual += (bin_counts[k] / total_samples) * abs(accs[k] - conf_in_bins[k])\r\n",
    "            ece = ece_manual\r\n",
    "\r\n",
    "\r\n",
    "        # 4. Colour‑map based on normalised bin counts (for this loader)\r\n",
    "        counts_np   = np.asarray(bin_counts, dtype=float)\r\n",
    "        norm_counts = counts_np / counts_np.max() if counts_np.max() > 0 else counts_np\r\n",
    "        cmap        = cm.get_cmap(cmap_name)\r\n",
    "        bar_colors  = cmap(norm_counts)\r\n",
    "        last_cmap   = cmap # Store for shared colorbar\r\n",
    "        last_norm   = plt.Normalize(vmin=0, vmax=1) # Store for shared colorbar\r\n",
    "\r\n",
    "        # 5. Plot\r\n",
    "        bin_centers = (lowers + uppers) / 2\r\n",
    "\r\n",
    "        ax.bar(\r\n",
    "            bin_centers.cpu().numpy(),\r\n",
    "            accs,\r\n",
    "            width=1 / n_bins,\r\n",
    "            edgecolor=\"black\",\r\n",
    "            align=\"center\",\r\n",
    "            color=bar_colors,\r\n",
    "        )\r\n",
    "        ax.plot([0, 1], [0, 1], \"k--\", label=\"Perfect calibration\")\r\n",
    "        ax.set_xlabel(\"Confidence\")\r\n",
    "        ax.set_ylabel(\"Accuracy\")\r\n",
    "        ax.set_title(f\"{loader_names[i]}\\nECE = {ece:.4f}\")\r\n",
    "        ax.grid(True, linestyle=\":\")\r\n",
    "        ax.legend()\r\n",
    "        ax.set_xlim(0, 1)\r\n",
    "        ax.set_ylim(0, 1)\r\n",
    "\r\n",
    "    # 6. Add a single colorbar for the whole figure\r\n",
    "    if last_cmap is not None and last_norm is not None:\r\n",
    "        sm = cm.ScalarMappable(cmap=last_cmap, norm=last_norm)\r\n",
    "        sm.set_array([])\r\n",
    "        fig.colorbar(sm, ax=axes.ravel().tolist(), label=\"Relative # samples per bin\", pad=0.02, aspect=30, shrink=0.7)\r\n",
    "    else:\r\n",
    "        print(\"Warning: No data plotted, cannot add colorbar.\")\r\n",
    "\r\n",
    "    fig.tight_layout(rect=[0, 0, 1, 0.96])\r\n",
    "    plt.suptitle(\"Reliability Diagrams\", fontsize=16, y=0.99)\r\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Print Info\r\n",
    "def print_info(model, loader, prefix = \"\"):\r\n",
    "    model.eval()\r\n",
    "    correct, su = 0, 0\r\n",
    "    all_logits, all_labels = [], []\r\n",
    "\r\n",
    "    with torch.no_grad():\r\n",
    "        for x, y in loader:\r\n",
    "            x, y = x.to(device), y.to(device)\r\n",
    "            logits = model(x)\r\n",
    "            preds = logits.argmax(dim=1)\r\n",
    "\r\n",
    "            correct += (preds == y).sum().item()\r\n",
    "            su += len(x)\r\n",
    "\r\n",
    "            all_logits.append(logits)\r\n",
    "            all_labels.append(y)\r\n",
    "\r\n",
    "    all_logits = torch.cat(all_logits, dim=0)\r\n",
    "    all_labels = torch.cat(all_labels, dim=0)\r\n",
    "\r\n",
    "    accuracy = correct / su\r\n",
    "    all_probs = F.softmax(all_logits, dim=1)\r\n",
    "    log_loss = F.cross_entropy(all_logits, all_labels).item()\r\n",
    "\r\n",
    "    print(f\"        {prefix} Accuracy: {accuracy:.4f} | log_loss: {log_loss:.4f}\")\r\n",
    "\r\n",
    "import torch\r\n",
    "import torch.nn.functional as F\r\n",
    "import torchvision.models\r\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Training {\"form-width\":\"20%\"}\r\n",
    "\r\n",
    "def initialize_weights(model):\r\n",
    "    for m in model.modules():\r\n",
    "        if isinstance(m, nn.Conv2d):\r\n",
    "            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\r\n",
    "            if m.bias is not None:\r\n",
    "                nn.init.zeros_(m.bias)\r\n",
    "\r\n",
    "        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\r\n",
    "            nn.init.ones_(m.weight)\r\n",
    "            nn.init.zeros_(m.bias)\r\n",
    "\r\n",
    "        elif isinstance(m, nn.Linear):\r\n",
    "            nn.init.normal_(m.weight, mean=0, std=0.001)\r\n",
    "            if m.bias is not None:\r\n",
    "                nn.init.zeros_(m.bias)\r\n",
    "\r\n",
    "        elif isinstance(m, nn.LayerNorm):\r\n",
    "            nn.init.ones_(m.weight)\r\n",
    "            nn.init.zeros_(m.bias)\r\n",
    "\r\n",
    "def adapt_resnet_for_cifar(model: nn.Module):\r\n",
    "        \"\"\"\r\n",
    "        Modify ResNet backbone to work better with CIFAR (32x32 input).\r\n",
    "        - Replace conv1: 7x7 stride=2 -> 3x3 stride=1\r\n",
    "        - Remove maxpool\r\n",
    "        \"\"\"\r\n",
    "        if hasattr(model, 'conv1') and isinstance(model.conv1, nn.Conv2d):\r\n",
    "            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\r\n",
    "        if hasattr(model, 'maxpool'):\r\n",
    "            model.maxpool = nn.Identity()\r\n",
    "        return model\r\n",
    "\r\n",
    "def adapt_model_for_num_classes(model, num_classes, device=None):\r\n",
    "    layer_candidates = ['fc', 'output', 'linear', 'classifier']\r\n",
    "    replaced = False\r\n",
    "\r\n",
    "    for name in layer_candidates:\r\n",
    "        layer = getattr(model, name, None)\r\n",
    "\r\n",
    "        if isinstance(layer, nn.Sequential):\r\n",
    "            for i in range(len(layer) - 1, -1, -1):\r\n",
    "                sub = layer[i]\r\n",
    "                if isinstance(sub, nn.Linear) and sub.out_features != num_classes:\r\n",
    "                    new_fc = nn.Linear(sub.in_features, num_classes)\r\n",
    "                    nn.init.normal_(new_fc.weight, mean=0, std=0.001)\r\n",
    "                    if new_fc.bias is not None: nn.init.zeros_(new_fc.bias)\r\n",
    "                    layer[i] = new_fc.to(device) if device else new_fc\r\n",
    "                    print(\"Replacement\")\r\n",
    "                    replaced = True\r\n",
    "                    break\r\n",
    "\r\n",
    "        elif isinstance(layer, nn.Linear) and layer.out_features != num_classes:\r\n",
    "            new_fc = nn.Linear(layer.in_features, num_classes)\r\n",
    "            nn.init.normal_(new_fc.weight, mean=0, std=0.001)\r\n",
    "            if new_fc.bias is not None: nn.init.zeros_(new_fc.bias)\r\n",
    "            setattr(model, name, new_fc.to(device) if device else new_fc)\r\n",
    "            print(\"Replacement\")\r\n",
    "            replaced = True\r\n",
    "            break\r\n",
    "\r\n",
    "    print(\r\n",
    "        f\"✅ Output layer {'replaced' if replaced else 'already matches'} \"\r\n",
    "        f\"num_classes = {num_classes}\"\r\n",
    "    )\r\n",
    "\r\n",
    "    # initialize_weights(model)\r\n",
    "    return model\r\n",
    "\r\n",
    "def findword(text: str, word: str) -> str:\r\n",
    "    \"\"\"\r\n",
    "    Splits the text by underscores, searches for a word in each part, and returns\r\n",
    "    the first part where the word is found.\r\n",
    "    \"\"\"\r\n",
    "    for part in text.split('_'):\r\n",
    "        if word in part:\r\n",
    "            return part\r\n",
    "    return \"\"\r\n",
    "\r\n",
    "# Model and Dataset\r\n",
    "\r\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n",
    "model_name = \"resnet101\"  # @param {type:\"string\"}\r\n",
    "# \"mobilenetv3_small_100\" \"resnet10t\" \"resnet18\" \"resnet34\" \"resnet50\" \"resnetv2_50x1_bit\"\r\n",
    "dataset_name = \"cifar-100\"  # @param [\"tiny-imagenet-200\",\"cifar-10\",\"cifar-100\"]\r\n",
    "num_classes = {\"tiny-imagenet-200\": 200, \"cifar-10\": 10, \"cifar-100\": 100}[dataset_name]\r\n",
    "tmpadd = \"\" # @param\r\n",
    "sch_choice = \"cos\" #param [\"cos\",\"step\"]\r\n",
    "param_decay = {\"nodecay\":0.,\"smalldecay\":0.2,\"\":1.,\"largedecay\":5.,\"morelargedecay\":10.,\"extralargedecay\":25.}[findword(tmpadd,\"decay\")]\r\n",
    "param_epochs = {\"\":1, \"1xepochs\":1, \"2xepochs\":2, }[findword(tmpadd,\"epochs\")]\r\n",
    "regularizer_name = \"None\" # @param\r\n",
    "regularizer = {\"None\": None,\r\n",
    "               \"None2\": None,\r\n",
    "               \"MMCE\": mmce_weighted,\r\n",
    "               }[regularizer_name]\r\n",
    "experimental = \"\" # @param [\"\", \"_val\", \"_half\", \"_both\"]\r\n",
    "pretrained = True # @param {type:\"boolean\"}\r\n",
    "finetune_resize = True # @param {type:\"boolean\"} # resize 32*32 to 224*224, when using models rather than resnet, set True\r\n",
    "# -1 for latest, 0 for not use, x for epoch = x\r\n",
    "use_checkpoint_epoch = -1 # @param {type:\"integer\"}\r\n",
    "ignore_no_checkpoint = True\r\n",
    "pix = 32 if model_name == \"resnet32\" or model_name == \"resnet110\" else 224\r\n",
    "\r\n",
    "\r\n",
    "pretrained_version_index = 0 # @param {type:\"integer\"}\r\n",
    "model_config = timm.models.get_pretrained_cfg(model_name)\r\n",
    "default_version_name = model_config.tag\r\n",
    "pretrained_versions_list = timm.list_models(f'{model_name}.*', pretrained=True)\r\n",
    "if default_version_name and default_version_name in pretrained_versions_list:\r\n",
    "    pretrained_versions_list.remove(default_version_name)\r\n",
    "    pretrained_versions_list.insert(0, default_version_name)\r\n",
    "print(\"✅ Available pretrained versions (default is at index [0]):\")\r\n",
    "for i, version in enumerate(pretrained_versions_list):\r\n",
    "    print(f\"[{i}] {version}\")\r\n",
    "selected_model_name = model_name # Default fallback\r\n",
    "if pretrained:\r\n",
    "    if 0 <= pretrained_version_index < len(pretrained_versions_list):\r\n",
    "        selected_model_name = pretrained_versions_list[pretrained_version_index]\r\n",
    "        print(f\"\\n✅ Using selected pretrained version: {selected_model_name}\")\r\n",
    "    else:\r\n",
    "        raise IndexError(f\"Index {pretrained_version_index} is out of bounds for the list of available versions.\")\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "# Init Data\r\n",
    "\r\n",
    "dataset_mean_std = {\r\n",
    "    \"cifar-10\": ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\r\n",
    "    \"cifar-100\": ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\r\n",
    "}\r\n",
    "mean, std = dataset_mean_std[dataset_name]\r\n",
    "if not finetune_resize:\r\n",
    "    train_transform = transforms.Compose([\r\n",
    "        transforms.RandomCrop(32, padding=4),\r\n",
    "        transforms.RandomHorizontalFlip(),\r\n",
    "        transforms.ToTensor(),\r\n",
    "        transforms.Normalize(mean, std),\r\n",
    "    ])\r\n",
    "    val_transform = transforms.Compose([\r\n",
    "        transforms.ToTensor(),\r\n",
    "        transforms.Normalize(mean, std),\r\n",
    "    ])\r\n",
    "else:\r\n",
    "    train_transform = transforms.Compose([\r\n",
    "        transforms.Resize(pix),\r\n",
    "        transforms.RandomCrop(pix, padding=4),\r\n",
    "        transforms.RandomHorizontalFlip(),\r\n",
    "        # transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),\r\n",
    "        transforms.ToTensor(),\r\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\r\n",
    "    ])\r\n",
    "    val_transform = transforms.Compose([\r\n",
    "        transforms.Resize((pix, pix)),\r\n",
    "        transforms.ToTensor(),\r\n",
    "        transforms.Normalize(mean, std),\r\n",
    "    ])\r\n",
    "\r\n",
    "# Init model\r\n",
    "if model_name == \"resnet32\":\r\n",
    "    model = resnet32().to(device)\r\n",
    "    model = adapt_model_for_num_classes(model, num_classes, device=device)\r\n",
    "elif model_name == \"resnet110\":\r\n",
    "    model = resnet110().to(device)\r\n",
    "    model = adapt_model_for_num_classes(model, num_classes, device=device)\r\n",
    "elif not finetune_resize:\r\n",
    "    model = timm.create_model(selected_model_name, pretrained=pretrained, num_classes=num_classes)\r\n",
    "    model = adapt_resnet_for_cifar(model).cuda()\r\n",
    "    model = model.to(device)\r\n",
    "else:\r\n",
    "    model = timm.create_model(selected_model_name, pretrained=pretrained).to(device)\r\n",
    "    model = adapt_model_for_num_classes(model, num_classes, device=device)\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "# Hyperparams\r\n",
    "\r\n",
    "epochs = 300 if not pretrained else 50 * param_epochs\r\n",
    "batch_size = 256\r\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-1 if not pretrained else 1e-2, momentum=0.9, weight_decay=5e-4 * param_decay)\r\n",
    "if sch_choice == \"step\":\r\n",
    "    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epochs // 2, epochs * 3 // 4], gamma=0.1)\r\n",
    "elif sch_choice == \"cos\":\r\n",
    "    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)\r\n",
    "else:\r\n",
    "    raise(Exception)\r\n",
    "\r\n",
    "# Dataset Setting\r\n",
    "\r\n",
    "from PIL import Image\r\n",
    "from torch.utils.data import Dataset, DataLoader\r\n",
    "from torchvision import datasets, transforms\r\n",
    "\r\n",
    "class CorrectedInMemoryDataset(Dataset):\r\n",
    "    \"\"\"\r\n",
    "    Corrected version of the in-memory dataset.\r\n",
    "    It pre-loads images as PIL objects to avoid inefficient T->P->T conversion.\r\n",
    "    \"\"\"\r\n",
    "    def __init__(self, root_dir, transform=None):\r\n",
    "        # Find all image paths and labels using ImageFolder\r\n",
    "        image_folder = datasets.ImageFolder(root_dir)\r\n",
    "        self.transform = transform\r\n",
    "\r\n",
    "        print(f\"✅ Pre-loading all images from {root_dir} into memory as PIL.Image objects...\")\r\n",
    "\r\n",
    "        # Pre-load everything as a list of PIL Images, which is what the transform pipeline expects.\r\n",
    "        # This is the correct way to do in-memory loading for this use case.\r\n",
    "        self.images = [Image.open(p[0]).convert('RGB') for p in image_folder.samples]\r\n",
    "        self.labels = [label for _, label in image_folder.samples]\r\n",
    "        print(\"✅ Pre-loading complete.\")\r\n",
    "\r\n",
    "    def __len__(self):\r\n",
    "        return len(self.labels)\r\n",
    "\r\n",
    "    def __getitem__(self, idx):\r\n",
    "        # Get the pre-loaded PIL image\r\n",
    "        image = self.images[idx]\r\n",
    "\r\n",
    "        # Apply transforms (which expect a PIL image)\r\n",
    "        if self.transform:\r\n",
    "            image = self.transform(image)\r\n",
    "\r\n",
    "        return image, self.labels[idx]\r\n",
    "\r\n",
    "train_dataset = CorrectedInMemoryDataset(\r\n",
    "    os.path.join(dataset_dir, f\"{dataset_name}/train\"),\r\n",
    "    transform=train_transform\r\n",
    ")\r\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, pin_memory=True)\r\n",
    "\r\n",
    "val_dataset = datasets.ImageFolder(os.path.join(dataset_dir, f\"{dataset_name}/val\"), transform=val_transform)\r\n",
    "val_loader = DataLoader(val_dataset, batch_size=250, shuffle=False, num_workers=8)\r\n",
    "\r\n",
    "test_dataset = datasets.ImageFolder(os.path.join(dataset_dir, f\"{dataset_name}/test\"), transform=val_transform)\r\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)\r\n",
    "\r\n",
    "# train_dataset = datasets.ImageFolder(os.path.join(dataset_dir,f\"{dataset_name}/train\"), transform=train_transform)\r\n",
    "# val_dataset = datasets.ImageFolder(os.path.join(dataset_dir,f\"{dataset_name}/val\"), transform=val_transform)\r\n",
    "# test_dataset = datasets.ImageFolder(os.path.join(dataset_dir,f\"{dataset_name}/test\"), transform=val_transform)\r\n",
    "# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=12, pin_memory=True, prefetch_factor=4)\r\n",
    "# val_loader = DataLoader(val_dataset, batch_size=250, shuffle=False, drop_last=True, num_workers=8)\r\n",
    "# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)\r\n",
    "\r\n",
    "this_checkpoint_dir = os.path.join(checkpoint_dir, dataset_name)\r\n",
    "model_suffix = ''\r\n",
    "if pretrained: model_suffix = 'PT' + ('' if finetune_resize else 'T') + ('' if pretrained_version_index == 0 else f'{pretrained_version_index}')\r\n",
    "this_checkpoint_dir = os.path.join(this_checkpoint_dir, model_name + model_suffix + \"_\" + sch_choice + \"_\" + regularizer_name)\r\n",
    "this_checkpoint_dir = os.path.join(this_checkpoint_dir, f\"variant{experimental + tmpadd}\")\r\n",
    "\r\n",
    "os.makedirs(this_checkpoint_dir, exist_ok=True)\r\n",
    "print(f\"✅ Checkpoint will be saved at {this_checkpoint_dir}.\")\r\n",
    "lam = 1.0\r\n",
    "if use_checkpoint_epoch == -1: use_checkpoint_epoch = epochs\r\n",
    "if use_checkpoint_epoch > 0:\r\n",
    "    checkpoint_path = os.path.join(this_checkpoint_dir, f\"model_epoch_{use_checkpoint_epoch}.pth\")\r\n",
    "    if not os.path.exists(checkpoint_path):\r\n",
    "        print(f\"❌ Model weights checkpoint not found: {checkpoint_path}\")\r\n",
    "        if not ignore_no_checkpoint:\r\n",
    "            raise Exception(f\"Model weights checkpoint not found: {checkpoint_path}\")\r\n",
    "        use_checkpoint_epoch = 0\r\n",
    "    else:\r\n",
    "        print(f\"🔄 Loading model weights from checkpoint: {checkpoint_path}\")\r\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\r\n",
    "\r\n",
    "        model.load_state_dict(checkpoint['model_state_dict'])\r\n",
    "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\r\n",
    "        if checkpoint.get('scheduler_state_dict') is not None:\r\n",
    "            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\r\n",
    "\r\n",
    "        lam = checkpoint['lambda']\r\n",
    "\r\n",
    "        print(f\"✅ Checkpoint loaded successfully from epoch {checkpoint['epoch']}.\")\r\n",
    "\r\n",
    "scaler = torch.cuda.amp.GradScaler()\r\n",
    "\r\n",
    "not_train_debug = True\r\n",
    "if not_train_debug: epochs = use_checkpoint_epoch\r\n",
    "\r\n",
    "for epoch in range(use_checkpoint_epoch + 1, epochs + 1):\r\n",
    "    epoch_start_time = time.time()\r\n",
    "    computation_time = 0.\r\n",
    "    model.train()\r\n",
    "    sum_train_loss, sum_train_reg = 0., 0.\r\n",
    "    total_confidence = 0.0\r\n",
    "    total_samples = 0\r\n",
    "    reg_iter = itertools.cycle(val_loader)\r\n",
    "    seed = epoch\r\n",
    "\r\n",
    "    for (x, y) in train_loader:\r\n",
    "        computation_start_time = time.time()\r\n",
    "        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\r\n",
    "        optimizer.zero_grad()\r\n",
    "        with torch.cuda.amp.autocast():\r\n",
    "            logits = model(x)\r\n",
    "            probs = F.softmax(logits, dim=1)\r\n",
    "            confs, _ = probs.max(dim=1)\r\n",
    "            total_confidence += confs.sum().item()\r\n",
    "            total_samples += x.size(0)\r\n",
    "\r\n",
    "            if regularizer is None:\r\n",
    "                train_loss = F.cross_entropy(logits, y)\r\n",
    "                train_reg = torch.tensor(0.0, device=device)\r\n",
    "            elif experimental == \"\":\r\n",
    "                train_loss = F.cross_entropy(logits, y)\r\n",
    "                train_reg = regularizer(logits, y, seed)\r\n",
    "            elif experimental == \"_half\":\r\n",
    "                mid_point = x.size(0) // 2\r\n",
    "                logits1, logits2 = logits[:mid_point], logits[mid_point:]\r\n",
    "                y1, y2 = y[:mid_point], y[mid_point:]\r\n",
    "                train_loss = F.cross_entropy(logits1, y1)\r\n",
    "                train_reg = regularizer(logits2, y2, seed)\r\n",
    "            elif experimental == \"_val\":\r\n",
    "                reg_x, reg_y = next(reg_iter)\r\n",
    "                reg_x, reg_y = reg_x.to(device, non_blocking=True), reg_y.to(device, non_blocking=True)\r\n",
    "                reg_logits = model(reg_x)\r\n",
    "                train_loss = F.cross_entropy(logits, y)\r\n",
    "                train_reg = regularizer(reg_logits, reg_y, seed)\r\n",
    "            elif experimental == \"_both\":\r\n",
    "                reg_x, reg_y = next(reg_iter)\r\n",
    "                reg_x, reg_y = reg_x.to(device, non_blocking=True), reg_y.to(device, non_blocking=True)\r\n",
    "                reg_logits = model(reg_x)\r\n",
    "                combined_logits = torch.cat((logits, reg_logits), dim=0)\r\n",
    "                combined_y = torch.cat((y, reg_y), dim=0)\r\n",
    "                train_loss = F.cross_entropy(logits, y)\r\n",
    "                train_reg = regularizer(combined_logits, combined_y, seed)\r\n",
    "            else:\r\n",
    "                raise(Exception)\r\n",
    "\r\n",
    "            reg_scale = min(epoch * 2.5, batch_size * 0.25) if \"atb\" in regularizer_name.lower() else 1.\r\n",
    "            loss = train_loss + reg_scale * train_reg\r\n",
    "\r\n",
    "        scaler.scale(loss).backward()\r\n",
    "        scaler.unscale_(optimizer)\r\n",
    "        # nn_utils.clip_grad_norm_(model.parameters(), 1.0)\r\n",
    "        scaler.step(optimizer)\r\n",
    "        scaler.update()\r\n",
    "        computation_time += time.time() - computation_start_time\r\n",
    "\r\n",
    "        sum_train_loss += train_loss.detach() * x.size(0)\r\n",
    "        sum_train_reg += train_reg.detach() * x.size(0)\r\n",
    "\r\n",
    "    scheduler.step(epoch)\r\n",
    "    print(f\"[Epoch {epoch}] | Train Loss: {(sum_train_loss / total_samples):.6f} | \"\r\n",
    "          f\"Train Reg: {(sum_train_reg / total_samples):.6f} | \"\r\n",
    "          f\"Avg Conf: {total_confidence / total_samples:.4f} | \"\r\n",
    "          f\"Full T: {time.time() - epoch_start_time:.2f}s | \"\r\n",
    "          f\"Compute T: {computation_time:.2f}s\")\r\n",
    "\r\n",
    "    model.eval()\r\n",
    "    with torch.no_grad():\r\n",
    "        if epoch % 10 == 11:\r\n",
    "            print_info(model, train_loader, \"Train\")\r\n",
    "            print_info(model, val_loader, \"Val  \")\r\n",
    "            print_info(model, test_loader, \"Test \")\r\n",
    "            plot_reliability_diagrams(model, [train_loader, val_loader, test_loader], [\"Train\", \"Val\", \"Test\"], num_classes=num_classes, n_bins=20, device=device)\r\n",
    "\r\n",
    "        if epoch % 1 == 0:\r\n",
    "            checkpoint_path = os.path.join(this_checkpoint_dir, f\"model_epoch_{epoch}.pth\")\r\n",
    "            torch.save({\r\n",
    "                'epoch': epoch,\r\n",
    "                'model_state_dict': model.state_dict(),\r\n",
    "                'optimizer_state_dict': optimizer.state_dict(),\r\n",
    "                'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,\r\n",
    "                'lambda': lam\r\n",
    "            }, checkpoint_path)\r\n",
    "            print(f\"💾 Saved checkpoint: {checkpoint_path}\")\r\n",
    "\r\n",
    "        if epoch == epochs:\r\n",
    "            print_info(model, train_loader, \"Train\")\r\n",
    "            print_info(model, val_loader, \"Val  \")\r\n",
    "            print_info(model, test_loader, \"Test \")\r\n",
    "            plot_reliability_diagrams(model, [train_loader, val_loader, test_loader], [\"Train\", \"Val\", \"Test\"], num_classes=num_classes, n_bins=20, device=device)\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "checkpoint_path = os.path.join(this_checkpoint_dir, f\"model.pth\")\r\n",
    "torch.save(model.state_dict(), checkpoint_path)\r\n",
    "print(\"✅ Model saved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print_info(model, train_loader, \"Train\")\r\n",
    "print_info(model, val_loader, \"Val  \")\r\n",
    "print_info(model, test_loader, \"Test \")\r\n",
    "plot_reliability_diagrams(model, [train_loader, val_loader, test_loader], [\"Train\", \"Val\", \"Test\"], num_classes=num_classes, n_bins=20, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TemperatureScaler(nn.Module):\r\n",
    "    def __init__(self, model, loss_fn=None):\r\n",
    "        super().__init__()\r\n",
    "        self.model = model\r\n",
    "        self.temperature = None\r\n",
    "        self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss()\r\n",
    "\r\n",
    "    def forward(self, x):\r\n",
    "        logits = self.model(x)\r\n",
    "        return logits / self.temperature\r\n",
    "\r\n",
    "    def set_temperature(self, loader, device, guess = None):\r\n",
    "        self.eval()\r\n",
    "        self.model.eval()\r\n",
    "\r\n",
    "        logits_list, labels_list = [], []\r\n",
    "        with torch.no_grad():\r\n",
    "            for x, y in loader:\r\n",
    "                x, y = x.to(device), y.to(device)\r\n",
    "                logits = self.model(x)\r\n",
    "                logits_list.append(logits)\r\n",
    "                labels_list.append(y)\r\n",
    "\r\n",
    "        logits = torch.cat(logits_list)\r\n",
    "        labels = torch.cat(labels_list)\r\n",
    "\r\n",
    "        initial_temps = [0.3, 0.55, 0.75, 1., 1.25, 1.5, 1.75, 2., 2.5, 3., 4.] if guess is None else [guess * 0.75, guess, guess * 1.33]\r\n",
    "\r\n",
    "        best_loss = float('inf')\r\n",
    "        best_temp = 1.0\r\n",
    "\r\n",
    "        for start_temp in initial_temps:\r\n",
    "\r\n",
    "            temp_c = torch.nn.Parameter(torch.tensor([start_temp], device=logits.device))\r\n",
    "\r\n",
    "            optimizer = optim.LBFGS([temp_c], lr=0.05, max_iter=500)\r\n",
    "            base = self.loss_fn(logits, labels)\r\n",
    "\r\n",
    "            def eval_loss():\r\n",
    "                optimizer.zero_grad()\r\n",
    "                scaled_logits = logits / torch.clamp(temp_c, min=1e-3)\r\n",
    "                loss = self.loss_fn(scaled_logits, labels) / base\r\n",
    "                loss.backward()\r\n",
    "                return loss\r\n",
    "\r\n",
    "            optimizer.step(eval_loss)\r\n",
    "\r\n",
    "            final_loss = eval_loss().item()\r\n",
    "            final_temp = temp_c.item()\r\n",
    "\r\n",
    "            print(f\"  -> Finished. Final Temp: {final_temp:.4f}, Final Loss: {final_loss:.4f}\")\r\n",
    "\r\n",
    "            if final_loss < best_loss:\r\n",
    "                best_loss = final_loss\r\n",
    "                best_temp = final_temp\r\n",
    "                print(f\"  ** New best temperature found: {best_temp:.4f}\")\r\n",
    "\r\n",
    "        self.temperature = nn.Parameter(torch.tensor([best_temp], device=device))\r\n",
    "        print(f\"✅ Optimal temperature: {self.temperature.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.gridspec as gridspec\r\n",
    "import matplotlib.lines as mlines\r\n",
    "import matplotlib.ticker as ticker\r\n",
    "import statsmodels.api as sm\r\n",
    "import re\r\n",
    "\r\n",
    "def plot_multiple_scatter(outs, title, prefix=0, large=False, s=10, plot_key_setting=None, font_scale=2.0, x_axis_option = \"log_loss\", paint_dom = True):\r\n",
    "    \"\"\"\r\n",
    "    Draws multiple scatter plots using GridSpec for a custom layout.\r\n",
    "\r\n",
    "    This function creates a dedicated subplot at the top for a horizontal colorbar\r\n",
    "    and displays a main title for the entire figure.\r\n",
    "\r\n",
    "    Args:\r\n",
    "        mp (dict): A dict mapping feature names (including log_loss) to lists of y-values.\r\n",
    "        title (str): The main title for the entire figure.\r\n",
    "        prefix (int): Starting index offset for plotting.\r\n",
    "        large (bool): If True, use a single-column layout for larger plots.\r\n",
    "        s (int): The size of the scatter plot points.\r\n",
    "        font_scale (float): Factor to scale font sizes.\r\n",
    "    \"\"\"\r\n",
    "    colbar = True\r\n",
    "    rename_yaxis = None\r\n",
    "    rename_xaxis = None\r\n",
    "    rename_title = None\r\n",
    "    if plot_key_setting is None:\r\n",
    "        num_plots = 5 * 4 + 2\r\n",
    "        cols = 4\r\n",
    "        scatter_rows = 6\r\n",
    "        plot_keys = {\"log_loss\": (0,0),\r\n",
    "                \"classification_error\": (0,1),\r\n",
    "                \"squared_loss\": (0,2),\r\n",
    "                \"spherical_loss\": (0,3),\r\n",
    "\r\n",
    "                \"l1ecequantile_confidence_2000\": (1,0),\r\n",
    "                \"l1ecequantile_confidence_20\": (1,1),\r\n",
    "                \"l1ecequantile_classwise_2000\": (1,2),\r\n",
    "                \"l1ecequantile_classwise_20\": (1,3),\r\n",
    "                \"l2ecequantile_confidence_2000\": (2,0),\r\n",
    "                \"l2ecequantile_confidence_20\": (2,1),\r\n",
    "                \"l2ecequantile_classwise_2000\": (2,2),\r\n",
    "                \"l2ecequantile_classwise_20\": (2,3),\r\n",
    "\r\n",
    "                \"l1ecefix_confidence_2000\": (3,0),\r\n",
    "                \"l1ecefix_confidence_20\": (3,1),\r\n",
    "                \"l1ecefix_classwise_2000\": (3,2),\r\n",
    "                \"l1ecefix_classwise_20\": (3,3),\r\n",
    "                \"l2ecefix_confidence_2000\": (4,0),\r\n",
    "                \"l2ecefix_confidence_20\": (4,1),\r\n",
    "                \"l2ecefix_classwise_2000\": (4,2),\r\n",
    "                \"l2ecefix_classwise_20\": (4,3),\r\n",
    "\r\n",
    "                \"atb_confidence\": (5,0),\r\n",
    "                \"legend\": (5,1),\r\n",
    "                \"atb_classwise_full\": (5,2),\r\n",
    "\r\n",
    "                }\r\n",
    "    elif plot_key_setting == \"l1_subset\":\r\n",
    "        cols = 6\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l1ecequantile_confidence_2000\": (0,0),\r\n",
    "            \"l1ecequantile_confidence_20\": (0,1),\r\n",
    "            \"l1ecequantile_classwise_2000\": (0,2),\r\n",
    "            \"l1ecequantile_classwise_20\": (0,3),\r\n",
    "            \"atb_confidence\": (0,4),\r\n",
    "            \"atb_classwise_full\": (0,5),\r\n",
    "        }\r\n",
    "    elif plot_key_setting == \"l2_subset\":\r\n",
    "        cols = 6\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l2ecequantile_confidence_2000\": (0,0),\r\n",
    "            \"l2ecequantile_confidence_20\": (0,1),\r\n",
    "            \"l2ecequantile_classwise_2000\": (0,2),\r\n",
    "            \"l2ecequantile_classwise_20\": (0,3),\r\n",
    "            \"atb_confidence\": (0,4),\r\n",
    "            \"atb_classwise_full\": (0,5),\r\n",
    "        }\r\n",
    "    elif plot_key_setting == \"paper_loss\":\r\n",
    "        cols = 3\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"classification_error\": (0,0),\r\n",
    "            \"squared_loss\": (0,1),\r\n",
    "            \"spherical_loss\": (0,2),\r\n",
    "        }\r\n",
    "        rename_title = {\r\n",
    "            \"classification_error\": \"Classification Error\",\r\n",
    "            \"squared_loss\": \"Squared Loss\",\r\n",
    "            \"spherical_loss\": \"Spherical Loss\",\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_yaxis = {\r\n",
    "            \"classification_error\": \"Classification Error\",\r\n",
    "            \"squared_loss\": \"Squared Loss\",\r\n",
    "            \"spherical_loss\": \"Spherical Loss\",\r\n",
    "        }\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "    elif plot_key_setting == \"paper_conf\":\r\n",
    "        cols = 4\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l1ecequantile_confidence_2000\": (0,0),\r\n",
    "            \"l1ecequantile_confidence_20\": (0,1),\r\n",
    "            \"l2ecequantile_confidence_2000\": (0,2),\r\n",
    "            \"l2ecequantile_confidence_20\": (0,3),\r\n",
    "        }\r\n",
    "        rename_title = {\r\n",
    "            \"l1ecequantile_confidence_2000\": r\"$\\ell_1\\mathrm{-}\\text{QECE}_{2000}^{\\mathsf{conf}}$\",\r\n",
    "            \"l1ecequantile_confidence_20\": r\"$\\ell_1\\mathrm{-}\\text{QECE}_{20}^{\\mathsf{conf}}$\",\r\n",
    "            \"l2ecequantile_confidence_2000\": r\"$\\ell_2\\mathrm{-}\\text{QECE}_{2000}^{\\mathsf{conf}}$\",\r\n",
    "            \"l2ecequantile_confidence_20\": r\"$\\ell_2\\mathrm{-}\\text{QECE}_{20}^{\\mathsf{conf}}$\",\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_yaxis = \"Calibration Error\"\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "    elif plot_key_setting == \"paper_class\":\r\n",
    "        cols = 4\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l1ecequantile_classwise_2000\": (0,0),\r\n",
    "            \"l1ecequantile_classwise_20\": (0,1),\r\n",
    "            \"l2ecequantile_classwise_2000\": (0,2),\r\n",
    "            \"l2ecequantile_classwise_20\": (0,3),\r\n",
    "        }\r\n",
    "        rename_title = {\r\n",
    "            \"l1ecequantile_classwise_2000\": r\"$\\ell_1\\mathrm{-}\\text{QECE}_{2000}^{\\mathsf{classwise}}$\",\r\n",
    "            \"l1ecequantile_classwise_20\": r\"$\\ell_1\\mathrm{-}\\text{QECE}_{20}^{\\mathsf{classwise}}$\",\r\n",
    "            \"l2ecequantile_classwise_2000\": r\"$\\ell_2\\mathrm{-}\\text{QECE}_{2000}^{\\mathsf{classwise}}$\",\r\n",
    "            \"l2ecequantile_classwise_20\": r\"$\\ell_2\\mathrm{-}\\text{QECE}_{20}^{\\mathsf{classwise}}$\",\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_yaxis = \"Calibration Error\"\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "    elif plot_key_setting == \"appendix_conf\":\r\n",
    "        cols = 4\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l1ecefix_confidence_2000\": (0,0),\r\n",
    "            \"l1ecefix_confidence_20\": (0,1),\r\n",
    "            \"l2ecefix_confidence_2000\": (0,2),\r\n",
    "            \"l2ecefix_confidence_20\": (0,3),\r\n",
    "        }\r\n",
    "        rename_title = {\r\n",
    "            \"l1ecefix_confidence_2000\": r\"$\\ell_1\\mathrm{-}\\text{ECE}_{2000}^{\\mathsf{conf}}$\",\r\n",
    "            \"l1ecefix_confidence_20\": r\"$\\ell_1\\mathrm{-}\\text{ECE}_{20}^{\\mathsf{conf}}$\",\r\n",
    "            \"l2ecefix_confidence_2000\": r\"$\\ell_2\\mathrm{-}\\text{ECE}_{2000}^{\\mathsf{conf}}$\",\r\n",
    "            \"l2ecefix_confidence_20\": r\"$\\ell_2\\mathrm{-}\\text{ECE}_{20}^{\\mathsf{conf}}$\",\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_yaxis = \"Calibration Error\"\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "    elif plot_key_setting == \"appendix_class\":\r\n",
    "        cols = 4\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l1ecefix_classwise_2000\": (0,0),\r\n",
    "            \"l1ecefix_classwise_20\": (0,1),\r\n",
    "            \"l2ecefix_classwise_2000\": (0,2),\r\n",
    "            \"l2ecefix_classwise_20\": (0,3),\r\n",
    "        }\r\n",
    "        rename_title = {\r\n",
    "            \"l1ecefix_classwise_2000\": r\"$\\ell_1\\mathrm{-}\\text{ECE}_{2000}^{\\mathsf{classwise}}$\",\r\n",
    "            \"l1ecefix_classwise_20\": r\"$\\ell_1\\mathrm{-}\\text{ECE}_{20}^{\\mathsf{classwise}}$\",\r\n",
    "            \"l2ecefix_classwise_2000\": r\"$\\ell_2\\mathrm{-}\\text{ECE}_{2000}^{\\mathsf{classwise}}$\",\r\n",
    "            \"l2ecefix_classwise_20\": r\"$\\ell_2\\mathrm{-}\\text{ECE}_{20}^{\\mathsf{classwise}}$\",\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_yaxis = \"Calibration Error\"\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "    elif plot_key_setting == \"l1ecequantile_confidence\":\r\n",
    "        cols = 2\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l1ecequantile_confidence_2000\": (0,0),\r\n",
    "            \"l1ecequantile_confidence_20\": (0,1),\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_title = {\r\n",
    "            \"l1ecequantile_confidence_2000\": \"2000 Bins\",\r\n",
    "            \"l1ecequantile_confidence_20\": \"20 Bins\",\r\n",
    "        }\r\n",
    "        rename_yaxis = \"Calibration Error\"\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "    elif plot_key_setting == \"l2ecequantile_classwise\":\r\n",
    "        cols = 2\r\n",
    "        scatter_rows = 1\r\n",
    "        plot_keys = {\r\n",
    "            \"l2ecequantile_classwise_2000\": (0,0),\r\n",
    "            \"l2ecequantile_classwise_20\": (0,1),\r\n",
    "        }\r\n",
    "        colbar = False\r\n",
    "        rename_title = {\r\n",
    "            \"l2ecequantile_classwise_2000\": \"2000 Bins\",\r\n",
    "            \"l2ecequantile_classwise_20\": \"20 Bins\",\r\n",
    "        }\r\n",
    "        rename_yaxis = \"Calibration Error\"\r\n",
    "        rename_xaxis = \"Log Loss\"\r\n",
    "\r\n",
    "\r\n",
    "    # Determine the grid shape for the main scatter plots\r\n",
    "\r\n",
    "    scale = 10 if large else 5\r\n",
    "    fig = plt.figure(figsize=(scale *(1.06 if plot_key_setting != \"paper_loss\" else 1.12)* cols, scale *0.82* scatter_rows + (2 if colbar else 0))) # Adjust figure size\r\n",
    "\r\n",
    "\r\n",
    "    # Add a main title for the entire figure\r\n",
    "    if plot_key_setting is None:\r\n",
    "        fig.suptitle(title, fontsize=16 * font_scale, y=0.98) # Scale suptitle font size\r\n",
    "\r\n",
    "    # Create a GridSpec layout. One extra row for the colorbar.\r\n",
    "    # The top row for the colorbar will be much shorter than the plot rows.\r\n",
    "    gs = gridspec.GridSpec(\r\n",
    "        scatter_rows + (1 if colbar else 0),\r\n",
    "        cols,\r\n",
    "        height_ratios=([0.1] if colbar else []) + [1] * scatter_rows,\r\n",
    "        hspace=0.4\r\n",
    "    )\r\n",
    "\r\n",
    "    cax = fig.add_subplot(gs[0, :]) if colbar else None\r\n",
    "    axes_mp = {key: fig.add_subplot(gs[a + (1 if colbar else 0), b]) for key, (a, b) in plot_keys.items()}\r\n",
    "\r\n",
    "\r\n",
    "    # params\r\n",
    "\r\n",
    "    do_lowess = False\r\n",
    "    dom_loss_keys = [\r\n",
    "        \"classification_error\",\r\n",
    "        \"squared_loss\",\r\n",
    "        \"spherical_loss\",\r\n",
    "        \"log_loss\"\r\n",
    "    ]\r\n",
    "\r\n",
    "    # PART 1: GLOBAL LIS & LOWESS PRE-COMPUTATION\r\n",
    "    global_dom_points_set = set()\r\n",
    "    dominant_coords = {name: {'x': [], 'y': []} for name in plot_keys}\r\n",
    "    smoothed_curves = {}\r\n",
    "    markers = ['o', '+', '*', 'x', '^', 'v', '<', '>', 'd', 's', 'p', 'h', 'H', 'D', '|', '_'][:len(outs)]\r\n",
    "\r\n",
    "\r\n",
    "    if paint_dom:\r\n",
    "        all_points_for_lis = []\r\n",
    "        for marker_key, mp in outs.items():\r\n",
    "            losses = [np.array(mp.get(key, [])[prefix:]) for key in dom_loss_keys]\r\n",
    "\r\n",
    "            # Ensure data is valid for this marker group before adding its points\r\n",
    "            if not (len(losses[0]) > 0 and all(len(l) == len(losses[0]) for l in losses)):\r\n",
    "                print(f\"Warning: Skipping marker group '{marker_key}' in LIS calculation due to invalid data.\")\r\n",
    "                continue\r\n",
    "\r\n",
    "            num_points_in_mp = len(losses[0])\r\n",
    "            for i in range(num_points_in_mp):\r\n",
    "                # For each point, store its loss values, its source marker_key, and its index within that source\r\n",
    "                point_losses = tuple(l[i] for l in losses)\r\n",
    "                all_points_for_lis.append((point_losses, marker_key, i))\r\n",
    "\r\n",
    "        if all_points_for_lis:\r\n",
    "            num_losses = len(dom_loss_keys)\r\n",
    "            # Sort all points globally by the first loss metric\r\n",
    "            points = sorted(all_points_for_lis, key=lambda x: x[0])\r\n",
    "\r\n",
    "            n = len(points)\r\n",
    "            dp = [1] * n\r\n",
    "            predecessor = [-1] * n\r\n",
    "            for i in range(n):\r\n",
    "                for j in range(i):\r\n",
    "                    is_dominated = all(points[j][0][k] < points[i][0][k] for k in range(1, num_losses))\r\n",
    "                    if is_dominated and dp[j] + 1 > dp[i]:\r\n",
    "                        dp[i] = dp[j] + 1\r\n",
    "                        predecessor[i] = j\r\n",
    "\r\n",
    "\r\n",
    "            if dp:\r\n",
    "                max_len = max(dp)\r\n",
    "                end_index = dp.index(max_len)\r\n",
    "                current_index = end_index\r\n",
    "                while current_index != -1:\r\n",
    "                    # points[i][1] is marker_key, points[i][2] is index_in_mp\r\n",
    "                    global_dom_points_set.add((points[current_index][1], points[current_index][2]))\r\n",
    "                    current_index = predecessor[current_index]\r\n",
    "\r\n",
    "        if do_lowess:\r\n",
    "            for marker_key, idx in global_dom_points_set:\r\n",
    "                mp = outs[marker_key]\r\n",
    "                x_val = mp[x_axis_option][prefix:][idx]\r\n",
    "                for plot_name in plot_keys:\r\n",
    "                    if plot_name in mp:\r\n",
    "                        y_val = mp[plot_name][prefix:][idx]\r\n",
    "                        dominant_coords[plot_name]['x'].append(x_val)\r\n",
    "                        dominant_coords[plot_name]['y'].append(y_val)\r\n",
    "\r\n",
    "            for plot_name, coords in dominant_coords.items():\r\n",
    "                if len(coords['x']) > 2: # lowess requires enough points\r\n",
    "                    x_vals = np.array(coords['x'])\r\n",
    "                    y_vals = np.array(coords['y'])\r\n",
    "\r\n",
    "                    order = np.argsort(x_vals)\r\n",
    "                    sorted_x = x_vals[order]\r\n",
    "                    sorted_y = y_vals[order]\r\n",
    "\r\n",
    "                    smoothed = sm.nonparametric.lowess(sorted_y, sorted_x, frac=0.8)\r\n",
    "                    smoothed_curves[plot_name] = smoothed\r\n",
    "\r\n",
    "    # PART 2: PLOTTING\r\n",
    "\r\n",
    "    from matplotlib.colors import LinearSegmentedColormap\r\n",
    "    original_plasma = cm.get_cmap('plasma', 256)\r\n",
    "    start_fraction = 0.0\r\n",
    "    stop_fraction = 0.75\r\n",
    "    new_colors = original_plasma(np.linspace(start_fraction, stop_fraction, 256))\r\n",
    "    custom_plasma_no_yellow = LinearSegmentedColormap.from_list(\"custom_plasma_no_yellow\", new_colors)\r\n",
    "\r\n",
    "\r\n",
    "    for marker, (name, mp) in zip(markers, outs.items()):\r\n",
    "        x_loss = np.array(mp[x_axis_option][prefix:])\r\n",
    "        point_indices = np.arange(len(x_loss))\r\n",
    "\r\n",
    "        # For this specific marker group, find which of its points are in the global dominant set\r\n",
    "        num_points_in_mp = len(x_loss)\r\n",
    "        dom_indices_in_mp = np.array([i for i in range(num_points_in_mp) if (name, i) in global_dom_points_set])\r\n",
    "\r\n",
    "        all_indices_set = set(range(num_points_in_mp))\r\n",
    "        non_dom_indices_in_mp = np.array(list(all_indices_set - set(dom_indices_in_mp)))\r\n",
    "\r\n",
    "        for plot_name in mp.keys():\r\n",
    "            if plot_name not in plot_keys: continue\r\n",
    "            y_vals = np.array(mp[plot_name][prefix:])\r\n",
    "            if plot_name == \"spherical_loss\": y_vals = y_vals + 1.\r\n",
    "            ax = axes_mp[plot_name]\r\n",
    "\r\n",
    "            if paint_dom:\r\n",
    "                # Plot non-dominant points (semi-transparent gray)\r\n",
    "                if non_dom_indices_in_mp.size > 0:\r\n",
    "                    ax.scatter(\r\n",
    "                        x_loss[non_dom_indices_in_mp], y_vals[non_dom_indices_in_mp],\r\n",
    "                        c='black', alpha=0.1, s=s*0.33, marker=marker\r\n",
    "                    )\r\n",
    "                # Plot dominant points (colored)\r\n",
    "                if dom_indices_in_mp.size > 0:\r\n",
    "                    scatter_plot = ax.scatter(\r\n",
    "                        x_loss[dom_indices_in_mp], y_vals[dom_indices_in_mp],\r\n",
    "                        c=point_indices[dom_indices_in_mp] + prefix,\r\n",
    "                        cmap=custom_plasma_no_yellow, alpha=0.75, s=s, marker=marker\r\n",
    "                    )\r\n",
    "            else:\r\n",
    "                scatter_plot = ax.scatter(\r\n",
    "                        x_loss, y_vals,\r\n",
    "                        c=point_indices + prefix,\r\n",
    "                        cmap=custom_plasma_no_yellow, alpha=0.75, s=s, marker=marker\r\n",
    "                    )\r\n",
    "\r\n",
    "            ax.set_xlabel(x_axis_option if rename_xaxis is None else (rename_xaxis if type(rename_xaxis)==str else rename_xaxis[plot_name]), fontsize=7 * font_scale)\r\n",
    "            ax.set_ylabel(plot_name if rename_yaxis is None else (rename_yaxis if type(rename_yaxis)==str else rename_yaxis[plot_name]), fontsize=7 * font_scale)\r\n",
    "            ax.set_title(plot_name if rename_title is None else (rename_title if type(rename_title)==str else rename_title[plot_name]), fontsize=10 * font_scale)\r\n",
    "\r\n",
    "            class CustomFormatter(ticker.ScalarFormatter):\r\n",
    "                \"\"\"\r\n",
    "                This formatter forces tick labels to display one decimal place.\r\n",
    "                For example, an integer tick '2' will be formatted as '2.0'.\r\n",
    "                \"\"\"\r\n",
    "                def _set_format(self):\r\n",
    "                    self.format = '%1.1f'\r\n",
    "            for axis_obj in [ax.xaxis, ax.yaxis]:\r\n",
    "                axis_obj.set_major_formatter(CustomFormatter(useMathText=True))\r\n",
    "            ax.ticklabel_format(axis='both', style='sci', scilimits=(-1,1))\r\n",
    "\r\n",
    "            ax.tick_params(axis='x', labelsize=8 * font_scale)\r\n",
    "            ax.tick_params(axis='y', labelsize=8 * font_scale)\r\n",
    "\r\n",
    "            if len(outs) > 1:\r\n",
    "                ax.set_xticks([0.6,0.8,1.0,1.2,1.4])\r\n",
    "                if plot_name == \"l1ecequantile_confidence_20\": ax.set_yticks([0.00,0.03,0.06,0.09,0.12])\r\n",
    "                if plot_name == \"l2ecequantile_confidence_20\": ax.set_yticks([0.0000,0.0003,0.0006,0.0009])\r\n",
    "            else:\r\n",
    "                aaa = 1\r\n",
    "\r\n",
    "\r\n",
    "\r\n",
    "    # PART 3: DRAW THE SINGLE, GLOBAL LOWESS CURVE\r\n",
    "    if paint_dom and do_lowess:\r\n",
    "        for plot_name, smoothed in smoothed_curves.items():\r\n",
    "            ax = axes_mp[plot_name]\r\n",
    "            # Plot the smooth curve with the same style as non-dominant points\r\n",
    "            ax.plot(smoothed[:, 0], smoothed[:, 1], color=\"gray\", alpha=0.6, linewidth=2, zorder=1)\r\n",
    "\r\n",
    "\r\n",
    "    if scatter_plot and (cax is not None):\r\n",
    "        cbar = fig.colorbar(\r\n",
    "            scatter_plot,\r\n",
    "            cax=cax,\r\n",
    "            orientation='horizontal'\r\n",
    "        )\r\n",
    "        cbar.set_label(f'Point Index (Offset by {prefix})', rotation=0, labelpad=5, fontsize=10 * font_scale)\r\n",
    "        cax.xaxis.set_ticks_position('top')\r\n",
    "        cax.xaxis.set_label_position('top')\r\n",
    "        cax.tick_params(axis='x', labelsize=8 * font_scale)\r\n",
    "\r\n",
    "    if \"legend\" in axes_mp:\r\n",
    "        legend_handles = [\r\n",
    "            mlines.Line2D([], [], color='black', marker=marker, linestyle='None',\r\n",
    "                        markersize=8 * font_scale, label=name)\r\n",
    "            for marker, name in zip(markers, outs.keys())\r\n",
    "        ]\r\n",
    "        axes_mp[\"legend\"].legend(handles=legend_handles, title='Marker Types', loc='center', fontsize=8 * font_scale)\r\n",
    "        axes_mp[\"legend\"].axis('off')\r\n",
    "\r\n",
    "    gs.tight_layout(fig) # , rect=[0, 0.03, 1, 0.95]\r\n",
    "\r\n",
    "    plt.savefig(os.path.join(figure_dir, f\"{title}.pdf\"))\r\n",
    "    print(\"Figure saved to: \", os.path.join(figure_dir, f\"{title}.pdf\"))\r\n",
    "    plt.show()\r\n",
    "\r\n",
    "def show_all(outs, title, plot_key_setting=None):\r\n",
    "\r\n",
    "    if plot_key_setting is None:\r\n",
    "        plot_keys = {\r\n",
    "            \"classification_error\": (0, 0),\r\n",
    "            \"squared_loss\": (0, 1),\r\n",
    "            \"spherical_loss\": (0, 2),\r\n",
    "            \"legend\": (0, 3),\r\n",
    "            \"l1ecequantile_confidence_2000\": (1, 0),\r\n",
    "            \"l1ecequantile_confidence_20\": (1, 1),\r\n",
    "            \"l1ecequantile_classwise_2000\": (1, 2),\r\n",
    "            \"l1ecequantile_classwise_20\": (1, 3),\r\n",
    "            \"l2ecequantile_confidence_2000\": (2, 0),\r\n",
    "            \"l2ecequantile_confidence_20\": (2, 1),\r\n",
    "            \"l2ecequantile_classwise_2000\": (2, 2),\r\n",
    "            \"l2ecequantile_classwise_20\": (2, 3),\r\n",
    "            \"l1ecefix_confidence_2000\": (3, 0),\r\n",
    "            \"l1ecefix_confidence_20\": (3, 1),\r\n",
    "            \"l1ecefix_classwise_2000\": (3, 2),\r\n",
    "            \"l1ecefix_classwise_20\": (3, 3),\r\n",
    "            \"l2ecefix_confidence_2000\": (4, 0),\r\n",
    "            \"l2ecefix_confidence_20\": (4, 1),\r\n",
    "            \"l2ecefix_confidence_5\": (4, 2),\r\n",
    "            \"l2ecefix_classwise_2000\": (4, 2),\r\n",
    "            \"l2ecefix_classwise_20\": (4, 3),\r\n",
    "            \"atb_confidence\": (5,0),\r\n",
    "            \"atb_classwise_full\": (5,1),\r\n",
    "        }\r\n",
    "    elif plot_key_setting == \"l2ecequantile\":\r\n",
    "        plot_keys = {\r\n",
    "            \"l2ecequantile_confidence_2000\": (0, 0),\r\n",
    "            \"l2ecequantile_confidence_20\": (0, 1),\r\n",
    "            \"l2ecequantile_classwise_2000\": (1, 0),\r\n",
    "            \"l2ecequantile_classwise_20\": (1, 1),\r\n",
    "        }\r\n",
    "\r\n",
    "    print(f\"--- Results for {title} ---\")\r\n",
    "    for model_name, mp in outs.items():\r\n",
    "        print(f\"\\nModel: {model_name}\")\r\n",
    "        for name in plot_keys.keys():\r\n",
    "            if name in mp:\r\n",
    "                value = mp[name][-1]\r\n",
    "                if value:\r\n",
    "                    print(f\"{name}: {value:.3e}\")\r\n",
    "                else:\r\n",
    "                    print(f\"{name}: N/A\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import concurrent.futures\r\n",
    "\r\n",
    "# load 1 1 1\r\n",
    "# new 0 2 0\r\n",
    "\r\n",
    "calc_step = 1\r\n",
    "\r\n",
    "use_output_cache = 1\r\n",
    "use_tem = 1 # 0 for not scaling, 1 for scaling (cached tem), 2 for scaling (calc tem)\r\n",
    "load_mp = 1 # 0 for not loading, 1 for loading\r\n",
    "if use_tem == 2 and use_output_cache == 1:\r\n",
    "    raise Exception(\"Cannot use both use_tem == 2 and use_output_cache == 1\")\r\n",
    "method_map = {\"classification_error\": classification_error,\r\n",
    "              \"squared_loss\": squared_loss,\r\n",
    "              \"spherical_loss\": spherical_loss,\r\n",
    "              \"log_loss\": F.cross_entropy,\r\n",
    "              \"l1ecequantile_confidence_2000\": (l1ecequa_confidence,2000),\r\n",
    "              \"l1ecequantile_confidence_20\": (l1ecequa_confidence,20),\r\n",
    "              \"l1ecequantile_confidence_5\": (l1ecequa_confidence,5),\r\n",
    "              \"l1ecequantile_classwise_2000\": (l1ecequa_classwise,2000),\r\n",
    "              \"l1ecequantile_classwise_20\": (l1ecequa_classwise,20),\r\n",
    "              \"l2ecequantile_confidence_2000\": (l2ecequa_confidence,2000),\r\n",
    "              \"l2ecequantile_confidence_20\": (l2ecequa_confidence,20),\r\n",
    "              \"l2ecequantile_confidence_5\": (l2ecequa_confidence,5),\r\n",
    "              \"l2ecequantile_classwise_2000\": (l2ecequa_classwise,2000),\r\n",
    "              \"l2ecequantile_classwise_20\": (l2ecequa_classwise,20),\r\n",
    "              \"atb_classwise_full\": atb_classwise,\r\n",
    "              \"atb_confidence\": atb_confidence,\r\n",
    "              \"l1ecefix_confidence_2000\": (l1ecefix_confidence,2000),\r\n",
    "              \"l1ecefix_confidence_20\": (l1ecefix_confidence,20),\r\n",
    "              \"l1ecefix_confidence_5\": (l1ecefix_confidence,5),\r\n",
    "              \"l1ecefix_classwise_2000\": (l1ecefix_classwise,2000),\r\n",
    "              \"l1ecefix_classwise_20\": (l1ecefix_classwise,20),\r\n",
    "              \"l2ecefix_confidence_2000\": (l2ecefix_confidence,2000),\r\n",
    "              \"l2ecefix_confidence_20\": (l2ecefix_confidence,20),\r\n",
    "              \"l2ecefix_confidence_5\": (l2ecefix_confidence,5),\r\n",
    "              \"l2ecefix_classwise_2000\": (l2ecefix_classwise,2000),\r\n",
    "              \"l2ecefix_classwise_20\": (l2ecefix_classwise,20),}\r\n",
    "upd = []\r\n",
    "tem = [None]\r\n",
    "if use_tem == 1:\r\n",
    "    with open(os.path.join(this_checkpoint_dir, f\"tem.json\"), \"r\", encoding=\"utf-8\") as f:\r\n",
    "        tem = json.load(f)\r\n",
    "\r\n",
    "if load_mp == 0:\r\n",
    "    mp = {}\r\n",
    "    upd = list(method_map.keys())\r\n",
    "else:\r\n",
    "    try:\r\n",
    "        with open(os.path.join(this_checkpoint_dir, f\"res_{(use_tem+1)//2}.json\"), \"r\", encoding=\"utf-8\") as f:\r\n",
    "            mp = json.load(f)\r\n",
    "        for it in upd:\r\n",
    "            if it in mp: del mp[it]\r\n",
    "        upd = [it for it in method_map.keys() if it not in mp]\r\n",
    "    except:\r\n",
    "        print(\"No mp found\")\r\n",
    "        mp = {}\r\n",
    "        upd = list(method_map.keys())\r\n",
    "\r\n",
    "\r\n",
    "print(len(upd), upd)\r\n",
    "methods = [method_map[it] for it in upd]\r\n",
    "\r\n",
    "if len(upd) > 0:\r\n",
    "    if use_output_cache == 1:\r\n",
    "        model_outputs = torch.load(os.path.join(this_checkpoint_dir, f\"modeloutputs_{(use_tem+1)//2}.pt\"))\r\n",
    "    else:\r\n",
    "        model_outputs = [None]\r\n",
    "        for use_checkpoint_epoch in range(1, epochs+1, calc_step):\r\n",
    "            checkpoint_path = os.path.join(this_checkpoint_dir, f\"model_epoch_{use_checkpoint_epoch}.pth\")\r\n",
    "            print(f\"🔄 Loading model weights from checkpoint: {checkpoint_path}\")\r\n",
    "            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\r\n",
    "            model.load_state_dict(checkpoint['model_state_dict'])\r\n",
    "            model.eval()\r\n",
    "\r\n",
    "            scaled_model = TemperatureScaler(model).to(device)\r\n",
    "            if use_tem == 0:\r\n",
    "                scaled_model.temperature = nn.Parameter(torch.tensor([1.], device=device))\r\n",
    "            if use_tem == 1:\r\n",
    "                scaled_model.temperature = nn.Parameter(torch.tensor([tem[use_checkpoint_epoch]], device=device))\r\n",
    "            elif use_tem == 2:\r\n",
    "                scaled_model.set_temperature(val_loader, device, guess=tem[-1])\r\n",
    "\r\n",
    "            print(f\"[{'AfterTS' if use_tem else 'BeforeTS'}] epoch = {use_checkpoint_epoch}\")\r\n",
    "            if use_tem == 2:\r\n",
    "                tem.append(scaled_model.temperature.item())\r\n",
    "\r\n",
    "            #values = calc_values(scaled_model, test_loader, methods)\r\n",
    "            scaled_model.eval()\r\n",
    "            all_logits, all_labels = [], []\r\n",
    "\r\n",
    "            with torch.no_grad():\r\n",
    "                for x, y in test_loader:\r\n",
    "                    x, y = x.to(device), y.to(device)\r\n",
    "                    logits = model(x)\r\n",
    "\r\n",
    "                    all_logits.append(logits)\r\n",
    "                    all_labels.append(y)\r\n",
    "\r\n",
    "            all_logits = torch.cat(all_logits, dim=0)\r\n",
    "            all_labels = torch.cat(all_labels, dim=0)\r\n",
    "            model_outputs.append((all_logits, all_labels))\r\n",
    "        torch.save(model_outputs, os.path.join(this_checkpoint_dir, f\"modeloutputs_{(use_tem+1)//2}.pt\"))\r\n",
    "\r\n",
    "    print(\"length of model_outputs: \", len(model_outputs))\r\n",
    "\r\n",
    "\r\n",
    "    def process_epoch(epoch_data: tuple):\r\n",
    "        epoch, model_outputs, methods, upd = epoch_data\r\n",
    "        print(\"haha\", epoch, (epoch - 1) // calc_step + 1, len(methods), len(model_outputs))\r\n",
    "        all_logits, all_labels = model_outputs[(epoch - 1) // calc_step + 1]\r\n",
    "        values = []\r\n",
    "        for method in methods:\r\n",
    "            if isinstance(method, tuple):\r\n",
    "                method_func = method[0]\r\n",
    "                try:\r\n",
    "                    values.append(method_func(all_logits, all_labels, method[1]).item())\r\n",
    "                except Exception(e):\r\n",
    "                    print(e)\r\n",
    "            else:\r\n",
    "                method_func = method\r\n",
    "                try:\r\n",
    "                    values.append(method_func(all_logits, all_labels).item())\r\n",
    "                except Exception(e):\r\n",
    "                    print(e)\r\n",
    "        print(epoch, values)\r\n",
    "        return list(zip(upd, values))\r\n",
    "\r\n",
    "    print(epochs)\r\n",
    "\r\n",
    "    tasks_data = [(epoch, model_outputs, methods, upd) for epoch in range(1, epochs + 1, calc_step)]\r\n",
    "    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:\r\n",
    "        all_epoch_results = executor.map(process_epoch, tasks_data)\r\n",
    "        for result_list in all_epoch_results:\r\n",
    "            for key, value in result_list:\r\n",
    "                if key not in mp:\r\n",
    "                    mp[key] = []\r\n",
    "                mp[key].append(value)\r\n",
    "\r\n",
    "    if use_tem == 2:\r\n",
    "        out_name = os.path.join(this_checkpoint_dir, f\"tem.json\")\r\n",
    "        with open(out_name, \"w\", encoding=\"utf-8\") as f:\r\n",
    "            print(\"Saving to: \", out_name)\r\n",
    "            json.dump(tem, f, ensure_ascii=False, indent=4)\r\n",
    "    out_name = os.path.join(this_checkpoint_dir, f\"res_{(use_tem+1)//2}.json\")\r\n",
    "    with open(out_name, \"w\", encoding=\"utf-8\") as f:\r\n",
    "        print(\"Saving to: \", out_name)\r\n",
    "        json.dump(mp, f, ensure_ascii=False, indent=4)\r\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import concurrent.futures\r\n",
    "\r\n",
    "plot_model_names_all = [\"mobilenetv3_small_100PT\",\"resnet10tPT\",\"resnet18PT\",\"resnet34PT\",\"resnet50PT\",\"resnetv2_50x1_bitPT\"]\r\n",
    "\r\n",
    "if True:\r\n",
    "    plot_model_names = plot_model_names_all\r\n",
    "    plot_dataset_name = dataset_name\r\n",
    "    plot_sch_choice = sch_choice\r\n",
    "    plot_regularizer_name = regularizer_name\r\n",
    "    plot_experimental = experimental\r\n",
    "    plot_tmpadd = \"\"\r\n",
    "    plot_use_tem = 1\r\n",
    "\r\n",
    "    outs = {}\r\n",
    "    for plot_model_name in plot_model_names:\r\n",
    "        plot_checkpoint_dir = os.path.join(checkpoint_dir, plot_dataset_name)\r\n",
    "        plot_checkpoint_dir = os.path.join(plot_checkpoint_dir, plot_model_name + \"_\" + plot_sch_choice + \"_\" + plot_regularizer_name)\r\n",
    "        plot_checkpoint_dir = os.path.join(plot_checkpoint_dir, f\"variant{plot_experimental + plot_tmpadd}\")\r\n",
    "        # Print the filename being loaded\r\n",
    "        in_name = os.path.join(plot_checkpoint_dir, f'res_{(plot_use_tem+1)//2}.json')\r\n",
    "        print(f\"Loading results from: {in_name}\")\r\n",
    "        with open(in_name, \"r\", encoding=\"utf-8\") as f:\r\n",
    "            mp = json.load(f)\r\n",
    "        outs[plot_model_name]=mp\r\n",
    "\r\n",
    "    for plot_key_setting in [\"paper_loss\",\"paper_conf\",\"paper_class\",\"appendix_conf\",\"appendix_class\"]:\r\n",
    "        title = plot_dataset_name +\"__\" + plot_tmpadd + \"__\" + plot_sch_choice + \"lr__\" + (\"AfterTS\" if plot_use_tem else \"BeforeTS\") + \"__\" + (plot_model_names[0] if len(plot_model_names)==1 else \"all\") + \"__\" + str(plot_key_setting)\r\n",
    "        #plot_multiple_scatter(outs, title = title, prefix=10, large=False, s=8) # plot_key_setting=\"l2ecequantile\"\r\n",
    "        # show_all(outs, title = title)\r\n",
    "        plot_multiple_scatter(outs, title = title, prefix=10, large=False, s=30, x_axis_option = \"log_loss\", font_scale = 2.5, plot_key_setting=plot_key_setting, paint_dom=(len(plot_model_names_all)>1))"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
