{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "44e9ab55",
   "metadata": {},
   "source": [
    "## DOC(Direct Output Control) sample code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d78cc5c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================\n",
    "# 1. Environment Setup\n",
    "# ================================\n",
    "!pip install torch torchvision timm torchdiffeq lpips -q\n",
    "\n",
    "import torch, torch.nn as nn, torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import torchvision.transforms as T\n",
    "import torchvision.datasets as datasets\n",
    "import timm, time, warnings\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from torchdiffeq import odeint\n",
    "import lpips\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3fca8c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# It is best to use GPU.\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ff8a6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================\n",
    "# 2. CIFAR-100 Dataset\n",
    "# ================================\n",
    "transform = T.Compose([\n",
    "    T.Resize(224),  # Resize for ViT input\n",
    "    T.ToTensor(),\n",
    "    T.Normalize(mean=[0.5071, 0.4865, 0.4409],\n",
    "                std=[0.2673, 0.2564, 0.2762]),\n",
    "])\n",
    "\n",
    "trainset = datasets.CIFAR100(root=\"./data\", train=True,\n",
    "                             download=True, transform=transform)\n",
    "testset  = datasets.CIFAR100(root=\"./data\", train=False,\n",
    "                             download=True, transform=transform)\n",
    "\n",
    "trainloader = DataLoader(trainset, batch_size=512,\n",
    "                         shuffle=True, num_workers=4)\n",
    "testloader  = DataLoader(testset, batch_size=1,\n",
    "                         shuffle=False, num_workers=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89b1ccb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================\n",
    "# 3. Model (ViT Base)\n",
    "# ================================\n",
    "model = timm.create_model(\"vit_base_patch16_224\", pretrained=True)\n",
    "model.head = nn.Linear(model.head.in_features, 100)  # Adjust for CIFAR-100\n",
    "model = model.to(device)\n",
    "\n",
    "# Count GPUs\n",
    "n_gpus = torch.cuda.device_count()\n",
    "print(f\"Detected GPUs: {n_gpus}\")\n",
    "\n",
    "if n_gpus > 1:\n",
    "    print(\"Using DataParallel across\", n_gpus, \"GPUs\")\n",
    "    model = nn.DataParallel(model, device_ids=list(range(n_gpus)))\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "# --- Light fine-tuning (few epochs just for demo)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "model.train()\n",
    "for epoch in range(2):  # For real use, train longer\n",
    "    for imgs, labels in trainloader:\n",
    "        imgs, labels = imgs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(imgs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    print(f\"Epoch {epoch}: loss={loss.item():.4f}\")\n",
    "\n",
    "model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "828bde01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================\n",
    "# 4. DOC/GCRC Core Functions\n",
    "# ================================\n",
    "def error_map(y, y_star, mode=\"fisher\", eps=1e-12):\n",
    "    y = y.clamp_min(eps)\n",
    "    y_star = y_star.clamp_min(eps)\n",
    "    if mode == \"fisher\":\n",
    "        psi, psis = y.sqrt(), y_star.sqrt()\n",
    "        c = (psi * psis).sum().clamp(-1+1e-7, 1-1e-7)\n",
    "        theta = torch.arccos(c)\n",
    "        s = torch.sin(theta) + 1e-12\n",
    "        v_psi = (2*theta/s) * (psis - c*psi)\n",
    "        v_y = 2*psi*v_psi\n",
    "        return v_y - v_y.sum()*y\n",
    "    elif mode == \"l2\":\n",
    "        v = y_star - y\n",
    "        return v - v.sum()*y\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported mode\")\n",
    "\n",
    "def flatten_state(x): return x.flatten()\n",
    "def unflatten_state(flat, shape): return flat.view(shape)\n",
    "\n",
    "def pullback_rhs_mode(t, flat, model, x_shape, y_star,\n",
    "                      lam=1e-2, mode=\"fisher\", k=10):\n",
    "    x = unflatten_state(flat, x_shape).detach().requires_grad_(True)\n",
    "    logits = model(x.unsqueeze(0)).squeeze(0)\n",
    "    y = torch.softmax(logits, dim=0)\n",
    "    v = error_map(y, y_star, mode=mode)\n",
    "    topk = logits.topk(k=min(k, logits.numel())).indices\n",
    "    tgt = y_star.argmax()\n",
    "    if tgt not in topk:\n",
    "        topk = torch.unique(torch.cat([topk, tgt.unsqueeze(0)]))\n",
    "    vA = v[topk]\n",
    "    def f_subset(z):\n",
    "        return torch.softmax(model(z.view(1,*x_shape)).squeeze(0), dim=0)[topk]\n",
    "    J = torch.autograd.functional.jacobian(f_subset, x).view(len(topk), -1)\n",
    "    JJt = J @ J.T\n",
    "    lam_eff = max(lam, 1e-6)\n",
    "    A = JJt + lam_eff*torch.eye(JJt.size(0), device=device)\n",
    "    dx = J.T @ torch.linalg.solve(A, vA.unsqueeze(1)).squeeze(1)\n",
    "    return dx.flatten()\n",
    "\n",
    "def pullback_mode(model, x0, y_star, steps=8, lam=1e-2, mode=\"fisher\", k=10):\n",
    "    t = torch.linspace(0,1,steps+1,device=device)\n",
    "    flat0 = flatten_state(x0)\n",
    "    sol = odeint(lambda t,flat: pullback_rhs_mode(t,flat,model,x0.shape,y_star,\n",
    "                                                  lam=lam,mode=mode,k=k),\n",
    "                 flat0, t, method=\"euler\")\n",
    "    return unflatten_state(sol[-1], x0.shape).detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54639fdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================\n",
    "# 5. Evaluation\n",
    "# ================================\n",
    "lpips_fn = lpips.LPIPS(net='alex').to(device)\n",
    "\n",
    "def evaluate_pullback_mode(model, dataloader, max_samples=20,\n",
    "                           steps=8, lam=1e-2, mode=\"fisher\", k=10):\n",
    "    rec = []\n",
    "    mis = []\n",
    "    with torch.no_grad():\n",
    "        for imgs, lbls in dataloader:\n",
    "            img, lbl = imgs.to(device), lbls.to(device)\n",
    "            pred = model(img).argmax(dim=1)\n",
    "            if pred.item() != lbl.item():\n",
    "                mis.append((img[0], lbl[0]))\n",
    "            if len(mis) >= max_samples:\n",
    "                break\n",
    "    print(\"Collected misclassified:\", len(mis))\n",
    "\n",
    "    for idx, (img, lbl) in enumerate(mis):\n",
    "        C = model(img.unsqueeze(0)).shape[1]\n",
    "        y_star = F.one_hot(lbl, num_classes=C).float().to(device)\n",
    "        # Repair\n",
    "        start = time.time()\n",
    "        x_fix = pullback_mode(model,img,y_star,steps=steps,lam=lam,mode=mode,k=k)\n",
    "        overhead = time.time() - start\n",
    "        with torch.no_grad():\n",
    "            y_after = torch.softmax(model(x_fix.unsqueeze(0)).squeeze(0), dim=0)\n",
    "            succ = int(y_after.argmax().item()==lbl.item())\n",
    "            l2 = (img-x_fix).pow(2).mean().sqrt().item()\n",
    "            lp = lpips_fn(img*2-1, x_fix*2-1).item()\n",
    "        rec.append({\"idx\": idx, \"success\": succ, \"l2\": l2, \"lpips\": lp,\n",
    "                    \"overhead\": overhead})\n",
    "        print(f\"[{idx}] success={succ} l2={l2:.4f} lpips={lp:.4f}\")\n",
    "    return pd.DataFrame(rec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f6bc86",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================\n",
    "# 6. Run Experiment\n",
    "# ================================\n",
    "df = evaluate_pullback_mode(model, testloader, max_samples=5)\n",
    "print(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "076f9936",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "\n",
    "# --- Visualization utility ---\n",
    "def visualize_repair_triplet(img, x_fix, disp, idx=0, success=True):\n",
    "    disp_mag = disp.norm(dim=0)\n",
    "    gamma, scale = 0.5, 5.0\n",
    "    disp_mag = (disp_mag ** gamma) * scale\n",
    "    disp_mag = disp_mag / (disp_mag.max() + 1e-8)\n",
    "    disp_img = (disp_mag.cpu().numpy() * 255).astype(\"uint8\")\n",
    "    disp_img = cv2.applyColorMap(disp_img, cv2.COLORMAP_JET)\n",
    "    disp_img = cv2.cvtColor(disp_img, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(9, 3))\n",
    "    axs[0].imshow(img.cpu().permute(1,2,0)); axs[0].set_title(\"Input\"); axs[0].axis(\"off\")\n",
    "    axs[1].imshow(disp_img); axs[1].set_title(\"Displacement\"); axs[1].axis(\"off\")\n",
    "    axs[2].imshow(x_fix.cpu().permute(1,2,0)); axs[2].set_title(\"Output\"); axs[2].axis(\"off\")\n",
    "    fig.suptitle(f\"Sample {idx} - {'Success' if success else 'Fail'}\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# --- Collect misclassified samples from dataloader ---\n",
    "def collect_misclassified(model, dataloader, max_samples=4):\n",
    "    mis = []\n",
    "    with torch.no_grad():\n",
    "        for imgs, lbls in dataloader:\n",
    "            img, lbl = imgs.to(device), lbls.to(device)\n",
    "            pred = model(img).argmax(dim=1)\n",
    "            if pred.item() != lbl.item():\n",
    "                mis.append((img[0], lbl[0]))\n",
    "            if len(mis) >= max_samples:\n",
    "                break\n",
    "    print(f\"Collected misclassified: {len(mis)}\")\n",
    "    return mis\n",
    "\n",
    "\n",
    "# --- Run repair and visualize ---\n",
    "def run_and_visualize(model, dataloader, steps=8, lam=1e-2, mode=\"fisher\", k=10, max_samples=4):\n",
    "    mis = collect_misclassified(model, dataloader, max_samples=max_samples)\n",
    "    for idx, (img, lbl) in enumerate(mis):\n",
    "        C = model(img.unsqueeze(0)).shape[1]\n",
    "        y_star = F.one_hot(lbl, num_classes=C).float().to(img.device)\n",
    "        # perform repair (your pullback function)\n",
    "        x_fix = pullback_mode(model, img, y_star, steps=steps, lam=lam, mode=mode, k=k)\n",
    "        with torch.no_grad():\n",
    "            succ = (model(x_fix.unsqueeze(0)).argmax().item() == lbl.item())\n",
    "        disp = x_fix - img\n",
    "        visualize_repair_triplet(img, x_fix, disp, idx=idx, success=succ)\n",
    "\n",
    "\n",
    "# --- Example usage ---\n",
    "run_and_visualize(model, testloader, max_samples=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb5a1b98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
