{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5ea42d92",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Basic eval imports ready.\n",
      "Results will be saved under: data/cmnist/results_empirical\n"
     ]
    }
   ],
   "source": [
    "# === CMNIST EVALUATION NOTEBOOK: SETUP ===\n",
    "import os\n",
    "import gc\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.transforms as T\n",
    "\n",
    "print(\"✅ Basic eval imports ready.\")\n",
    "\n",
    "# You should already have loaded:\n",
    "# - all_results\n",
    "# - cv_folds\n",
    "# - omega\n",
    "# - Dll_samples, Dhl_samples\n",
    "# - det_ll_dict, det_ll_dict_opt, det_hl_dict_opt\n",
    "# - U_ll_hat_fixed, U_hl_hat_fixed, U_ll_hat_opt, U_hl_hat_opt\n",
    "\n",
    "# Optional: set default output dir for eval results\n",
    "output_dir = 'data/cmnist/results_empirical'\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "print(f\"Results will be saved under: {output_dir}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eb583abe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading CMNIST eval state...\n",
      "✗ Could not find all_results at data/cmnist/results_empirical/cmnist_all_results.pkl. Please update the path.\n",
      "✗ Could not find cv_folds at data/cmnist/results_empirical/cmnist_cv_folds.pkl. Please update the path.\n",
      "✗ Could not find omega at data/cmnist/results_empirical/cmnist_omega.pkl. Please update the path.\n",
      "✗ det_ll_dict_opt not found at data/cmnist/results_empirical/cmnist_det_ll_dict_opt.pt\n",
      "✗ det_hl_dict_opt not found at data/cmnist/results_empirical/cmnist_det_hl_dict_opt.pt\n",
      "✗ U_ll_hat_opt not found at data/cmnist/results_empirical/cmnist_U_ll_hat_opt.pt\n",
      "✗ U_hl_hat_opt not found at data/cmnist/results_empirical/cmnist_U_hl_hat_opt.pt\n",
      "✗ Dll_samples not found at data/cmnist/results_empirical/cmnist_Dll_samples.pkl\n",
      "✗ Dhl_samples not found at data/cmnist/results_empirical/cmnist_Dhl_samples.pkl\n",
      "Done loading. Adjust any paths above as needed.\n"
     ]
    }
   ],
   "source": [
    "# === Cell 0: Load training results and CMNIST structures ===\n",
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "print(\"Loading CMNIST eval state...\")\n",
    "\n",
    "# TODO: adapt these paths to your actual saved files\n",
    "base_dir = \"data/cmnist/results_empirical\"\n",
    "\n",
    "# Example: merged optimization results (all_results)\n",
    "all_results_path = os.path.join(base_dir, \"cmnist_all_results.pkl\")\n",
    "if os.path.exists(all_results_path):\n",
    "    with open(all_results_path, \"rb\") as f:\n",
    "        all_results = pickle.load(f)\n",
    "    print(f\"✓ Loaded all_results from {all_results_path}\")\n",
    "else:\n",
    "    print(f\"✗ Could not find all_results at {all_results_path}. Please update the path.\")\n",
    "\n",
    "# Example: cross-validation folds\n",
    "cv_folds_path = os.path.join(base_dir, \"cmnist_cv_folds.pkl\")\n",
    "if os.path.exists(cv_folds_path):\n",
    "    with open(cv_folds_path, \"rb\") as f:\n",
    "        cv_folds = pickle.load(f)\n",
    "    print(f\"✓ Loaded cv_folds from {cv_folds_path}\")\n",
    "else:\n",
    "    print(f\"✗ Could not find cv_folds at {cv_folds_path}. Please update the path.\")\n",
    "\n",
    "# Example: omega (environment pair mapping)\n",
    "omega_path = os.path.join(base_dir, \"cmnist_omega.pkl\")\n",
    "if os.path.exists(omega_path):\n",
    "    with open(omega_path, \"rb\") as f:\n",
    "        omega = pickle.load(f)\n",
    "    print(f\"✓ Loaded omega from {omega_path}\")\n",
    "else:\n",
    "    print(f\"✗ Could not find omega at {omega_path}. Please update the path.\")\n",
    "\n",
    "# Example: deterministic dicts and noise (opt view)\n",
    "det_ll_opt_path = os.path.join(base_dir, \"cmnist_det_ll_dict_opt.pt\")\n",
    "det_hl_opt_path = os.path.join(base_dir, \"cmnist_det_hl_dict_opt.pt\")\n",
    "U_ll_opt_path   = os.path.join(base_dir, \"cmnist_U_ll_hat_opt.pt\")\n",
    "U_hl_opt_path   = os.path.join(base_dir, \"cmnist_U_hl_hat_opt.pt\")\n",
    "\n",
    "if os.path.exists(det_ll_opt_path):\n",
    "    det_ll_dict_opt = torch.load(det_ll_opt_path)\n",
    "    print(f\"✓ Loaded det_ll_dict_opt from {det_ll_opt_path}\")\n",
    "else:\n",
    "    print(f\"✗ det_ll_dict_opt not found at {det_ll_opt_path}\")\n",
    "\n",
    "if os.path.exists(det_hl_opt_path):\n",
    "    det_hl_dict_opt = torch.load(det_hl_opt_path)\n",
    "    print(f\"✓ Loaded det_hl_dict_opt from {det_hl_opt_path}\")\n",
    "else:\n",
    "    print(f\"✗ det_hl_dict_opt not found at {det_hl_opt_path}\")\n",
    "\n",
    "if os.path.exists(U_ll_opt_path):\n",
    "    U_ll_hat_opt = torch.load(U_ll_opt_path)\n",
    "    print(f\"✓ Loaded U_ll_hat_opt from {U_ll_opt_path}\")\n",
    "else:\n",
    "    print(f\"✗ U_ll_hat_opt not found at {U_ll_opt_path}\")\n",
    "\n",
    "if os.path.exists(U_hl_opt_path):\n",
    "    U_hl_hat_opt = torch.load(U_hl_opt_path)\n",
    "    print(f\"✓ Loaded U_hl_hat_opt from {U_hl_opt_path}\")\n",
    "else:\n",
    "    print(f\"✗ U_hl_hat_opt not found at {U_hl_opt_path}\")\n",
    "\n",
    "# Example: Dll_samples / Dhl_samples for evaluation\n",
    "dll_samples_path = os.path.join(base_dir, \"cmnist_Dll_samples.pkl\")\n",
    "dhl_samples_path = os.path.join(base_dir, \"cmnist_Dhl_samples.pkl\")\n",
    "\n",
    "if os.path.exists(dll_samples_path):\n",
    "    with open(dll_samples_path, \"rb\") as f:\n",
    "        Dll_samples = pickle.load(f)\n",
    "    print(f\"✓ Loaded Dll_samples from {dll_samples_path}\")\n",
    "else:\n",
    "    print(f\"✗ Dll_samples not found at {dll_samples_path}\")\n",
    "\n",
    "if os.path.exists(dhl_samples_path):\n",
    "    with open(dhl_samples_path, \"rb\") as f:\n",
    "        Dhl_samples = pickle.load(f)\n",
    "    print(f\"✓ Loaded Dhl_samples from {dhl_samples_path}\")\n",
    "else:\n",
    "    print(f\"✗ Dhl_samples not found at {dhl_samples_path}\")\n",
    "\n",
    "print(\"Done loading. Adjust any paths above as needed.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f7403973",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==================== Visualizing Adversarial Perturbations (UPDATED) ====================\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'all_results' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[16]\u001b[39m\u001b[32m, line 29\u001b[39m\n\u001b[32m     24\u001b[39m \u001b[38;5;66;03m# ------------------------------------------------------------\u001b[39;00m\n\u001b[32m     25\u001b[39m \u001b[38;5;66;03m# 1) Select DiRoCA run and fold\u001b[39;00m\n\u001b[32m     26\u001b[39m \u001b[38;5;66;03m# ------------------------------------------------------------\u001b[39;00m\n\u001b[32m     27\u001b[39m fold_key_to_show = \u001b[33m\"\u001b[39m\u001b[33mfold_0\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m diroca_key_toplevel = \u001b[38;5;28mnext\u001b[39m((k \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[43mall_results\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m k.startswith(\u001b[33m\"\u001b[39m\u001b[33mdiroca_eps_\u001b[39m\u001b[33m\"\u001b[39m)), \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m     30\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m diroca_key_toplevel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m     31\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mError: No DiRoCA results found in all_results.\u001b[39m\u001b[33m\"\u001b[39m)\n",
      "\u001b[31mNameError\u001b[39m: name 'all_results' is not defined"
     ]
    }
   ],
   "source": [
    "# === Visualization: adversarial perturbations (pixels-only, DiRoCA) ===\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "print(\"\\n\" + \"=\"*20 + \" Visualizing Adversarial Perturbations (UPDATED) \" + \"=\"*20)\n",
    "\n",
    "# ------------------------------------------------------------\n",
    "# 0) Helper: find observational LL key in det_ll_dict\n",
    "# ------------------------------------------------------------\n",
    "def find_obs_ll_key(det_ll_dict):\n",
    "    keys = list(det_ll_dict.keys())\n",
    "    if not keys:\n",
    "        raise ValueError(\"det_ll_dict is empty; cannot find observational key.\")\n",
    "    for k in keys:\n",
    "        if isinstance(k, str) and k.lower() in [\"obs\", \"observational\", \"none\", \"null\"]:\n",
    "            return k\n",
    "        label = getattr(k, \"label\", None)\n",
    "        if isinstance(label, str) and label.lower() in [\"obs\", \"observational\"]:\n",
    "            return k\n",
    "    print(\"[Warning] No explicit obs key found. Falling back to first key:\", keys[0])\n",
    "    return keys[0]\n",
    "\n",
    "# ------------------------------------------------------------\n",
    "# 1) Select DiRoCA run and fold\n",
    "# ------------------------------------------------------------\n",
    "fold_key_to_show = \"fold_0\"\n",
    "\n",
    "diroca_key_toplevel = next((k for k in all_results if k.startswith(\"diroca_eps_\")), None)\n",
    "if diroca_key_toplevel is None:\n",
    "    print(\"Error: No DiRoCA results found in all_results.\")\n",
    "    run_result = None\n",
    "else:\n",
    "    print(f\"Using DiRoCA results: {diroca_key_toplevel}\")\n",
    "    fold_block = all_results[diroca_key_toplevel].get(fold_key_to_show, None)\n",
    "\n",
    "    if fold_block is None:\n",
    "        print(f\"Error: Fold {fold_key_to_show} not found in {diroca_key_toplevel}\")\n",
    "        run_result = None\n",
    "    else:\n",
    "        run_keys = [k for k in fold_block.keys() if k.startswith(\"eps_\")]\n",
    "        if not run_keys:\n",
    "            print(f\"Error: no inner keys 'eps_*' in fold data. Keys: {list(fold_block.keys())}\")\n",
    "            run_result = None\n",
    "        else:\n",
    "            run_key = run_keys[0]\n",
    "            run_result = fold_block[run_key]\n",
    "            if isinstance(run_result, dict) and \"error\" in run_result:\n",
    "                print(f\"Error in run result: {run_result['error']}\")\n",
    "                run_result = None\n",
    "            else:\n",
    "                print(f\"Using run key: {run_key}\")\n",
    "\n",
    "if run_result is None:\n",
    "    print(\"✗ Could not access DiRoCA run_result for visualization.\")\n",
    "else:\n",
    "    # ------------------------------------------------------------\n",
    "    # 2) Extract Theta and fold/train indices\n",
    "    # ------------------------------------------------------------\n",
    "    final_Theta_ll = run_result.get(\"final_Theta_ll\", None)\n",
    "    if final_Theta_ll is None:\n",
    "        raise ValueError(\"Theta (final_Theta_ll) not found in run_result.\")\n",
    "\n",
    "    if not isinstance(final_Theta_ll, torch.Tensor):\n",
    "        final_Theta_ll = torch.tensor(final_Theta_ll, dtype=torch.float32)\n",
    "\n",
    "    epsilon_run = run_result.get(\"epsilon\", \"Unknown\")\n",
    "    fold_index = int(fold_key_to_show.split(\"_\")[-1])\n",
    "    train_indices = cv_folds[fold_index][\"train\"]\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 3) Get the *exact* U_ll used in training (pixels only)\n",
    "    # ------------------------------------------------------------\n",
    "    if \"U_ll_hat_fixed\" in globals():\n",
    "        U_ll_base = torch.as_tensor(U_ll_hat_fixed, dtype=torch.float32)\n",
    "    else:\n",
    "        U_ll_base = torch.as_tensor(U_ll_hat, dtype=torch.float32)\n",
    "        if U_ll_base.ndim > 2:\n",
    "            U_ll_base = U_ll_base.view(U_ll_base.shape[0], -1)\n",
    "\n",
    "    U_ll_train = U_ll_base[train_indices]  # (N_train, 3072)\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 4) Deterministic LL part for obs key\n",
    "    # ------------------------------------------------------------\n",
    "    obs_ll_key = find_obs_ll_key(det_ll_dict)\n",
    "    det_ll_train_obs = det_ll_dict[obs_ll_key][train_indices]\n",
    "\n",
    "    if not isinstance(det_ll_train_obs, torch.Tensor):\n",
    "        det_ll_train_obs = torch.as_tensor(det_ll_train_obs, dtype=torch.float32)\n",
    "\n",
    "    if det_ll_train_obs.shape[1] == 3092:\n",
    "        det_pixels_train = det_ll_train_obs[:, :3072]   # (N_train, 3072)\n",
    "    elif det_ll_train_obs.shape[1] == 3072:\n",
    "        det_pixels_train = det_ll_train_obs             # (N_train, 3072)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            f\"Unexpected det_ll_train_obs width {det_ll_train_obs.shape[1]} \"\n",
    "            \"expected 3072 or 3092.\"\n",
    "        )\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 5) Align Theta to train split + pixels-only view\n",
    "    # ------------------------------------------------------------\n",
    "    if final_Theta_ll.shape[0] != len(train_indices):\n",
    "        if final_Theta_ll.shape[0] == U_ll_base.shape[0]:\n",
    "            final_Theta_ll = final_Theta_ll[train_indices]\n",
    "        else:\n",
    "            raise ValueError(\n",
    "                f\"Theta rows {final_Theta_ll.shape[0]} do not match \"\n",
    "                f\"train size {len(train_indices)} or full N {U_ll_base.shape[0]}.\"\n",
    "            )\n",
    "\n",
    "    if final_Theta_ll.shape[1] != 3072:\n",
    "        raise ValueError(f\"Theta has {final_Theta_ll.shape[1]} cols, expected 3072 (pixels-only).\")\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 6) Reconstruct clean / worst-case pixels\n",
    "    # ------------------------------------------------------------\n",
    "    clean_recon_pixels = det_pixels_train + U_ll_train\n",
    "    worst_case_pixels = clean_recon_pixels + final_Theta_ll\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 7) Select samples\n",
    "    # ------------------------------------------------------------\n",
    "    num_samples_to_show = 4\n",
    "    N_train = U_ll_train.shape[0]\n",
    "    num_samples_to_show = min(num_samples_to_show, N_train)\n",
    "\n",
    "    np.random.seed(fold_index + 42)\n",
    "    sample_indices = np.random.choice(N_train, num_samples_to_show, replace=False)\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 8) Reshape helpers\n",
    "    # ------------------------------------------------------------\n",
    "    def reshape_and_rescale_for_plot(pixel_vector_neg1_1):\n",
    "        \"\"\"3072 -> (32,32,3), rescale [-1,1] -> [0,1].\"\"\"\n",
    "        if not isinstance(pixel_vector_neg1_1, torch.Tensor):\n",
    "            pixel_vector_neg1_1 = torch.tensor(pixel_vector_neg1_1)\n",
    "        pixel_vector_neg1_1 = pixel_vector_neg1_1.detach().cpu()\n",
    "\n",
    "        if pixel_vector_neg1_1.numel() != 3072:\n",
    "            return np.zeros((32, 32, 3))\n",
    "\n",
    "        img_chw = pixel_vector_neg1_1.view(3, 32, 32)\n",
    "        img_hwc = img_chw.permute(1, 2, 0).numpy()\n",
    "        img_rescaled = (img_hwc + 1.0) / 2.0\n",
    "        return np.clip(img_rescaled, 0.0, 1.0)\n",
    "\n",
    "    def reshape_for_plot(pixel_vector):\n",
    "        \"\"\"3072 -> (32,32,3), keep raw values (for U / Θ heatmaps).\"\"\"\n",
    "        if not isinstance(pixel_vector, torch.Tensor):\n",
    "            pixel_vector = torch.tensor(pixel_vector)\n",
    "        pixel_vector = pixel_vector.detach().cpu()\n",
    "        if pixel_vector.numel() != 3072:\n",
    "            return np.zeros((32, 32, 3))\n",
    "        img_chw = pixel_vector.view(3, 32, 32)\n",
    "        return img_chw.permute(1, 2, 0).numpy()\n",
    "\n",
    "    # ------------------------------------------------------------\n",
    "    # 9) Plot\n",
    "    # ------------------------------------------------------------\n",
    "    fig, axes = plt.subplots(num_samples_to_show, 5, figsize=(20, 4 * num_samples_to_show))\n",
    "    if num_samples_to_show == 1:\n",
    "        axes = np.array([axes])\n",
    "\n",
    "    fig.suptitle(\n",
    "        f\"Visualizing Adversarial Perturbations (pixels-only)\\n\"\n",
    "        f\"{diroca_key_toplevel}, {fold_key_to_show}, ε={epsilon_run}, δ=0\",\n",
    "        fontsize=16, y=1.02\n",
    "    )\n",
    "\n",
    "    for row_i, idx in enumerate(sample_indices):\n",
    "        # Col 1: Deterministic D\n",
    "        ax = axes[row_i, 0]\n",
    "        ax.imshow(reshape_and_rescale_for_plot(det_pixels_train[idx]))\n",
    "        ax.set_title(f\"Sample {idx}: D\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "        # Col 2: U noise\n",
    "        ax = axes[row_i, 1]\n",
    "        img_u = reshape_for_plot(U_ll_train[idx])\n",
    "        norm = np.max(np.abs(img_u)) if img_u.size > 0 else 0.1\n",
    "        ax.imshow(img_u, cmap=\"RdBu_r\", vmin=-norm-1e-6, vmax=norm+1e-6)\n",
    "        ax.set_title(f\"Sample {idx}: U\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "        # Col 3: Θ adversary\n",
    "        ax = axes[row_i, 2]\n",
    "        img_th = reshape_for_plot(final_Theta_ll[idx])\n",
    "        norm = np.max(np.abs(img_th)) if img_th.size > 0 else 0.1\n",
    "        ax.imshow(img_th, cmap=\"RdBu_r\", vmin=-norm-1e-6, vmax=norm+1e-6)\n",
    "        ax.set_title(f\"Sample {idx}: Θ\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "        # Col 4: Clean D+U\n",
    "        ax = axes[row_i, 3]\n",
    "        ax.imshow(reshape_and_rescale_for_plot(clean_recon_pixels[idx]))\n",
    "        ax.set_title(f\"Sample {idx}: Clean (D+U)\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "        # Col 5: Worst D+U+Θ\n",
    "        ax = axes[row_i, 4]\n",
    "        ax.imshow(reshape_and_rescale_for_plot(worst_case_pixels[idx]))\n",
    "        ax.set_title(f\"Sample {idx}: Worst (D+U+Θ)\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0.03, 1, 0.98])\n",
    "    plt.show()\n",
    "\n",
    "    print(\"✓ Visualization completed successfully!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "20449dba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ apply_huber_contamination_cmnist ready (pixels-only eval usage).\n"
     ]
    }
   ],
   "source": [
    "# === Helper: Huber-style pixel contamination (for evaluation) ===\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def apply_huber_contamination_cmnist(clean_data, alpha, noise_scale, noise_dims, seed=None, loc=0.0):\n",
    "    \"\"\"\n",
    "    Contaminate ONLY the specified dimensions (columns) of clean_data with Gaussian noise.\n",
    "\n",
    "    New setting reminder:\n",
    "      - We will call this ONLY on LL pixel blocks during evaluation.\n",
    "      - Digit/Color one-hots are NEVER passed here and never contaminated.\n",
    "\n",
    "    Args:\n",
    "        clean_data : (N, D) tensor or array\n",
    "        alpha      : fraction of samples to contaminate in [0,1]\n",
    "        noise_scale: std of Gaussian noise\n",
    "        noise_dims : slice / list / np.ndarray / torch.Tensor of column indices\n",
    "        seed       : RNG seed\n",
    "        loc        : mean shift of noise (0.0 = zero-mean)\n",
    "    \"\"\"\n",
    "    # --- Early exit (numerically safe) ---\n",
    "    if alpha is None or noise_scale is None or np.isclose(alpha, 0.0) or np.isclose(noise_scale, 0.0):\n",
    "        if isinstance(clean_data, torch.Tensor):\n",
    "            return clean_data.clone().to(torch.float32)\n",
    "        return torch.tensor(clean_data, dtype=torch.float32)\n",
    "\n",
    "    # --- Convert input to tensor ---\n",
    "    if isinstance(clean_data, torch.Tensor):\n",
    "        data_cont = clean_data.to(torch.float32).clone()\n",
    "    else:\n",
    "        data_cont = torch.tensor(clean_data, dtype=torch.float32).clone()\n",
    "    device = data_cont.device\n",
    "    N, D = data_cont.shape\n",
    "\n",
    "    # --- Build index tensor for noise_dims ---\n",
    "    if isinstance(noise_dims, slice):\n",
    "        start = 0 if noise_dims.start is None else noise_dims.start\n",
    "        stop  = D if noise_dims.stop is None else noise_dims.stop\n",
    "        step  = 1 if noise_dims.step is None else noise_dims.step\n",
    "        noise_idx = torch.arange(start, stop, step, device=device)\n",
    "    elif isinstance(noise_dims, (list, tuple, np.ndarray, torch.Tensor)):\n",
    "        noise_idx = torch.as_tensor(noise_dims, dtype=torch.long, device=device)\n",
    "    else:\n",
    "        raise TypeError(f\"Unsupported type for noise_dims: {type(noise_dims)}\")\n",
    "\n",
    "    # --- Keep only valid indices ---\n",
    "    noise_idx = noise_idx[(noise_idx >= 0) & (noise_idx < D)]\n",
    "    if noise_idx.numel() == 0:\n",
    "        return data_cont\n",
    "\n",
    "    # --- Extract sub-matrix to contaminate ---\n",
    "    data_to_noise = data_cont.index_select(dim=1, index=noise_idx)  # (N, |noise_idx|)\n",
    "\n",
    "    # --- Sample Gaussian noise ---\n",
    "    rng = np.random.default_rng(seed)\n",
    "    noise = rng.normal(loc=loc, scale=noise_scale, size=tuple(data_to_noise.shape)).astype(np.float32)\n",
    "    noise_tensor = torch.tensor(noise, dtype=torch.float32, device=device)\n",
    "\n",
    "    noisy_slice = data_to_noise + noise_tensor\n",
    "\n",
    "    # --- Full contamination ---\n",
    "    if alpha >= 1.0:\n",
    "        data_cont[:, noise_idx] = noisy_slice\n",
    "        return data_cont\n",
    "\n",
    "    # --- Partial contamination ---\n",
    "    n_contaminate = int(alpha * N)\n",
    "    if n_contaminate <= 0:\n",
    "        return data_cont\n",
    "\n",
    "    idx_to_contaminate_np = rng.choice(N, size=n_contaminate, replace=False)\n",
    "    idx_to_contaminate = torch.as_tensor(idx_to_contaminate_np, dtype=torch.long, device=device)\n",
    "\n",
    "    # Replace only selected rows and selected columns\n",
    "    data_cont.index_copy_(\n",
    "        dim=0,\n",
    "        index=idx_to_contaminate,\n",
    "        source=data_cont.index_select(0, idx_to_contaminate).scatter(\n",
    "            1,\n",
    "            noise_idx.view(1, -1).expand(n_contaminate, -1),\n",
    "            noisy_slice.index_select(0, idx_to_contaminate)\n",
    "        )\n",
    "    )\n",
    "    return data_cont\n",
    "\n",
    "print(\"✅ apply_huber_contamination_cmnist ready (pixels-only eval usage).\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6080eeeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Helper: empirical abstraction error (pixels -> z & full) ===\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def calculate_empirical_error_flat(T_matrix, Dll_test_flat, Dhl_test):\n",
    "    \"\"\"\n",
    "    Abstraction error (Frobenius MSE) in the NEW setting.\n",
    "\n",
    "    NEW default:\n",
    "      - T maps pixels -> z only.\n",
    "      - LL labels (digit/color one-hots) are copied 1-1 to HL labels.\n",
    "      - We evaluate against full HL vector [labels(20) | z(d_z)].\n",
    "\n",
    "    Fallback:\n",
    "      - If T is full (84 x 3092), use old full-linear evaluation.\n",
    "\n",
    "    Args:\n",
    "        T_matrix      : torch.Tensor or np.ndarray\n",
    "                       Either (d_z, 3072) or (84, 3092)\n",
    "        Dll_test_flat : (N, 3072) OR (N, 3092)  [pixels | LL labels]\n",
    "        Dhl_test      : (N, 64)   OR (N, 84)    [HL labels | z]\n",
    "    \"\"\"\n",
    "    try:\n",
    "        # --- to tensors ---\n",
    "        T_matrix = T_matrix if isinstance(T_matrix, torch.Tensor) else torch.tensor(T_matrix, dtype=torch.float32)\n",
    "        Dll_test_flat = Dll_test_flat if isinstance(Dll_test_flat, torch.Tensor) else torch.tensor(Dll_test_flat, dtype=torch.float32)\n",
    "        Dhl_test = Dhl_test if isinstance(Dhl_test, torch.Tensor) else torch.tensor(Dhl_test, dtype=torch.float32)\n",
    "\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        T_matrix = T_matrix.to(device)\n",
    "        Dll_test_flat = Dll_test_flat.to(device)\n",
    "        Dhl_test = Dhl_test.to(device)\n",
    "\n",
    "        N = Dll_test_flat.shape[0]\n",
    "\n",
    "        # ============================================================\n",
    "        # Case A (NEW): T is z-only: (d_z, 3072)\n",
    "        # ============================================================\n",
    "        if T_matrix.shape[1] == 3072:\n",
    "            d_z = T_matrix.shape[0]\n",
    "\n",
    "            # Dll may be pixels-only or pixels+labels\n",
    "            if Dll_test_flat.shape[1] == 3072:\n",
    "                ll_pixels = Dll_test_flat\n",
    "                ll_labels = None\n",
    "            elif Dll_test_flat.shape[1] == 3092:\n",
    "                ll_pixels = Dll_test_flat[:, :3072]\n",
    "                ll_labels = Dll_test_flat[:, 3072:]  # (N,20)\n",
    "            else:\n",
    "                return float(\"inf\")\n",
    "\n",
    "            # Dhl may be z-only or labels+z\n",
    "            if Dhl_test.shape[1] == d_z:\n",
    "                hl_labels = ll_labels  # if present, else None\n",
    "                hl_z = Dhl_test\n",
    "            elif Dhl_test.shape[1] == 20 + d_z:\n",
    "                hl_labels = Dhl_test[:, :20]\n",
    "                hl_z = Dhl_test[:, 20:]\n",
    "            else:\n",
    "                return float(\"inf\")\n",
    "\n",
    "            with torch.no_grad():\n",
    "                z_pred = ll_pixels @ T_matrix.T  # (N, d_z)\n",
    "\n",
    "                if hl_labels is None:\n",
    "                    # compare only z\n",
    "                    diff = z_pred - hl_z\n",
    "                else:\n",
    "                    pred_full = torch.cat([hl_labels, z_pred], dim=1)\n",
    "                    true_full = torch.cat([hl_labels, hl_z], dim=1)\n",
    "                    diff = pred_full - true_full\n",
    "\n",
    "                err = torch.norm(diff, p=\"fro\") ** 2 / max(1, N)\n",
    "            return float(err.item())\n",
    "\n",
    "        # ============================================================\n",
    "        # Case B (OLD/FULL): T is full map (84, 3092)\n",
    "        # ============================================================\n",
    "        if T_matrix.shape[1] == Dll_test_flat.shape[1] and T_matrix.shape[0] == Dhl_test.shape[1]:\n",
    "            with torch.no_grad():\n",
    "                Dhl_pred = Dll_test_flat @ T_matrix.T\n",
    "                diff = Dhl_pred - Dhl_test\n",
    "                err = torch.norm(diff, p=\"fro\") ** 2 / max(1, N)\n",
    "            return float(err.item())\n",
    "\n",
    "        # If nothing matched, shape mismatch\n",
    "        return float(\"inf\")\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error in calculate_empirical_error_flat: {e}\")\n",
    "        return float(\"inf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af997a36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--- Starting Evaluation: Clean & Huber-Noise (Cases 1 & 2) ---\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'all_results' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 18\u001b[39m\n\u001b[32m     16\u001b[39m \u001b[38;5;66;03m# Count total configs for progress bar\u001b[39;00m\n\u001b[32m     17\u001b[39m total_methods_trained = \u001b[32m0\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m method_group_key, method_data_inner \u001b[38;5;129;01min\u001b[39;00m \u001b[43mall_results\u001b[49m.items():\n\u001b[32m     19\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(method_data_inner, \u001b[38;5;28mdict\u001b[39m):\n\u001b[32m     20\u001b[39m         \u001b[38;5;28;01mfor\u001b[39;00m fold_key, fold_data \u001b[38;5;129;01min\u001b[39;00m method_data_inner.items():\n",
      "\u001b[31mNameError\u001b[39m: name 'all_results' is not defined"
     ]
    }
   ],
   "source": [
    "# === Evaluation: Clean & Huber noise (Cases 1 & 2) ===\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import gc\n",
    "\n",
    "print(\"\\n--- Starting Evaluation: Clean & Huber-Noise (Cases 1 & 2) ---\")\n",
    "\n",
    "N_TRIALS = 5\n",
    "NOISE_SCALE_FOR_ALPHA1 = 0.5   # noise level for alpha=1\n",
    "ALPHA_VALUES_TO_TEST = [0.0, 1.0]  # 0 -> clean, 1 -> fully noisy\n",
    "\n",
    "evaluation_records = []\n",
    "\n",
    "# Count total configs for progress bar\n",
    "total_methods_trained = 0\n",
    "for method_group_key, method_data_inner in all_results.items():\n",
    "    if isinstance(method_data_inner, dict):\n",
    "        for fold_key, fold_data in method_data_inner.items():\n",
    "            if isinstance(fold_data, dict) and 'error' not in fold_data:\n",
    "                total_methods_trained += len(fold_data)\n",
    "\n",
    "total_configs = total_methods_trained * len(ALPHA_VALUES_TO_TEST) * N_TRIALS\n",
    "\n",
    "if total_configs == 0:\n",
    "    print(\"Error: No valid training results found in 'all_results'.\")\n",
    "else:\n",
    "    pbar_eval = tqdm(total=total_configs, desc=\"Evaluating (clean + huber)\")\n",
    "\n",
    "    LL_PIXEL_DIMS = slice(0, 3072)  # flattened RGB pixels\n",
    "\n",
    "    for method_group_key, method_results_inner in all_results.items():\n",
    "        if not isinstance(method_results_inner, dict):\n",
    "            continue\n",
    "\n",
    "        for fold_key, fold_data in method_results_inner.items():\n",
    "            if not fold_key.startswith('fold_'):\n",
    "                continue\n",
    "            if 'error' in fold_data:\n",
    "                continue\n",
    "\n",
    "            fold_idx = int(fold_key.split('_')[-1])\n",
    "\n",
    "            for run_key, run_result in fold_data.items():\n",
    "                if 'error' in run_result or run_result.get('T_matrix') is None:\n",
    "                    pbar_eval.update(len(ALPHA_VALUES_TO_TEST) * N_TRIALS)\n",
    "                    continue\n",
    "\n",
    "                T_matrix = run_result['T_matrix']\n",
    "                test_idx = run_result['test_indices']\n",
    "                if test_idx is None:\n",
    "                    pbar_eval.update(len(ALPHA_VALUES_TO_TEST) * N_TRIALS)\n",
    "                    continue\n",
    "\n",
    "                # Human-readable method name\n",
    "                if method_group_key.startswith('diroca_'):\n",
    "                    eval_method_name = f\"DiRoCA ({run_key})\"\n",
    "                elif method_group_key == 'gradca':\n",
    "                    eval_method_name = \"GradCA\"\n",
    "                elif method_group_key == 'baryca':\n",
    "                    eval_method_name = \"BaryCA\"\n",
    "                elif method_group_key == 'abslingam':\n",
    "                    eval_method_name = run_key\n",
    "                else:\n",
    "                    eval_method_name = f\"{method_group_key}_{run_key}\"\n",
    "\n",
    "                for alpha in ALPHA_VALUES_TO_TEST:\n",
    "                    noise_scale = 0.0 if np.isclose(alpha, 0.0) else NOISE_SCALE_FOR_ALPHA1\n",
    "                    loc_ll = 0.0  # label-preserving pixel noise (mean 0)\n",
    "\n",
    "                    for trial in range(N_TRIALS):\n",
    "                        trial_errors = []\n",
    "\n",
    "                        for iota, eta in list(omega.items()):\n",
    "                            try:\n",
    "                                if iota not in Dll_samples or eta not in Dhl_samples:\n",
    "                                    continue\n",
    "\n",
    "                                ll_images_01, _, ll_digits, ll_colors = Dll_samples[iota]\n",
    "                                max_idx = max(test_idx) if len(test_idx) > 0 else -1\n",
    "                                if max_idx >= len(ll_images_01):\n",
    "                                    continue\n",
    "\n",
    "                                # --- Test split ---\n",
    "                                ll_images_test_01 = ll_images_01[test_idx]      # (N_test,3,32,32) in [0,1]\n",
    "                                ll_digits_test    = ll_digits[test_idx]\n",
    "                                ll_colors_test    = ll_colors[test_idx]\n",
    "                                Dhl_test_clean    = Dhl_samples[eta][test_idx]  # (N_test,84) clean HL\n",
    "\n",
    "                                seed = hash((fold_idx, run_key, float(alpha),\n",
    "                                             float(noise_scale), trial, str(iota))) % (2**32)\n",
    "\n",
    "                                # Rescale pixels to tanh space [-1,1]\n",
    "                                ll_images_test_tanh = ll_images_test_01 * 2.0 - 1.0\n",
    "\n",
    "                                # flatten images and contaminate only pixels\n",
    "                                ll_images_test_flat = ll_images_test_tanh.view(ll_images_test_tanh.shape[0], -1)\n",
    "                                ll_images_cont_flat = apply_huber_contamination_cmnist(\n",
    "                                    ll_images_test_flat, alpha, noise_scale,\n",
    "                                    noise_dims=LL_PIXEL_DIMS, seed=seed, loc=loc_ll\n",
    "                                )\n",
    "\n",
    "                                # one-hot digits & colors (never contaminated)\n",
    "                                ll_digits_onehot = F.one_hot(ll_digits_test, num_classes=10).float()\n",
    "                                ll_colors_onehot = F.one_hot(ll_colors_test, num_classes=10).float()\n",
    "\n",
    "                                device = ll_images_cont_flat.device\n",
    "                                Dll_test_cont_flat_full = torch.cat(\n",
    "                                    [ll_images_cont_flat,\n",
    "                                     ll_digits_onehot.to(device),\n",
    "                                     ll_colors_onehot.to(device)],\n",
    "                                    dim=1\n",
    "                                )  # (N_test, 3092)\n",
    "\n",
    "                                # High-level stays clean (labels + z)\n",
    "                                Dhl_test = Dhl_test_clean.to(device)\n",
    "\n",
    "                                error = calculate_empirical_error_flat(\n",
    "                                    T_matrix, Dll_test_cont_flat_full, Dhl_test\n",
    "                                )\n",
    "                                if not np.isnan(error) and error != float('inf'):\n",
    "                                    trial_errors.append(error)\n",
    "\n",
    "                            except Exception as e:\n",
    "                                print(f\"ERROR inner loop: {e} | \"\n",
    "                                      f\"Context: M{eval_method_name}, F{fold_idx}, \"\n",
    "                                      f\"R{run_key}, A{alpha}, N{noise_scale}, \"\n",
    "                                      f\"T{trial}, Iota{iota}\")\n",
    "                                trial_errors.append(np.nan)\n",
    "\n",
    "                        record = {\n",
    "                            'shift_type': 'huber_noise',\n",
    "                            'method': eval_method_name,\n",
    "                            'fold': fold_idx,\n",
    "                            'alpha': float(alpha),\n",
    "                            'noise_scale': float(noise_scale),\n",
    "                            'trial': trial,\n",
    "                            'error': float(np.nanmean(trial_errors)) if trial_errors else np.nan,\n",
    "                        }\n",
    "                        if method_group_key.startswith('diroca_'):\n",
    "                            record['train_epsilon'] = run_result.get('epsilon', np.nan)\n",
    "                            record['train_delta']   = run_result.get('delta', np.nan)\n",
    "\n",
    "                        evaluation_records.append(record)\n",
    "                        pbar_eval.update(1)\n",
    "\n",
    "                        # cleanup\n",
    "                        del (ll_images_test_01, ll_images_test_tanh, ll_digits_test, ll_colors_test,\n",
    "                             Dhl_test_clean, ll_images_test_flat, ll_images_cont_flat,\n",
    "                             ll_digits_onehot, ll_colors_onehot, Dll_test_cont_flat_full)\n",
    "                        if 'error' in locals():\n",
    "                            del error\n",
    "                        if 'trial_errors' in locals():\n",
    "                            trial_errors[:] = []\n",
    "                            del trial_errors\n",
    "\n",
    "    pbar_eval.close()\n",
    "\n",
    "    full_results_df = pd.DataFrame(evaluation_records)\n",
    "    eval_output_path = os.path.join(output_dir, \"cmnist_eval_clean_and_huber.pkl\")\n",
    "    full_results_df.to_pickle(eval_output_path)\n",
    "    print(f\"\\nEvaluation results saved to {eval_output_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8062928f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ apply_camera_shifts_cmnist & camera_transform ready.\n"
     ]
    }
   ],
   "source": [
    "# === Helper: camera/augmentation shifts ===\n",
    "import numpy as np\n",
    "import torchvision.transforms as T\n",
    "import torch\n",
    "\n",
    "def apply_camera_shifts_cmnist(images_01, alpha, transform, seed=None):\n",
    "    \"\"\"\n",
    "    Apply a camera/augmentation transform to a fraction alpha of images.\n",
    "\n",
    "    INPUT:\n",
    "        images_01 : (N, C, H, W) tensor in [0,1]  (as stored in Dll_samples)\n",
    "    OUTPUT:\n",
    "        images_tanh_aug : (N, C, H, W) in tanh space [-1,1]\n",
    "                          matching the optimization pipeline.\n",
    "    \"\"\"\n",
    "    # alpha = 0 → no augmentation, but STILL convert to tanh\n",
    "    if alpha <= 0.0:\n",
    "        return images_01 * 2.0 - 1.0\n",
    "\n",
    "    images_aug = images_01.clone()\n",
    "    N = images_01.shape[0]\n",
    "    rng = np.random.default_rng(seed)\n",
    "    n_aug = int(alpha * N)\n",
    "\n",
    "    if n_aug == 0:\n",
    "        return images_01 * 2.0 - 1.0\n",
    "\n",
    "    idx_all = np.arange(N)\n",
    "    idx_aug = rng.choice(idx_all, size=n_aug, replace=False)\n",
    "\n",
    "    for j in idx_aug:\n",
    "        # torchvision transforms expect input in [0,1]\n",
    "        images_aug[j] = transform(images_aug[j])\n",
    "\n",
    "    # Convert augmented images to tanh space [-1,1]\n",
    "    images_tanh_aug = images_aug * 2.0 - 1.0\n",
    "    return images_tanh_aug\n",
    "\n",
    "camera_transform = T.Compose([\n",
    "    T.RandomAffine(\n",
    "        degrees=10,              # small rotation\n",
    "        translate=(0.1, 0.1),    # small translation\n",
    "        scale=(0.9, 1.1)         # mild zoom\n",
    "    ),\n",
    "    T.ColorJitter(\n",
    "        brightness=0.2,\n",
    "        contrast=0.2,\n",
    "        saturation=0.1\n",
    "    ),\n",
    "    # Optionally: T.GaussianBlur(3, sigma=(0.1, 1.0)),\n",
    "])\n",
    "\n",
    "print(\"✅ apply_camera_shifts_cmnist & camera_transform ready.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11dbb1fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--- Starting Evaluation: Camera/Augmentation Shifts (Case 3) ---\n",
      "Error: No valid training results found in 'all_results'.\n"
     ]
    }
   ],
   "source": [
    "# === Evaluation: Camera/Augmentation Shifts (Case 3) ===\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "print(\"\\n--- Starting Evaluation: Camera/Augmentation Shifts (Case 3) ---\")\n",
    "\n",
    "N_TRIALS = 5\n",
    "ALPHA_VALUES_CAMERA = [0.0, 1.0]  # 0 -> clean, 1 -> all images augmented\n",
    "\n",
    "eval_records_camera = []\n",
    "\n",
    "if total_methods_trained == 0:\n",
    "    print(\"Error: No valid training results found in 'all_results'.\")\n",
    "else:\n",
    "    pbar_cam = tqdm(\n",
    "        total=total_methods_trained * len(ALPHA_VALUES_CAMERA) * N_TRIALS,\n",
    "        desc=\"Evaluating (camera shifts)\"\n",
    "    )\n",
    "\n",
    "    for method_group_key, method_results_inner in all_results.items():\n",
    "        if not isinstance(method_results_inner, dict):\n",
    "            continue\n",
    "\n",
    "        for fold_key, fold_data in method_results_inner.items():\n",
    "            if not fold_key.startswith('fold_'):\n",
    "                continue\n",
    "            if 'error' in fold_data:\n",
    "                continue\n",
    "\n",
    "            fold_idx = int(fold_key.split('_')[-1])\n",
    "\n",
    "            for run_key, run_result in fold_data.items():\n",
    "                if 'error' in run_result or run_result.get('T_matrix') is None:\n",
    "                    pbar_cam.update(len(ALPHA_VALUES_CAMERA) * N_TRIALS)\n",
    "                    continue\n",
    "\n",
    "                T_matrix = run_result['T_matrix']\n",
    "                test_idx = run_result['test_indices']\n",
    "                if test_idx is None:\n",
    "                    pbar_cam.update(len(ALPHA_VALUES_CAMERA) * N_TRIALS)\n",
    "                    continue\n",
    "\n",
    "                # Human-readable method name\n",
    "                if method_group_key.startswith('diroca_'):\n",
    "                    eval_method_name = f\"DiRoCA ({run_key})\"\n",
    "                elif method_group_key == 'gradca':\n",
    "                    eval_method_name = \"GradCA\"\n",
    "                elif method_group_key == 'baryca':\n",
    "                    eval_method_name = \"BaryCA\"\n",
    "                elif method_group_key == 'abslingam':\n",
    "                    eval_method_name = run_key\n",
    "                else:\n",
    "                    eval_method_name = f\"{method_group_key}_{run_key}\"\n",
    "\n",
    "                # Determine view\n",
    "                if isinstance(T_matrix, torch.Tensor):\n",
    "                    T_shape = tuple(T_matrix.shape)\n",
    "                else:\n",
    "                    T_shape = tuple(np.asarray(T_matrix).shape)\n",
    "\n",
    "                is_opt_view  = (T_shape == (64, 3072))\n",
    "                is_full_view = (T_shape == (84, 3092))\n",
    "\n",
    "                if not (is_opt_view or is_full_view):\n",
    "                    print(f\"[Warning] Unexpected T shape {T_shape} for {eval_method_name}. Skipping.\")\n",
    "                    pbar_cam.update(len(ALPHA_VALUES_CAMERA) * N_TRIALS)\n",
    "                    continue\n",
    "\n",
    "                for alpha_cam in ALPHA_VALUES_CAMERA:\n",
    "                    for trial in range(N_TRIALS):\n",
    "                        trial_errors = []\n",
    "\n",
    "                        for iota, eta in list(omega.items()):\n",
    "                            try:\n",
    "                                if iota not in Dll_samples or eta not in Dhl_samples:\n",
    "                                    continue\n",
    "\n",
    "                                ll_images, _, ll_digits, ll_colors = Dll_samples[iota]\n",
    "                                max_idx = max(test_idx) if len(test_idx) > 0 else -1\n",
    "                                if max_idx >= len(ll_images):\n",
    "                                    continue\n",
    "\n",
    "                                # LL images are in [0,1]\n",
    "                                ll_images_test = ll_images[test_idx]          # (N,C,H,W) in [0,1]\n",
    "                                ll_digits_test = ll_digits[test_idx]\n",
    "                                ll_colors_test = ll_colors[test_idx]\n",
    "\n",
    "                                # HL full vector, but we only want z-block\n",
    "                                Dhl_test_full = Dhl_samples[eta][test_idx]   # (N,84)\n",
    "                                Dhl_test_z    = Dhl_test_full[:, 20:]        # (N,64)\n",
    "\n",
    "                                seed = hash((fold_idx, run_key, float(alpha_cam),\n",
    "                                             trial, str(iota))) % (2**32)\n",
    "\n",
    "                                # 1) apply camera shifts in [0,1] and convert to tanh\n",
    "                                ll_images_shifted_tanh = apply_camera_shifts_cmnist(\n",
    "                                    ll_images_test, alpha_cam, camera_transform, seed=seed\n",
    "                                )  # (N,C,H,W) in [-1,1]\n",
    "\n",
    "                                # 2) flatten pixels only (tanh space)\n",
    "                                ll_pixels_shifted_flat = ll_images_shifted_tanh.view(\n",
    "                                    ll_images_shifted_tanh.shape[0], -1\n",
    "                                )  # (N,3072)\n",
    "\n",
    "                                device = ll_pixels_shifted_flat.device\n",
    "\n",
    "                                # Build LL input and T block depending on view\n",
    "                                if is_opt_view:\n",
    "                                    Dll_input = ll_pixels_shifted_flat        # (N,3072)\n",
    "                                    T_use     = T_matrix                      # (64,3072)\n",
    "                                    Dhl_target = Dhl_test_z.to(device)        # (N,64)\n",
    "                                else:\n",
    "                                    if not isinstance(T_matrix, torch.Tensor):\n",
    "                                        T_matrix = torch.tensor(T_matrix, dtype=torch.float32)\n",
    "                                    T_use = T_matrix[20:, :3072]             # (64,3072)\n",
    "\n",
    "                                    Dll_input  = ll_pixels_shifted_flat      # (N,3072)\n",
    "                                    Dhl_target = Dhl_test_z.to(device)       # (N,64)\n",
    "\n",
    "                                # Compute error on z only\n",
    "                                error = calculate_empirical_error_flat(\n",
    "                                    T_use, Dll_input, Dhl_target\n",
    "                                )\n",
    "                                if not np.isnan(error) and error != float('inf'):\n",
    "                                    trial_errors.append(error)\n",
    "\n",
    "                            except Exception as e:\n",
    "                                print(f\"ERROR (camera) inner loop: {e} | \"\n",
    "                                      f\"Context: M{eval_method_name}, F{fold_idx}, \"\n",
    "                                      f\"R{run_key}, A{alpha_cam}, T{trial}, Iota{iota}\")\n",
    "                                trial_errors.append(np.nan)\n",
    "\n",
    "                        record = {\n",
    "                            'shift_type': 'camera_aug',\n",
    "                            'method': eval_method_name,\n",
    "                            'fold': fold_idx,\n",
    "                            'alpha': float(alpha_cam),\n",
    "                            'trial': trial,\n",
    "                            'error': float(np.nanmean(trial_errors)) if trial_errors else np.nan,\n",
    "                        }\n",
    "                        if method_group_key.startswith('diroca_'):\n",
    "                            record['train_epsilon'] = run_result.get('epsilon', np.nan)\n",
    "                            record['train_delta']   = run_result.get('delta', np.nan)\n",
    "\n",
    "                        eval_records_camera.append(record)\n",
    "                        pbar_cam.update(1)\n",
    "\n",
    "                        # cleanup\n",
    "                        del (ll_images_test, ll_digits_test, ll_colors_test,\n",
    "                             Dhl_test_full, Dhl_test_z,\n",
    "                             ll_images_shifted_tanh, ll_pixels_shifted_flat)\n",
    "                        if 'error' in locals():\n",
    "                            del error\n",
    "                        if 'trial_errors' in locals():\n",
    "                            trial_errors[:] = []\n",
    "                            del trial_errors\n",
    "\n",
    "    pbar_cam.close()\n",
    "\n",
    "    df_cam = pd.DataFrame(eval_records_camera)\n",
    "    cam_output_path = os.path.join(output_dir, \"cmnist_eval_camera_shifts.pkl\")\n",
    "    df_cam.to_pickle(cam_output_path)\n",
    "    print(f\"\\nCamera-shift evaluation results saved to {cam_output_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dbc1ab94",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "df_cam not found. Run the camera evaluation cell first.",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mValueError\u001b[39m                                Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m      2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m      4\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m'\u001b[39m\u001b[33mdf_cam\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mglobals\u001b[39m():\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mdf_cam not found. Run the camera evaluation cell first.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m      7\u001b[39m \u001b[38;5;66;03m# Filter rows for clean (alpha=0) and fully shifted (alpha=1)\u001b[39;00m\n\u001b[32m      8\u001b[39m df_cam_clean = df_cam[\n\u001b[32m      9\u001b[39m     (df_cam[\u001b[33m'\u001b[39m\u001b[33mshift_type\u001b[39m\u001b[33m'\u001b[39m] == \u001b[33m'\u001b[39m\u001b[33mcamera_aug\u001b[39m\u001b[33m'\u001b[39m) &\n\u001b[32m     10\u001b[39m     (df_cam[\u001b[33m'\u001b[39m\u001b[33malpha\u001b[39m\u001b[33m'\u001b[39m].abs() < \u001b[32m1e-8\u001b[39m)\n\u001b[32m     11\u001b[39m ]\n",
      "\u001b[31mValueError\u001b[39m: df_cam not found. Run the camera evaluation cell first."
     ]
    }
   ],
   "source": [
    "# === Summary: camera-shift evaluation ===\n",
    "import numpy as np\n",
    "\n",
    "if 'df_cam' not in globals():\n",
    "    raise ValueError(\"df_cam not found. Run the camera evaluation cell first.\")\n",
    "\n",
    "# Filter rows for clean (alpha=0) and fully shifted (alpha=1)\n",
    "df_cam_clean = df_cam[\n",
    "    (df_cam['shift_type'] == 'camera_aug') &\n",
    "    (df_cam['alpha'].abs() < 1e-8)\n",
    "]\n",
    "\n",
    "df_cam_shift = df_cam[\n",
    "    (df_cam['shift_type'] == 'camera_aug') &\n",
    "    (df_cam['alpha'].round(3) == 1.0)\n",
    "]\n",
    "\n",
    "# Compute summaries\n",
    "summary_cam_clean = (\n",
    "    df_cam_clean.groupby('method')['error']\n",
    "    .agg(['mean', 'std'])\n",
    "    .sort_values('mean')\n",
    ")\n",
    "\n",
    "summary_cam_shift = (\n",
    "    df_cam_shift.groupby('method')['error']\n",
    "    .agg(['mean', 'std'])\n",
    "    .sort_values('mean')\n",
    ")\n",
    "\n",
    "print(\"\\n================ CAMERA SHIFT EVALUATION SUMMARY ================\\n\")\n",
    "\n",
    "print(\"--- Camera: Clean (alpha = 0) ---\")\n",
    "if len(summary_cam_clean) == 0:\n",
    "    print(\"No clean (alpha=0) results found.\")\n",
    "else:\n",
    "    print(summary_cam_clean)\n",
    "\n",
    "print(\"\\n--- Camera: Fully Shifted (alpha = 1) ---\")\n",
    "if len(summary_cam_shift) == 0:\n",
    "    print(\"No shifted (alpha=1) results found.\")\n",
    "else:\n",
    "    print(summary_cam_shift)\n",
    "\n",
    "print(\"\\n===============================================================\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "65e16fe4",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'Dll_samples' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m      3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n\u001b[32m      5\u001b[39m \u001b[38;5;66;03m# pick observational environment explicitly\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m iota_obs = \u001b[33m\"\u001b[39m\u001b[33mobs\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mobs\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[43mDll_samples\u001b[49m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(Dll_samples.keys())[\u001b[32m0\u001b[39m]\n\u001b[32m      7\u001b[39m ll_images, _, ll_digits, ll_colors = Dll_samples[iota_obs]\n\u001b[32m      9\u001b[39m \u001b[38;5;66;03m# pick some indices\u001b[39;00m\n",
      "\u001b[31mNameError\u001b[39m: name 'Dll_samples' is not defined"
     ]
    }
   ],
   "source": [
    "# === Visual sanity check: Huber contamination on pixels ===\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# pick observational environment explicitly\n",
    "iota_obs = \"obs\" if \"obs\" in Dll_samples else list(Dll_samples.keys())[0]\n",
    "ll_images, _, ll_digits, ll_colors = Dll_samples[iota_obs]\n",
    "\n",
    "# pick some indices\n",
    "idx = torch.arange(16)  # first 16 examples\n",
    "imgs = ll_images[idx]   # (16, C, H, W)\n",
    "\n",
    "# flatten for Huber (pixels only)\n",
    "imgs_flat = imgs.view(imgs.shape[0], -1)\n",
    "d = imgs_flat.shape[1]\n",
    "\n",
    "imgs_huber_flat = apply_huber_contamination_cmnist(\n",
    "    imgs_flat,\n",
    "    alpha=1.0,\n",
    "    noise_scale=0.5,\n",
    "    noise_dims=slice(0, d),   # contaminate ONLY pixels\n",
    "    seed=0,\n",
    "    loc=0.0\n",
    ")\n",
    "\n",
    "imgs_huber = imgs_huber_flat.view_as(imgs)  # back to (N,C,H,W)\n",
    "\n",
    "# plot original vs huber\n",
    "n_show = 8\n",
    "fig, axes = plt.subplots(2, n_show, figsize=(2*n_show, 4))\n",
    "\n",
    "for j in range(n_show):\n",
    "    # original\n",
    "    ax = axes[0, j]\n",
    "    img = imgs[j].detach().cpu()\n",
    "    if img.shape[0] == 1:\n",
    "        ax.imshow(img.squeeze(0).clamp(0, 1), cmap=\"gray\")\n",
    "    else:\n",
    "        ax.imshow(img.permute(1, 2, 0).clamp(0, 1))\n",
    "    ax.axis('off')\n",
    "    ax.set_title(f\"orig #{j}\")\n",
    "\n",
    "    # huber\n",
    "    ax = axes[1, j]\n",
    "    img_h = imgs_huber[j].detach().cpu()\n",
    "    if img_h.shape[0] == 1:\n",
    "        ax.imshow(img_h.squeeze(0).clamp(0, 1), cmap=\"gray\")\n",
    "    else:\n",
    "        ax.imshow(img_h.permute(1, 2, 0).clamp(0, 1))\n",
    "    ax.axis('off')\n",
    "    ax.set_title(f\"huber #{j}\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3eaeeec6",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'imgs' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 7\u001b[39m\n\u001b[32m      3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m      5\u001b[39m \u001b[38;5;66;03m# reuse imgs from previous cell (ll_images[idx])\u001b[39;00m\n\u001b[32m      6\u001b[39m imgs_cam = apply_camera_shifts_cmnist(\n\u001b[32m----> \u001b[39m\u001b[32m7\u001b[39m     \u001b[43mimgs\u001b[49m,\n\u001b[32m      8\u001b[39m     alpha=\u001b[32m1.0\u001b[39m,              \u001b[38;5;66;03m# apply to all 16\u001b[39;00m\n\u001b[32m      9\u001b[39m     transform=camera_transform,\n\u001b[32m     10\u001b[39m     seed=\u001b[32m0\u001b[39m\n\u001b[32m     11\u001b[39m )\n\u001b[32m     13\u001b[39m n_show = \u001b[32m8\u001b[39m\n\u001b[32m     14\u001b[39m fig, axes = plt.subplots(\u001b[32m2\u001b[39m, n_show, figsize=(\u001b[32m2\u001b[39m*n_show, \u001b[32m4\u001b[39m))\n",
      "\u001b[31mNameError\u001b[39m: name 'imgs' is not defined"
     ]
    }
   ],
   "source": [
    "# === Visual sanity check: camera/augmentation shifts ===\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "# reuse imgs from previous cell (ll_images[idx])\n",
    "imgs_cam = apply_camera_shifts_cmnist(\n",
    "    imgs,\n",
    "    alpha=1.0,              # apply to all 16\n",
    "    transform=camera_transform,\n",
    "    seed=0\n",
    ")\n",
    "\n",
    "n_show = 8\n",
    "fig, axes = plt.subplots(2, n_show, figsize=(2*n_show, 4))\n",
    "\n",
    "for j in range(n_show):\n",
    "    # original\n",
    "    ax = axes[0, j]\n",
    "    img = imgs[j].detach().cpu()\n",
    "    if img.shape[0] == 1:\n",
    "        ax.imshow(img.squeeze(0).clamp(0, 1), cmap=\"gray\")\n",
    "    else:\n",
    "        ax.imshow(img.permute(1, 2, 0).clamp(0, 1))\n",
    "    ax.axis('off')\n",
    "    ax.set_title(f\"orig #{j}\")\n",
    "\n",
    "    # camera-shifted (note: in tanh space [-1,1])\n",
    "    ax = axes[1, j]\n",
    "    img_c = imgs_cam[j].detach().cpu()\n",
    "    # for display, bring back to [0,1]\n",
    "    img_c_disp = (img_c + 1.0) / 2.0\n",
    "    if img_c_disp.shape[0] == 1:\n",
    "        ax.imshow(img_c_disp.squeeze(0).clamp(0, 1), cmap=\"gray\")\n",
    "    else:\n",
    "        ax.imshow(img_c_disp.permute(1, 2, 0).clamp(0, 1))\n",
    "    ax.axis('off')\n",
    "    ax.set_title(f\"cam #{j}\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5082f7f8",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'all_results' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 33\u001b[39m\n\u001b[32m     29\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m T[:, :\u001b[32m3072\u001b[39m]\n\u001b[32m     31\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m T  \u001b[38;5;66;03m# last resort\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m method_group_key, folds \u001b[38;5;129;01min\u001b[39;00m \u001b[43mall_results\u001b[49m.items():\n\u001b[32m     34\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(folds, \u001b[38;5;28mdict\u001b[39m):\n\u001b[32m     35\u001b[39m         \u001b[38;5;28;01mcontinue\u001b[39;00m\n",
      "\u001b[31mNameError\u001b[39m: name 'all_results' is not defined"
     ]
    }
   ],
   "source": [
    "# === Norm statistics for each learned T (pixel->z) ===\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "records = []\n",
    "\n",
    "def extract_T_pix_to_z(T):\n",
    "    \"\"\"\n",
    "    New setting expects T to be (d_z, 3072).\n",
    "    If a full T (84, 3092) is passed, slice to z<-pixels part:\n",
    "        rows 20: (z block), cols :3072 (pixels block).\n",
    "    \"\"\"\n",
    "    T = T if isinstance(T, torch.Tensor) else torch.tensor(T, dtype=torch.float32)\n",
    "\n",
    "    if T.ndim != 2:\n",
    "        raise ValueError(f\"T must be 2D, got shape {tuple(T.shape)}\")\n",
    "\n",
    "    # already pixel->z\n",
    "    if T.shape[1] == 3072:\n",
    "        return T\n",
    "\n",
    "    # full old-style T (84,3092)\n",
    "    if T.shape == (84, 3092):\n",
    "        return T[20:, :3072]  # z rows, pixel cols\n",
    "\n",
    "    # fallback: try to infer by matching pixel dim\n",
    "    if T.shape[1] > 3072:\n",
    "        return T[:, :3072]\n",
    "\n",
    "    return T  # last resort\n",
    "\n",
    "for method_group_key, folds in all_results.items():\n",
    "    if not isinstance(folds, dict):\n",
    "        continue\n",
    "\n",
    "    for fold_key, fold_data in folds.items():\n",
    "        if not fold_key.startswith(\"fold_\"):\n",
    "            continue\n",
    "        if not isinstance(fold_data, dict) or \"error\" in fold_data:\n",
    "            continue  # skip failed folds\n",
    "\n",
    "        fold_idx = int(fold_key.split(\"_\")[-1])\n",
    "\n",
    "        for run_key, run_result in fold_data.items():\n",
    "            if not isinstance(run_result, dict):\n",
    "                continue\n",
    "            if \"error\" in run_result or run_result.get(\"T_matrix\") is None:\n",
    "                continue\n",
    "\n",
    "            T_raw = run_result[\"T_matrix\"]\n",
    "            try:\n",
    "                T = extract_T_pix_to_z(T_raw)\n",
    "            except Exception as e:\n",
    "                print(f\"[skip] could not parse T for {method_group_key}/{fold_key}/{run_key}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # method name logic (same as eval)\n",
    "            if method_group_key.startswith(\"diroca_\"):\n",
    "                method_name = f\"DiRoCA ({run_key})\"\n",
    "            elif method_group_key == \"gradca\":\n",
    "                method_name = \"GradCA\"\n",
    "            elif method_group_key == \"baryca\":\n",
    "                method_name = \"BaryCA\"\n",
    "            elif method_group_key == \"abslingam\":\n",
    "                method_name = run_key\n",
    "            else:\n",
    "                method_name = f\"{method_group_key}_{run_key}\"\n",
    "\n",
    "            # spectral & Frobenius norms on pixel->z map\n",
    "            with torch.no_grad():\n",
    "                fro_norm = torch.linalg.norm(T, ord='fro').item()\n",
    "                try:\n",
    "                    spec_norm = torch.linalg.norm(T, ord=2).item()\n",
    "                except RuntimeError:\n",
    "                    spec_norm = float(\"nan\")\n",
    "\n",
    "            # radii used in training (now stored at top-level)\n",
    "            eps_train   = run_result.get(\"epsilon\", np.nan)\n",
    "            delta_train = run_result.get(\"delta\", np.nan)\n",
    "\n",
    "            records.append({\n",
    "                \"method_group\": method_group_key,\n",
    "                \"run_key\": run_key,\n",
    "                \"fold\": fold_idx,\n",
    "                \"method\": method_name,\n",
    "                \"T_shape\": tuple(T.shape),\n",
    "                \"fro_norm\": fro_norm,\n",
    "                \"spec_norm\": spec_norm,\n",
    "                \"epsilon_train\": float(eps_train) if eps_train is not None else np.nan,\n",
    "                \"delta_train\": float(delta_train) if delta_train is not None else np.nan,\n",
    "            })\n",
    "\n",
    "norm_df = pd.DataFrame(records)\n",
    "\n",
    "print(\"\\n=== Per-run norms and training radii (first 50 rows) ===\")\n",
    "print(norm_df.head(50))\n",
    "\n",
    "print(\"\\n=== Aggregated by method (mean ± std) ===\")\n",
    "summary = norm_df.groupby(\"method\")[[\"fro_norm\", \"spec_norm\", \"epsilon_train\", \"delta_train\"]].agg([\"mean\", \"std\"])\n",
    "print(summary)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a6420d4e",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'all_results' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 58\u001b[39m\n\u001b[32m     55\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m cheat_index\n\u001b[32m     57\u001b[39m \u001b[38;5;66;03m# -------- Example usage --------\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m58\u001b[39m T_example = \u001b[43mall_results\u001b[49m[\u001b[33m'\u001b[39m\u001b[33mgradca\u001b[39m\u001b[33m'\u001b[39m][\u001b[33m'\u001b[39m\u001b[33mfold_0\u001b[39m\u001b[33m'\u001b[39m][\u001b[33m'\u001b[39m\u001b[33mgradca_run\u001b[39m\u001b[33m'\u001b[39m][\u001b[33m'\u001b[39m\u001b[33mT_matrix\u001b[39m\u001b[33m'\u001b[39m]\n\u001b[32m     59\u001b[39m cheat_val = plot_importance_maps(T_example, title=\u001b[33m\"\u001b[39m\u001b[33mGradCA\u001b[39m\u001b[33m\"\u001b[39m, n_show=\u001b[32m8\u001b[39m)\n\u001b[32m     60\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mCheat index:\u001b[39m\u001b[33m\"\u001b[39m, cheat_val)\n",
      "\u001b[31mNameError\u001b[39m: name 'all_results' is not defined"
     ]
    }
   ],
   "source": [
    "# === Pixel-importance heatmaps + cheat index ===\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def _as_tensor(x):\n",
    "    return x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float32)\n",
    "\n",
    "def compute_pixel_importance(T_matrix, img_shape=(3, 32, 32)):\n",
    "    \"\"\"\n",
    "    T_matrix: (z_dim, 3072) mapping pixels -> z\n",
    "    Returns:\n",
    "      imp_maps: (z_dim, H, W) per-z heatmaps based on sum over RGB |weights|\n",
    "      cheat_index: scalar concentration score (higher = more 'cheaty'/sparse)\n",
    "    \"\"\"\n",
    "    T = _as_tensor(T_matrix).detach().cpu()\n",
    "    z_dim, d_ll = T.shape\n",
    "    assert d_ll == np.prod(img_shape), f\"Expected d_ll={np.prod(img_shape)}, got {d_ll}\"\n",
    "\n",
    "    C, H, W = img_shape\n",
    "    T_img = T.view(z_dim, C, H, W)\n",
    "    imp_maps = T_img.abs().sum(dim=1)  # (z_dim, H, W)\n",
    "\n",
    "    # cheat index: top-k / mean concentration averaged over z\n",
    "    flat = imp_maps.view(z_dim, -1)  # (z_dim, H*W)\n",
    "    k = min(50, flat.shape[1])\n",
    "    topk_mean = flat.topk(k, dim=1).values.mean(dim=1)\n",
    "    all_mean = flat.mean(dim=1) + 1e-12\n",
    "    cheat_per_z = (topk_mean / all_mean)\n",
    "    cheat_index = cheat_per_z.mean().item()\n",
    "\n",
    "    return imp_maps.numpy(), cheat_index, cheat_per_z.numpy()\n",
    "\n",
    "def plot_importance_maps(T_matrix, title=\"\", n_show=8):\n",
    "    imp_maps, cheat_index, cheat_per_z = compute_pixel_importance(T_matrix)\n",
    "    z_dim, H, W = imp_maps.shape\n",
    "    n_show = min(n_show, z_dim)\n",
    "\n",
    "    fig, axes = plt.subplots(1, n_show, figsize=(3*n_show, 3))\n",
    "    if n_show == 1:\n",
    "        axes = [axes]\n",
    "\n",
    "    for i in range(n_show):\n",
    "        ax = axes[i]\n",
    "        hm = imp_maps[i]\n",
    "        vmax = np.max(hm) if np.max(hm) > 0 else 1.0\n",
    "        ax.imshow(hm, cmap=\"magma\", vmin=0.0, vmax=vmax)\n",
    "        ax.set_title(f\"z[{i}]  cheat={cheat_per_z[i]:.1f}\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "    plt.suptitle(f\"{title}\\nCheat index (top50/mean avg): {cheat_index:.2f}\", y=1.05)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    return cheat_index\n",
    "\n",
    "# -------- Example usage --------\n",
    "T_example = all_results['gradca']['fold_0']['gradca_run']['T_matrix']\n",
    "cheat_val = plot_importance_maps(T_example, title=\"GradCA\", n_show=8)\n",
    "print(\"Cheat index:\", cheat_val)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5be5458a",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'U_ll_hat_opt' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[13]\u001b[39m\u001b[32m, line 139\u001b[39m\n\u001b[32m    135\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m pd.DataFrame(records)\n\u001b[32m    137\u001b[39m \u001b[38;5;66;03m# -------- Example usage --------\u001b[39;00m\n\u001b[32m    138\u001b[39m \u001b[38;5;66;03m# In this eval notebook, just use the opt-view noises:\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m139\u001b[39m U_ll_hat_run = \u001b[43mU_ll_hat_opt\u001b[49m\n\u001b[32m    140\u001b[39m U_hl_hat_run = U_hl_hat_opt\n\u001b[32m    142\u001b[39m sem_df = run_semantic_eval_all_methods(\n\u001b[32m    143\u001b[39m     all_results,\n\u001b[32m    144\u001b[39m     det_ll_dict_opt, det_hl_dict_opt,\n\u001b[32m    145\u001b[39m     U_ll_hat_run, U_hl_hat_run,\n\u001b[32m    146\u001b[39m     omega\n\u001b[32m    147\u001b[39m )\n",
      "\u001b[31mNameError\u001b[39m: name 'U_ll_hat_opt' is not defined"
     ]
    }
   ],
   "source": [
    "# === Semantic evaluation: interventional consistency (pixels -> z) ===\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "def _as_tensor_dev(x, device=None):\n",
    "    t = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float32)\n",
    "    return t.to(device) if device is not None else t\n",
    "\n",
    "@torch.no_grad()\n",
    "def semantic_errors_pixels_to_z(\n",
    "    T_matrix,\n",
    "    det_ll_dict_opt,   # each (N, 3072) deterministic pixels only\n",
    "    det_hl_dict_opt,   # each (N, 64) deterministic z only\n",
    "    U_ll_hat_opt,      # (N, 3072) LL noise (test-time, maybe contaminated)\n",
    "    U_hl_hat_opt,      # (N, 64)   HL noise (usually clean)\n",
    "    omega,\n",
    "    test_idx=None,     # optional index list/tensor for test split\n",
    "    device=None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Computes:\n",
    "      e_{iota,eta}(T) = E || T (D_ll[iota]+U_ll) - (D_hl[eta]+U_hl) ||^2\n",
    "    over all (iota, eta) in omega.\n",
    "\n",
    "    Returns dict with per-pair errors plus:\n",
    "      mean_iv_error, max_iv_error, var_iv_error\n",
    "    \"\"\"\n",
    "    device = device or torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    T = _as_tensor_dev(T_matrix, device=device)\n",
    "    U_ll = _as_tensor_dev(U_ll_hat_opt, device=device)\n",
    "    U_hl = _as_tensor_dev(U_hl_hat_opt, device=device)\n",
    "\n",
    "    if test_idx is not None:\n",
    "        test_idx = torch.as_tensor(test_idx, dtype=torch.long, device=device)\n",
    "        U_ll = U_ll.index_select(0, test_idx)\n",
    "        U_hl = U_hl.index_select(0, test_idx)\n",
    "\n",
    "    pair_errors = {}\n",
    "\n",
    "    for iota, eta in omega.items():\n",
    "        if iota not in det_ll_dict_opt or eta not in det_hl_dict_opt:\n",
    "            continue\n",
    "\n",
    "        Dll = _as_tensor_dev(det_ll_dict_opt[iota], device=device)\n",
    "        Dhl = _as_tensor_dev(det_hl_dict_opt[eta], device=device)\n",
    "\n",
    "        if test_idx is not None:\n",
    "            Dll = Dll.index_select(0, test_idx)\n",
    "            Dhl = Dhl.index_select(0, test_idx)\n",
    "\n",
    "        # endo LL/HL (pixels->z setting)\n",
    "        X_ll = Dll + U_ll           # (N_test, 3072)\n",
    "        Z_hl = Dhl + U_hl           # (N_test, 64)\n",
    "\n",
    "        Z_pred = X_ll @ T.T         # (N_test, 64)\n",
    "        diff = Z_pred - Z_hl\n",
    "        err = (diff.norm(p=\"fro\")**2 / max(1, diff.shape[0])).item()\n",
    "\n",
    "        pair_errors[(iota, eta)] = err\n",
    "\n",
    "    errs = np.array(list(pair_errors.values()), dtype=np.float32)\n",
    "    out = {\n",
    "        \"pair_errors\": pair_errors,\n",
    "        \"mean_iv_error\": float(np.mean(errs)) if len(errs) else np.nan,\n",
    "        \"max_iv_error\": float(np.max(errs)) if len(errs) else np.nan,\n",
    "        \"var_iv_error\": float(np.var(errs)) if len(errs) else np.nan,\n",
    "    }\n",
    "    return out\n",
    "\n",
    "\n",
    "def run_semantic_eval_all_methods(\n",
    "    all_results,\n",
    "    det_ll_dict_opt, det_hl_dict_opt,\n",
    "    U_ll_hat_opt, U_hl_hat_opt,\n",
    "    omega\n",
    "):\n",
    "    \"\"\"\n",
    "    Drop-in evaluator over your all_results structure.\n",
    "    Returns a DataFrame with mean/max/var interventional errors per run.\n",
    "    \"\"\"\n",
    "    records = []\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    for method_group_key, folds in all_results.items():\n",
    "        if not isinstance(folds, dict):\n",
    "            continue\n",
    "\n",
    "        for fold_key, fold_data in folds.items():\n",
    "            if not fold_key.startswith(\"fold_\"):\n",
    "                continue\n",
    "            if not isinstance(fold_data, dict) or \"error\" in fold_data:\n",
    "                continue\n",
    "            fold_idx = int(fold_key.split(\"_\")[-1])\n",
    "\n",
    "            for run_key, run_result in fold_data.items():\n",
    "                if not isinstance(run_result, dict) or \"error\" in run_result:\n",
    "                    continue\n",
    "                if run_result.get(\"T_matrix\") is None:\n",
    "                    continue\n",
    "\n",
    "                T = run_result[\"T_matrix\"]\n",
    "                test_idx = run_result.get(\"test_indices\", None)\n",
    "\n",
    "                # match your naming conventions\n",
    "                if method_group_key.startswith(\"diroca_\"):\n",
    "                    method_name = f\"DiRoCA ({run_key})\"\n",
    "                elif method_group_key == \"gradca\":\n",
    "                    method_name = \"GradCA\"\n",
    "                elif method_group_key == \"baryca\":\n",
    "                    method_name = \"BaryCA\"\n",
    "                elif method_group_key == \"abslingam\":\n",
    "                    method_name = run_key\n",
    "                else:\n",
    "                    method_name = f\"{method_group_key}_{run_key}\"\n",
    "\n",
    "                sem = semantic_errors_pixels_to_z(\n",
    "                    T, det_ll_dict_opt, det_hl_dict_opt,\n",
    "                    U_ll_hat_opt, U_hl_hat_opt,\n",
    "                    omega, test_idx=test_idx, device=device\n",
    "                )\n",
    "\n",
    "                records.append({\n",
    "                    \"method_group\": method_group_key,\n",
    "                    \"run_key\": run_key,\n",
    "                    \"fold\": fold_idx,\n",
    "                    \"method\": method_name,\n",
    "                    \"mean_iv_error\": sem[\"mean_iv_error\"],\n",
    "                    \"max_iv_error\": sem[\"max_iv_error\"],\n",
    "                    \"var_iv_error\": sem[\"var_iv_error\"],\n",
    "                })\n",
    "\n",
    "    return pd.DataFrame(records)\n",
    "\n",
    "# -------- Example usage --------\n",
    "# In this eval notebook, just use the opt-view noises:\n",
    "U_ll_hat_run = U_ll_hat_opt\n",
    "U_hl_hat_run = U_hl_hat_opt\n",
    "\n",
    "sem_df = run_semantic_eval_all_methods(\n",
    "    all_results,\n",
    "    det_ll_dict_opt, det_hl_dict_opt,\n",
    "    U_ll_hat_run, U_hl_hat_run,\n",
    "    omega\n",
    ")\n",
    "\n",
    "print(sem_df.sort_values(\"mean_iv_error\").head(20))\n",
    "print(\n",
    "    sem_df.groupby(\"method\")[[\"mean_iv_error\",\"max_iv_error\",\"var_iv_error\"]]\n",
    "    .agg([\"mean\",\"std\"])\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cb0a27d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "erica",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
