{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1db00436",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "def sample_toy_data(N):\n",
    "    mean = torch.tensor([[-.5,-.5],[-.5,.5],[.5,-.5]])\n",
    "    X = 0.1*torch.randn(3*N,2) + mean.repeat_interleave(N,0)\n",
    "    y = torch.cat([torch.full((N,),i) for i in range(3)])          \n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa58a7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ToyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        self.X, self.y = X, y.clone()\n",
    "    def __len__(self):  return len(self.y)\n",
    "    def __getitem__(self,i): return self.X[i], self.y[i]\n",
    "\n",
    "train_X, train_y = sample_toy_data(20)\n",
    "loader = torch.utils.data.DataLoader(ToyDataset(train_X, train_y), \n",
    "                                     batch_size=len(train_X), shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aaed888",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self,h=16):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(2,h), nn.ReLU(),\n",
    "            nn.Linear(h,h), nn.ReLU(),\n",
    "            nn.Linear(h,3))\n",
    "    def forward(self,x): return self.net(x)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model  = MLP(h=70).to(device)\n",
    "opt    = torch.optim.SGD(model.parameters(), lr=5e-2)\n",
    "crit   = nn.CrossEntropyLoss()\n",
    "\n",
    "for _ in range(1000):                  # 1k epoch\n",
    "    for X,y in loader:\n",
    "        X,y = X.to(device), y.to(device)\n",
    "        opt.zero_grad();  crit(model(X),y).backward();  opt.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc1eb24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch.nn.functional as F\n",
    "def v_1(model, inputs, rho=0.1, noise_scale=0.05, K=1):\n",
    "    \"\"\"\n",
    "    Compute the inconsistency minimal computation graph\n",
    "    Can't be used for loss.\n",
    "    \"\"\"\n",
    "    criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=False)\n",
    "    original_mode = model.training\n",
    "    original_requires_grad = {name: p.requires_grad for name, p in model.named_parameters()}\n",
    "    model.eval()\n",
    "    eps = torch.finfo(inputs.dtype).tiny\n",
    "\n",
    "    with torch.no_grad():\n",
    "        pred_orig_softmax = F.softmax(model(inputs), dim=1)\n",
    "        p_orig=pred_orig_softmax.clamp(eps)\n",
    "\n",
    "    original_params_flat = torch.nn.utils.parameters_to_vector(model.parameters()).detach().clone()\n",
    "    params_structure = list(model.parameters()) \n",
    "\n",
    "    k = original_params_flat.numel()\n",
    "\n",
    "    noise = noise_scale / (math.sqrt(k)) * torch.randn_like(original_params_flat)\n",
    "    noisy_params_flat = original_params_flat + noise\n",
    "    torch.nn.utils.vector_to_parameters(noisy_params_flat, params_structure)\n",
    "\n",
    "    for param in params_structure:\n",
    "        param.requires_grad_(True)\n",
    "\n",
    "    outputs_noisy = model(inputs)\n",
    "    eps = torch.finfo(outputs_noisy.dtype).tiny\n",
    "    outputs_noisy = F.softmax(outputs_noisy, dim=1).clamp(eps)\n",
    "    loss = criterion_kl(outputs_noisy.log(), p_orig)\n",
    "\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "\n",
    "\n",
    "    grad_vector_list = []\n",
    "    for param in params_structure:\n",
    "        if param.requires_grad:\n",
    "\n",
    "            grad = param.grad if param.grad is not None else torch.zeros_like(param.data)\n",
    "            grad_vector_list.append(grad.detach().flatten())\n",
    "\n",
    "    if not grad_vector_list:\n",
    "        print(\"[Error] No parameters require gradients.\")\n",
    "        torch.nn.utils.vector_to_parameters(original_params_flat, params_structure)\n",
    "        return float('nan')\n",
    "\n",
    "    grad_vector_flat = torch.cat(grad_vector_list)\n",
    "    grad_norm = grad_vector_flat.norm()\n",
    "\n",
    "    # Ascent\n",
    "    ascent_perturbation = rho * (grad_vector_flat / (grad_norm + 1e-12))\n",
    "\n",
    "    torch.nn.utils.vector_to_parameters(original_params_flat, params_structure)\n",
    "    for name, p in model.named_parameters():\n",
    "        if name in original_requires_grad:\n",
    "            p.requires_grad_(original_requires_grad[name])\n",
    "\n",
    "    model.train(original_mode)\n",
    "\n",
    "    return ascent_perturbation\n",
    "\n",
    "def perturb(model, delta_scale=0.05):\n",
    "    perturbed = MLP(h=70).to(device)\n",
    "    perturbed.load_state_dict(model.state_dict())     # θ \n",
    "    with torch.no_grad():\n",
    "        for p in perturbed.parameters():\n",
    "            p.add_(delta_scale * p.norm() * torch.randn_like(p))\n",
    "    return perturbed\n",
    "\n",
    "delta_model = perturb(model, delta_scale=0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d18ff2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "@torch.no_grad()\n",
    "def plot_boundaries(model_a, model_b, model_c,\n",
    "                        X_train, y_train,\n",
    "                        ax=None,\n",
    "                        colors=('black','tab:blue', 'tab:orange'),\n",
    "                        linestyles=('-','--'),\n",
    "                        title='Decision boundaries'):\n",
    "\n",
    "    device = next(model_a.parameters()).device\n",
    "    model_a.eval();  model_b.eval()\n",
    "    model_c.eval()\n",
    "\n",
    "\n",
    "    grid_x, grid_y = torch.meshgrid(\n",
    "        torch.linspace(-1, 1, 400, device=device),\n",
    "        torch.linspace(-1, 1, 400, device=device),\n",
    "        indexing='ij'\n",
    "    )\n",
    "    grid = torch.cat([grid_x.reshape(-1,1), grid_y.reshape(-1,1)], 1)\n",
    "\n",
    "\n",
    "    pred_a = model_a(grid).argmax(1).cpu().numpy().reshape(400,400)\n",
    "    pred_b = model_b(grid).argmax(1).cpu().numpy().reshape(400,400)\n",
    "    pred_c = model_c(grid).argmax(1).cpu().numpy().reshape(400,400)\n",
    "\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(6,6))\n",
    "\n",
    "\n",
    "    # ax.contourf(grid_x.cpu(), grid_y.cpu(), pred_a,\n",
    "    #             levels=np.arange(pred_a.max()+2)-.5,\n",
    "    #             alpha=0.15, colors=colors[0])\n",
    "\n",
    "\n",
    "    ax.contour(grid_x.cpu(), grid_y.cpu(), pred_a,\n",
    "               levels=[0.5,1.5], colors=colors[0],\n",
    "               linestyles=linestyles[0], linewidths=1.2)\n",
    "    ax.contour(grid_x.cpu(), grid_y.cpu(), pred_b,\n",
    "               levels=[0.5,1.5], colors=colors[1],\n",
    "               linestyles=linestyles[1], linewidths=1.2)\n",
    "    ax.contour(grid_x.cpu(), grid_y.cpu(), pred_c,\n",
    "               levels=[0.5,1.5], colors=colors[2],\n",
    "               linestyles=linestyles[1], linewidths=1.2)\n",
    "\n",
    "    ax.scatter(X_train[:,0], X_train[:,1], c=y_train,\n",
    "               cmap='viridis', edgecolor='k', s=30)\n",
    "\n",
    "    ax.set_xlim(-1,1); ax.set_ylim(-1,1)\n",
    "    #ax.set_title(title); ax.set_xlabel('x₁'); ax.set_ylabel('x₂')\n",
    "    \n",
    "    # ax.legend(handles=[\n",
    "    #     plt.Line2D([0],[0], color=colors[0], linestyle=linestyles[0],\n",
    "    #                label=r'Model $\\theta$'),\n",
    "    #     plt.Line2D([0],[0], color=colors[1], linestyle=linestyles[1],\n",
    "    #                label=r'Model $\\theta+\\delta$'),\n",
    "    #     plt.Line2D([0],[0], color=colors[2], linestyle=linestyles[1],\n",
    "    #                label=r'Model $\\theta-\\delta$')\n",
    "    # ])\n",
    "    ax.set_xticks([-1,0,1]);  ax.set_yticks([-1,0,1])\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b4d25a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "from torch.nn.utils import parameters_to_vector, vector_to_parameters\n",
    "X, y = next(iter(loader))\n",
    "\n",
    "delta = v_1(model, X.to(device), rho=0.5, noise_scale=0.001)\n",
    "model_pert1 = deepcopy(model).to(device)\n",
    "model_pert2 = deepcopy(model).to(device)\n",
    "theta_flat = parameters_to_vector(model_pert1.parameters()).detach()\n",
    "theta_plus_delta = theta_flat + delta.to(theta_flat.device)\n",
    "theta_minus_delta = theta_flat - delta.to(theta_flat.device)\n",
    "\n",
    "\n",
    "vector_to_parameters(theta_plus_delta, model_pert1.parameters())\n",
    "vector_to_parameters(theta_minus_delta, model_pert2.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ee9f019",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_boundaries(model, model_pert1, model_pert2,\n",
    "                    train_X.numpy(), train_y.numpy(),\n",
    "                    title='Original vs θ+δ decision boundaries')\n",
    "plt.savefig('toy_perturbation.pdf',  dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62edd771",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "from torch.nn.utils import parameters_to_vector, vector_to_parameters\n",
    "X, y = next(iter(loader))\n",
    "original_params_flat = torch.nn.utils.parameters_to_vector(model.parameters()).detach().clone()\n",
    "params_structure = list(model.parameters()) \n",
    "\n",
    "k = original_params_flat.numel()\n",
    "\n",
    "# 4) add noise \n",
    "delta = 0.5 / (math.sqrt(k)) * torch.randn_like(original_params_flat) \n",
    "model_pert1 = deepcopy(model).to(device)\n",
    "model_pert2 = deepcopy(model).to(device)\n",
    "theta_flat = parameters_to_vector(model_pert1.parameters()).detach()\n",
    "theta_plus_delta = theta_flat + delta.to(theta_flat.device)\n",
    "theta_minus_delta = theta_flat - delta.to(theta_flat.device)\n",
    "\n",
    "\n",
    "vector_to_parameters(theta_plus_delta, model_pert1.parameters())\n",
    "vector_to_parameters(theta_minus_delta, model_pert2.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d753c7c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_boundaries(model, model_pert1, model_pert2,\n",
    "                    train_X.numpy(), train_y.numpy(),\n",
    "                    title='Original vs θ+δ decision boundaries')\n",
    "plt.savefig('toy_perturbation_w_noise.pdf',  dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3068642",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn.utils import parameters_to_vector, vector_to_parameters\n",
    "\n",
    "def empirical_fisher(model, data_loader):\n",
    "    \"\"\"\n",
    "    Return:\n",
    "        F : (n_params, n_params)  empirical Fisher matrix (torch.Tensor, CPU)\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    n_params = sum(p.numel() for p in model.parameters())\n",
    "\n",
    "    Fim = torch.zeros(n_params, n_params, dtype=torch.float64)\n",
    "\n",
    "    # loop over mini‑batches \n",
    "    model.eval()\n",
    "    for x_batch, y_batch in data_loader:\n",
    "        x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
    "        \n",
    "        # per‑sample negative log‑likelihood\n",
    "        logits = model(x_batch)\n",
    "        loss_vec = torch.nn.functional.cross_entropy(\n",
    "            logits, y_batch, reduction='none')     # shape = (B,)\n",
    "\n",
    "        for idx, loss in enumerate(loss_vec):\n",
    "            model.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "\n",
    "            # grad → g_i\n",
    "            g_vec = parameters_to_vector([p.grad for p in model.parameters()]\n",
    "                                         ).detach().cpu().double()\n",
    "\n",
    "            Fim += torch.outer(g_vec, g_vec)        # accum g_i g_i^T \n",
    "\n",
    "    Fim /= len(data_loader.dataset)                 # (1/N) Σ\n",
    "    return Fim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1631925c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "valid_X, valid_y = sample_toy_data(1000)\n",
    "val_loader = torch.utils.data.DataLoader(ToyDataset(valid_X, valid_y), \n",
    "                                     batch_size=64, shuffle=False)\n",
    "FIM = empirical_fisher(model, val_loader)          # (d × d) Tensor on CPU\n",
    "print(\"FIM shape:\", FIM.shape)\n",
    "\n",
    "\n",
    "\n",
    "eigvals, eigvecs = torch.linalg.eigh(FIM)           \n",
    "print(\"λ_max:\", eigvals[-1].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f57da9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"λ_max:\", eigvals[-1].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26e5ba34",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "valid_X, valid_y = sample_toy_data(1000)\n",
    "val_loader = torch.utils.data.DataLoader(ToyDataset(valid_X, valid_y), \n",
    "                                     batch_size=len(valid_X))\n",
    "X, y = next(iter(val_loader))\n",
    "print(X.shape)\n",
    "print(model(X.to(device)).type())\n",
    "print(F.softmax(model(X.to(device)), 1).type())\n",
    "delta = v_1(model, X.to(device), rho=1, noise_scale=0.001)\n",
    "print(\"delta shape:\", delta.shape)\n",
    "print(\"delta norm:\", delta.norm().item())\n",
    "delta = delta.cpu().double()\n",
    "print(\"eigenvecs norm:\", eigvecs[-1].norm().item())\n",
    "print(\"cos with v_max, :\", eigvecs[-1].dot(delta.cpu()).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d46820",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"eigenvecs norm:\", eigvecs[-2].norm().item())\n",
    "print(\"cos with v_2, :\", eigvecs[-2].dot(delta.cpu()).item())\n",
    "print(\"eigenvecs norm:\", eigvecs[-3].norm().item())\n",
    "print(\"cos with v_3, :\", eigvecs[-3].dot(delta.cpu()).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "793344ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "eigvals[-1].item(), eigvals[-2].item(), eigvals[-3].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31b55ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"ratio 1\", eigvals[-1].item()/eigvals.sum().item())\n",
    "print(\"ratio 2\", eigvals[-2].item()/eigvals.sum().item())\n",
    "print(\"ratio 3\", eigvals[-3].item()/eigvals.sum().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35c9608e",
   "metadata": {},
   "outputs": [],
   "source": [
    "eigvecs = eigvecs.double()\n",
    "\n",
    "cosines = {1: [], 2: [], 3: []}\n",
    "for _ in range(10000):\n",
    "    delta = v_1(model, X.to(device), rho=1.0, noise_scale=0.1)\n",
    "    delta = delta.cpu().double()\n",
    "    delta_norm = delta.norm().item()\n",
    "    if delta_norm == 0:\n",
    "        continue\n",
    "\n",
    "    for j in (1,2,3):\n",
    "        v = eigvecs[:, -j]\n",
    "        cos_sim = v.dot(delta) / (v.norm().item() * delta_norm)\n",
    "        cosines[j].append(torch.abs(cos_sim).item())\n",
    "\n",
    "\n",
    "for j in (1,2,3):\n",
    "    arr = np.array(cosines[j])\n",
    "    print(f\"Top{j} eigenvector: mean={arr.mean():.3f}, std={arr.std():.3f}\")\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07bdbc64",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum_3 = np.sqrt(np.array(cosines[1])**2 + np.array(cosines[2])**2 + np.array(cosines[3])**2)\n",
    "\n",
    "sns.histplot(sum_3, bins=30, stat='density')\n",
    "#plt.xlabel(r'$\\sqrt{\\sum_i^3 \\delta_1^\\top v_i}/\\|\\delta_1\\|$')\n",
    "plt.savefig('cosine_sum3.pdf', format=\"pdf\", dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f60fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86329e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(cosines[1], bins=50, color='blue', label='v_max', stat= \"density\")\n",
    "sns.histplot(cosines[2], bins=50, color='orange', label='v_2', stat= \"density\")\n",
    "sns.histplot(cosines[3], bins=50, color='green', label='v_3', stat= \"density\")\n",
    "plt.axvline(x=0, color='black', linestyle='--')\n",
    "plt.legend()\n",
    "plt.title('absolute Cosine similarity with eigenvectors')\n",
    "plt.savefig('cosine_similarity.pdf', format=\"pdf\", dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8569755",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(eigvals.cpu().numpy()[5200:-1], marker='o', linestyle='-', label='Eigenvalues')\n",
    "plt.title('Eigen Spectra')\n",
    "plt.xlabel('Index')\n",
    "plt.ylabel('Eigenvalue')\n",
    "plt.yscale('log')\n",
    "plt.grid(True)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hstest",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
