{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "866a36e2",
   "metadata": {},
   "source": [
    "## Training on CIFAR100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af1a63ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.models import resnet34\n",
    "import os\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Config\n",
    "EPOCHS = 200\n",
    "BATCH_SIZE = 128\n",
    "LR = 0.1\n",
    "WEIGHT_DECAY = 5e-4\n",
    "PRINT_EVERY = 20\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "DATA_DIR = '../data'\n",
    "SAVE_DIR = os.path.join(DATA_DIR, 'checkpoints_cifar100')\n",
    "\n",
    "os.makedirs(SAVE_DIR, exist_ok=True)\n",
    "\n",
    "# Data augmentation\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),\n",
    "])\n",
    "\n",
    "test_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),\n",
    "])\n",
    "\n",
    "# Data loaders\n",
    "train_dataset = datasets.CIFAR100(root=os.path.join(DATA_DIR, 'cifar100'), train=True, download=True, transform=train_transform)\n",
    "test_dataset = datasets.CIFAR100(root=os.path.join(DATA_DIR, 'cifar100'), train=False, download=True, transform=test_transform)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "# Model - ResNet34 modified for CIFAR (smaller input)\n",
    "def resnet34_cifar100():\n",
    "    model = resnet34(weights=None)\n",
    "    # Modify first conv for 32x32 input (no aggressive downsampling)\n",
    "    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "    model.maxpool = nn.Identity()  # Remove maxpool for small images\n",
    "    model.fc = nn.Linear(512, 100)\n",
    "    return model\n",
    "\n",
    "model = resnet34_cifar100().to(DEVICE)\n",
    "\n",
    "# Training setup\n",
    "criterion = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
    "optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=WEIGHT_DECAY)\n",
    "scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)\n",
    "\n",
    "# Training loop\n",
    "history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'lr': []}\n",
    "\n",
    "def train_epoch(model, loader, criterion, optimizer):\n",
    "    model.train()\n",
    "    total_loss, correct, total = 0, 0, 0\n",
    "    for inputs, targets in loader:\n",
    "        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total_loss += loss.item() * inputs.size(0)\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    return total_loss / total, 100. * correct / total\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate(model, loader, criterion):\n",
    "    model.eval()\n",
    "    total_loss, correct, total = 0, 0, 0\n",
    "    for inputs, targets in loader:\n",
    "        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        \n",
    "        total_loss += loss.item() * inputs.size(0)\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    return total_loss / total, 100. * correct / total\n",
    "\n",
    "# Main training\n",
    "best_acc = 0\n",
    "for epoch in range(EPOCHS):\n",
    "    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)\n",
    "    test_loss, test_acc = evaluate(model, test_loader, criterion)\n",
    "    \n",
    "    current_lr = scheduler.get_last_lr()[0]\n",
    "    scheduler.step()\n",
    "    \n",
    "    # Log\n",
    "    history['train_loss'].append(train_loss)\n",
    "    history['train_acc'].append(train_acc)\n",
    "    history['test_loss'].append(test_loss)\n",
    "    history['test_acc'].append(test_acc)\n",
    "    history['lr'].append(current_lr)\n",
    "    \n",
    "    # Print every PRINT_EVERY epochs\n",
    "    if (epoch + 1) % PRINT_EVERY == 0 or epoch == 0:\n",
    "        print(f'Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | LR: {current_lr:.6f}')\n",
    "    \n",
    "    # Save checkpoint every epoch\n",
    "    checkpoint = {\n",
    "        'epoch': epoch + 1,\n",
    "        'model_state_dict': model.state_dict(),\n",
    "        'optimizer_state_dict': optimizer.state_dict(),\n",
    "        'train_acc': train_acc,\n",
    "        'test_acc': test_acc,\n",
    "    }\n",
    "    torch.save(checkpoint, os.path.join(SAVE_DIR, f'checkpoint_epoch_{epoch+1:03d}.pt'))\n",
    "    \n",
    "    # Save best model\n",
    "    if test_acc > best_acc:\n",
    "        best_acc = test_acc\n",
    "        torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'best_model.pt'))\n",
    "        if (epoch + 1) % PRINT_EVERY == 0 or epoch == 0:\n",
    "            print(f'New best accuracy: {best_acc:.2f}%')\n",
    "\n",
    "# Save training history\n",
    "with open(os.path.join(SAVE_DIR, 'history.json'), 'w') as f:\n",
    "    json.dump(history, f)\n",
    "\n",
    "print(f'\\nTraining complete. Best test accuracy: {best_acc:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ae0a3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../src')\n",
    "\n",
    "from data_utils import get_tree_and_V\n",
    "\n",
    "tree, root, V = get_tree_and_V('cifar100', '../data/cifar100')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f745f96a",
   "metadata": {},
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b26b2120",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.models import resnet34\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import gc\n",
    "\n",
    "# ============================================================\n",
    "# SETUP\n",
    "# ============================================================\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "DATA_DIR = '../datae'\n",
    "CHECKPOINT_DIR = os.path.join(DATA_DIR, 'checkpoints_cifar100')\n",
    "\n",
    "# Load tree and V\n",
    "import sys\n",
    "sys.path.append('../src')\n",
    "from data_utils import get_tree_and_V\n",
    "\n",
    "tree, root, V = get_tree_and_V('cifar100', os.path.join(DATA_DIR, 'cifar100'))\n",
    "V_np = V if isinstance(V, np.ndarray) else V.numpy()\n",
    "V_tensor = torch.tensor(V_np, dtype=torch.float32, device=DEVICE)\n",
    "\n",
    "# Build depth assignment based on support size\n",
    "depth_by_col = []\n",
    "for col in range(99):\n",
    "    support_size = np.sum(np.abs(V_np[:, col]) > 1e-10)\n",
    "    if support_size <= 5:\n",
    "        depth_by_col.append(1)  # within superclass\n",
    "    else:\n",
    "        depth_by_col.append(0)  # root level\n",
    "depth_by_col = np.array(depth_by_col)\n",
    "\n",
    "print(f\"Depth 0 (root): {np.sum(depth_by_col == 0)} coords\")\n",
    "print(f\"Depth 1 (superclass): {np.sum(depth_by_col == 1)} coords\")\n",
    "\n",
    "# Precompute masks on device\n",
    "d0_mask = torch.tensor(depth_by_col == 0, device=DEVICE)\n",
    "d1_mask = torch.tensor(depth_by_col == 1, device=DEVICE)\n",
    "\n",
    "# ============================================================\n",
    "# MODEL\n",
    "# ============================================================\n",
    "def resnet34_cifar100():\n",
    "    model = resnet34(weights=None)\n",
    "    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "    model.maxpool = nn.Identity()\n",
    "    model.fc = nn.Linear(512, 100)\n",
    "    return model\n",
    "\n",
    "# ============================================================\n",
    "# DATA\n",
    "# ============================================================\n",
    "test_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),\n",
    "])\n",
    "test_dataset = datasets.CIFAR100(root=os.path.join(DATA_DIR, 'cifar100'), train=False, download=True, transform=test_transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "# ============================================================\n",
    "# COMPUTE GRADIENT ENTROPY (MEMORY EFFICIENT)\n",
    "# ============================================================\n",
    "def compute_gradient_entropy_with_V(model, loader, V_tensor, d0_mask, d1_mask, device):\n",
    "    \"\"\"\n",
    "    Memory-efficient computation of gradient entropy.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    \n",
    "    sum_norm_d0 = 0.0\n",
    "    sum_norm_d1 = 0.0\n",
    "    n_samples = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            \n",
    "            logits = model(inputs)\n",
    "            probs = torch.softmax(logits, dim=1)\n",
    "            one_hot = torch.zeros_like(probs)\n",
    "            one_hot.scatter_(1, targets.unsqueeze(1), 1)\n",
    "            grad_z = probs - one_hot  # (batch, 100)\n",
    "            \n",
    "            grad_a = grad_z @ V_tensor  # (batch, 99)\n",
    "            \n",
    "            sum_norm_d0 += torch.norm(grad_a[:, d0_mask], dim=1).sum().item()\n",
    "            sum_norm_d1 += torch.norm(grad_a[:, d1_mask], dim=1).sum().item()\n",
    "            n_samples += inputs.size(0)\n",
    "    \n",
    "    mean_norm_d0 = sum_norm_d0 / n_samples\n",
    "    mean_norm_d1 = sum_norm_d1 / n_samples\n",
    "    \n",
    "    total = mean_norm_d0 + mean_norm_d1\n",
    "    p = np.array([mean_norm_d0, mean_norm_d1]) / total\n",
    "    entropy = -np.sum(p * np.log(p + 1e-10))\n",
    "    \n",
    "    return entropy, p, (mean_norm_d0, mean_norm_d1)\n",
    "\n",
    "# ============================================================\n",
    "# EXPERIMENT (i): TRUE TREE VS SHUFFLED TREES\n",
    "# ============================================================\n",
    "N_SHUFFLES = 1000\n",
    "\n",
    "print(\"\\n\" + \"=\" * 50)\n",
    "print(\"EXPERIMENT (i): True tree vs shuffled trees\")\n",
    "print(f\"Number of shuffled trees: {N_SHUFFLES}\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "# ============================================================\n",
    "# RANDOM INIT MODEL (CONTROL)\n",
    "# ============================================================\n",
    "print(\"\\n--- Random Init Model (Control) ---\")\n",
    "\n",
    "model_init = resnet34_cifar100().to(DEVICE)\n",
    "init_entropy, init_p, init_norms = compute_gradient_entropy_with_V(\n",
    "    model_init, test_loader, V_tensor, d0_mask, d1_mask, DEVICE\n",
    ")\n",
    "print(f\"True tree: entropy = {init_entropy:.4f}, p = {init_p}\")\n",
    "\n",
    "shuffled_entropies_init = []\n",
    "np.random.seed(42)\n",
    "\n",
    "print(f\"Running {N_SHUFFLES} shuffled trees...\")\n",
    "for i in tqdm(range(N_SHUFFLES), desc=\"Shuffled trees (init)\"):\n",
    "    perm = np.random.permutation(100)\n",
    "    V_shuffled = torch.tensor(V_np[perm, :], dtype=torch.float32, device=DEVICE)\n",
    "    \n",
    "    entropy, _, _ = compute_gradient_entropy_with_V(\n",
    "        model_init, test_loader, V_shuffled, d0_mask, d1_mask, DEVICE\n",
    "    )\n",
    "    shuffled_entropies_init.append(entropy)\n",
    "    \n",
    "    del V_shuffled\n",
    "    if i % 50 == 0:\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "shuffled_entropies_init = np.array(shuffled_entropies_init)\n",
    "p_value_init = np.mean(shuffled_entropies_init <= init_entropy)\n",
    "\n",
    "print(f\"Shuffled: mean = {np.mean(shuffled_entropies_init):.4f} ± {np.std(shuffled_entropies_init):.4f}\")\n",
    "print(f\"p-value: {p_value_init:.4f}\")\n",
    "\n",
    "del model_init\n",
    "torch.cuda.empty_cache()\n",
    "gc.collect()\n",
    "\n",
    "# ============================================================\n",
    "# TRAINED MODEL\n",
    "# ============================================================\n",
    "print(\"\\n--- Trained Model ---\")\n",
    "\n",
    "model = resnet34_cifar100().to(DEVICE)\n",
    "ckpt = torch.load(os.path.join(CHECKPOINT_DIR, 'checkpoint_epoch_200.pt'), map_location=DEVICE)\n",
    "model.load_state_dict(ckpt['model_state_dict'])\n",
    "print(f\"Loaded epoch {ckpt['epoch']}, test acc: {ckpt['test_acc']:.2f}%\")\n",
    "\n",
    "true_entropy, true_p, true_norms = compute_gradient_entropy_with_V(\n",
    "    model, test_loader, V_tensor, d0_mask, d1_mask, DEVICE\n",
    ")\n",
    "print(f\"True tree: entropy = {true_entropy:.4f}, p = {true_p}, norms = {true_norms}\")\n",
    "\n",
    "shuffled_entropies = []\n",
    "np.random.seed(42)\n",
    "\n",
    "print(f\"Running {N_SHUFFLES} shuffled trees...\")\n",
    "for i in tqdm(range(N_SHUFFLES), desc=\"Shuffled trees (trained)\"):\n",
    "    perm = np.random.permutation(100)\n",
    "    V_shuffled = torch.tensor(V_np[perm, :], dtype=torch.float32, device=DEVICE)\n",
    "    \n",
    "    entropy, _, _ = compute_gradient_entropy_with_V(\n",
    "        model, test_loader, V_shuffled, d0_mask, d1_mask, DEVICE\n",
    "    )\n",
    "    shuffled_entropies.append(entropy)\n",
    "    \n",
    "    del V_shuffled\n",
    "    if i % 50 == 0:\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "shuffled_entropies = np.array(shuffled_entropies)\n",
    "p_value = np.mean(shuffled_entropies <= true_entropy)\n",
    "\n",
    "print(f\"Shuffled: mean = {np.mean(shuffled_entropies):.4f} ± {np.std(shuffled_entropies):.4f}\")\n",
    "print(f\"True tree entropy: {true_entropy:.4f}\")\n",
    "print(f\"p-value: {p_value:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d21bc627",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "\n",
    "# Random init (LEFT)\n",
    "ax = axes[0]\n",
    "ax.hist(shuffled_entropies_init, bins=50, alpha=0.7, color='steelblue', edgecolor='white', linewidth=0.5)\n",
    "ax.axvline(init_entropy, color='crimson', linewidth=2.5, label='True tree')\n",
    "ax.set_xlabel('Gradient Entropy', fontsize=11)\n",
    "ax.set_ylabel('Count', fontsize=11)\n",
    "ax.set_title('Random Init (p = 0.73)', fontsize=12, fontweight='bold')\n",
    "ax.legend(fontsize=10)\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.xaxis.set_major_locator(plt.MaxNLocator(5))\n",
    "ax.ticklabel_format(useOffset=False, style='plain')\n",
    "\n",
    "# Trained model (RIGHT)\n",
    "ax = axes[1]\n",
    "ax.hist(shuffled_entropies, bins=50, alpha=0.7, color='steelblue', edgecolor='white', linewidth=0.5)\n",
    "ax.axvline(true_entropy, color='crimson', linewidth=2.5, label='True tree')\n",
    "ax.set_xlabel('Gradient Entropy', fontsize=11)\n",
    "ax.set_ylabel('Count', fontsize=11)\n",
    "ax.set_title('Trained Model (p < 0.001)', fontsize=12, fontweight='bold')\n",
    "ax.legend(fontsize=10)\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.xaxis.set_major_locator(plt.MaxNLocator(5))\n",
    "ax.ticklabel_format(useOffset=False, style='plain')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('experiment_i_entropy.png', dpi=150, bbox_inches='tight')\n",
    "plt.savefig('experiment_i_entropy.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f6ed75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.models import resnet34\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "\n",
    "# ============================================================\n",
    "# SETUP\n",
    "# ============================================================\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "DATA_DIR = '../data'\n",
    "CHECKPOINT_DIR = os.path.join(DATA_DIR, 'checkpoints_cifar100')\n",
    "\n",
    "# Load tree and V\n",
    "import sys\n",
    "sys.path.append('../src')\n",
    "from data_utils import get_tree_and_V\n",
    "\n",
    "tree, root, V = get_tree_and_V('cifar100', os.path.join(DATA_DIR, 'cifar100'))\n",
    "V_np = V if isinstance(V, np.ndarray) else V.numpy()\n",
    "V_tensor = torch.tensor(V_np, dtype=torch.float32, device=DEVICE)\n",
    "\n",
    "# Build depth assignment based on support size\n",
    "depth_by_col = []\n",
    "for col in range(99):\n",
    "    support_size = np.sum(np.abs(V_np[:, col]) > 1e-10)\n",
    "    if support_size <= 5:\n",
    "        depth_by_col.append(1)\n",
    "    else:\n",
    "        depth_by_col.append(0)\n",
    "depth_by_col = np.array(depth_by_col)\n",
    "\n",
    "# Precompute masks on device\n",
    "d0_mask = torch.tensor(depth_by_col == 0, device=DEVICE)\n",
    "d1_mask = torch.tensor(depth_by_col == 1, device=DEVICE)\n",
    "\n",
    "# ============================================================\n",
    "# MODEL\n",
    "# ============================================================\n",
    "def resnet34_cifar100():\n",
    "    model = resnet34(weights=None)\n",
    "    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "    model.maxpool = nn.Identity()\n",
    "    model.fc = nn.Linear(512, 100)\n",
    "    return model\n",
    "\n",
    "# ============================================================\n",
    "# DATA\n",
    "# ============================================================\n",
    "test_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),\n",
    "])\n",
    "test_dataset = datasets.CIFAR100(root=os.path.join(DATA_DIR, 'cifar100'), train=False, download=True, transform=test_transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "# ============================================================\n",
    "# COMPUTE GRADIENT STATS\n",
    "# ============================================================\n",
    "def compute_gradient_stats(model, loader, V_tensor, d0_mask, d1_mask, device):\n",
    "    \"\"\"\n",
    "    Compute gradient norm at each depth.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    \n",
    "    sum_norm_d0 = 0.0\n",
    "    sum_norm_d1 = 0.0\n",
    "    n_samples = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            \n",
    "            logits = model(inputs)\n",
    "            probs = torch.softmax(logits, dim=1)\n",
    "            one_hot = torch.zeros_like(probs)\n",
    "            one_hot.scatter_(1, targets.unsqueeze(1), 1)\n",
    "            grad_z = probs - one_hot\n",
    "            \n",
    "            grad_a = grad_z @ V_tensor\n",
    "            \n",
    "            sum_norm_d0 += torch.norm(grad_a[:, d0_mask], dim=1).sum().item()\n",
    "            sum_norm_d1 += torch.norm(grad_a[:, d1_mask], dim=1).sum().item()\n",
    "            n_samples += inputs.size(0)\n",
    "    \n",
    "    mean_norm_d0 = sum_norm_d0 / n_samples\n",
    "    mean_norm_d1 = sum_norm_d1 / n_samples\n",
    "    total = mean_norm_d0 + mean_norm_d1\n",
    "    \n",
    "    frac_d0 = mean_norm_d0 / total\n",
    "    frac_d1 = mean_norm_d1 / total\n",
    "    entropy = -frac_d0 * np.log(frac_d0 + 1e-10) - frac_d1 * np.log(frac_d1 + 1e-10)\n",
    "    \n",
    "    return {\n",
    "        'norm_d0': mean_norm_d0,\n",
    "        'norm_d1': mean_norm_d1,\n",
    "        'frac_d0': frac_d0,\n",
    "        'frac_d1': frac_d1,\n",
    "        'entropy': entropy,\n",
    "    }\n",
    "\n",
    "# ============================================================\n",
    "# EXPERIMENT (ii): LEARNING DYNAMICS\n",
    "# ============================================================\n",
    "print(\"=\" * 50)\n",
    "print(\"EXPERIMENT (ii): Learning dynamics across training\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "# Get all checkpoints\n",
    "checkpoints = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.startswith('checkpoint_epoch')])\n",
    "print(f\"Found {len(checkpoints)} checkpoints\")\n",
    "\n",
    "# Load training history for accuracy\n",
    "with open(os.path.join(CHECKPOINT_DIR, 'history.json'), 'r') as f:\n",
    "    history = json.load(f)\n",
    "\n",
    "# Analyze each checkpoint\n",
    "epochs = []\n",
    "frac_d0_list = []\n",
    "frac_d1_list = []\n",
    "entropy_list = []\n",
    "test_acc_list = []\n",
    "\n",
    "model = resnet34_cifar100().to(DEVICE)\n",
    "\n",
    "for ckpt_name in tqdm(checkpoints, desc=\"Analyzing checkpoints\"):\n",
    "    ckpt_path = os.path.join(CHECKPOINT_DIR, ckpt_name)\n",
    "    ckpt = torch.load(ckpt_path, map_location=DEVICE)\n",
    "    \n",
    "    model.load_state_dict(ckpt['model_state_dict'])\n",
    "    epoch = ckpt['epoch']\n",
    "    \n",
    "    stats = compute_gradient_stats(model, test_loader, V_tensor, d0_mask, d1_mask, DEVICE)\n",
    "    \n",
    "    epochs.append(epoch)\n",
    "    frac_d0_list.append(stats['frac_d0'])\n",
    "    frac_d1_list.append(stats['frac_d1'])\n",
    "    entropy_list.append(stats['entropy'])\n",
    "    test_acc_list.append(ckpt['test_acc'])\n",
    "\n",
    "epochs = np.array(epochs)\n",
    "frac_d0_list = np.array(frac_d0_list)\n",
    "frac_d1_list = np.array(frac_d1_list)\n",
    "entropy_list = np.array(entropy_list)\n",
    "test_acc_list = np.array(test_acc_list)\n",
    "\n",
    "print(f\"\\nEpoch 1: frac_d0 = {frac_d0_list[0]:.3f}, frac_d1 = {frac_d1_list[0]:.3f}\")\n",
    "print(f\"Epoch 200: frac_d0 = {frac_d0_list[-1]:.3f}, frac_d1 = {frac_d1_list[-1]:.3f}\")\n",
    "\n",
    "# ============================================================\n",
    "# PLOT\n",
    "# ============================================================\n",
    "fig, axes = plt.subplots(1, 2, figsize=(11, 4))\n",
    "\n",
    "# Left: Stacked area plot of gradient fraction by depth\n",
    "ax = axes[0]\n",
    "ax.fill_between(epochs, 0, frac_d0_list, alpha=0.8, label='Depth 0 (superclass)', color='#2ecc71')\n",
    "ax.fill_between(epochs, frac_d0_list, 1, alpha=0.8, label='Depth 1 (fine class)', color='#3498db')\n",
    "ax.set_xlabel('Epoch', fontsize=11)\n",
    "ax.set_ylabel('Fraction of Gradient Norm', fontsize=11)\n",
    "ax.set_title('Gradient Distribution by Tree Depth', fontsize=12, fontweight='bold')\n",
    "ax.set_xlim(1, 200)\n",
    "ax.set_ylim(0, 1)\n",
    "ax.legend(loc='right', fontsize=10)\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Right: Entropy and accuracy over training\n",
    "ax1 = axes[1]\n",
    "color1 = '#e74c3c'\n",
    "ax1.plot(epochs, entropy_list, color=color1, linewidth=2, label='Gradient Entropy')\n",
    "ax1.set_xlabel('Epoch', fontsize=11)\n",
    "ax1.set_ylabel('Gradient Entropy', fontsize=11, color=color1)\n",
    "ax1.tick_params(axis='y', labelcolor=color1)\n",
    "ax1.set_xlim(1, 200)\n",
    "\n",
    "ax2 = ax1.twinx()\n",
    "color2 = '#9b59b6'\n",
    "ax2.plot(epochs, test_acc_list, color=color2, linewidth=2, linestyle='--', label='Test Accuracy')\n",
    "ax2.set_ylabel('Test Accuracy (%)', fontsize=11, color=color2)\n",
    "ax2.tick_params(axis='y', labelcolor=color2)\n",
    "\n",
    "ax1.set_title('Entropy & Accuracy During Training', fontsize=12, fontweight='bold')\n",
    "ax1.spines['top'].set_visible(False)\n",
    "\n",
    "# Combined legend\n",
    "lines1, labels1 = ax1.get_legend_handles_labels()\n",
    "lines2, labels2 = ax2.get_legend_handles_labels()\n",
    "ax1.legend(lines1 + lines2, labels1 + labels2, loc='center right', fontsize=10)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('experiment_ii_dynamics.png', dpi=150, bbox_inches='tight')\n",
    "plt.savefig('experiment_ii_dynamics.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(\"\\nSaved: experiment_ii_dynamics.png, experiment_ii_dynamics.pdf\")\n",
    "\n",
    "# Save results\n",
    "results_ii = {\n",
    "    'epochs': epochs.tolist(),\n",
    "    'frac_d0': frac_d0_list.tolist(),\n",
    "    'frac_d1': frac_d1_list.tolist(),\n",
    "    'entropy': entropy_list.tolist(),\n",
    "    'test_acc': test_acc_list.tolist(),\n",
    "}\n",
    "with open('experiment_ii_results.json', 'w') as f:\n",
    "    json.dump(results_ii, f, indent=2)\n",
    "print(\"Saved: experiment_ii_results.json\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
