{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e7fff997-d86e-42d0-9c89-ab4c58587d76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ntk_utils import NTKHelper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9986fecf-2b83-49cf-be91-4b1fabd94a16",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Network\n",
    "import resnet\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "\n",
    "class ResNet18ImageNetPretrainedFactory():\n",
    "\n",
    "    def __init__(self, factory_args = None):\n",
    "        super(ResNet18ImageNetPretrainedFactory, self).__init__(factory_args)\n",
    "        self.args = factory_args\n",
    "\n",
    "    def getNets(self, input_shape, output_shape):\n",
    "\n",
    "        assert input_shape[0] == 3\n",
    "        # assert input_shape[1] == 224\n",
    "        # assert input_shape[2] == 224\n",
    "        # assert output_shape[0] == 1000\n",
    "\n",
    "        whole_net = resnet.resnet18(\n",
    "            pretrained = True,\n",
    "            act = lambda: nn.ReLU(inplace=False)\n",
    "        )\n",
    "\n",
    "        head = whole_net.fc\n",
    "        whole_net.fc = nn.Identity()\n",
    "        net = whole_net\n",
    "\n",
    "        # print(net)\n",
    "        # print(head)\n",
    "\n",
    "        return net, head\n",
    "\n",
    "net = nn.Sequential(*(ResNet18ImageNetPretrainedFactory().getNets(\n",
    "    [3, 224, 224],\n",
    "    [1000,],\n",
    ")))\n",
    "net.eval()\n",
    "net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f5cd32ed-56d8-436d-a4a3-2bc50d4aceda",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntk_helper = NTKHelper(net, torch.LongTensor(np.random.choice(1000, 10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "124a35d6-bf2d-4651-9a61-ac392b654dd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "from openood.datasets.imglist_dataset import ImglistDataset\n",
    "from openood.evaluation_api.preprocessor import get_default_preprocessor\n",
    "\n",
    "prepro = get_default_preprocessor('imagenet')\n",
    "\n",
    "id_dataset = ImglistDataset(\n",
    "    name = 'test',\n",
    "    imglist_pth = '.../.../.../data/benchmark_imglist/imagenet/val_imagenet.txt', # Paths are anonymized; This is OpenOOD data path\n",
    "    data_dir = '.../.../.../data/images_largescale',\n",
    "    num_classes = 1000,\n",
    "    preprocessor = prepro,\n",
    "    data_aux_preprocessor = prepro\n",
    ")\n",
    "\n",
    "N_train = 256\n",
    "X_dataset = torch.utils.data.Subset(id_dataset, np.random.choice(len(id_dataset), N_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe3e118d-cc3e-4940-8c6b-5f90c7295e5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "X = DataLoader(X_dataset, batch_size = 4, shuffle = False)\n",
    "\n",
    "_Oxx_diag = ntk_helper.compute_ntk(X, X, batch_mode = '1to1', x1_map = lambda b : b['data'], x2_map = lambda b : b['data']).detach().cpu().numpy()\n",
    "np.save(\"IDxx.npy\", _Oxx_diag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e605f125-f115-4a43-8849-0a751e50f456",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openood.datasets.imglist_dataset import ImglistDataset\n",
    "from openood.evaluation_api.preprocessor import get_default_preprocessor\n",
    "import os\n",
    "\n",
    "ood_datasets = {\n",
    "    \"ImageNet\": \"benchmark_imglist/imagenet/test_imagenet.txt\",\n",
    "    \"ImageNet-C\": \"benchmark_imglist/imagenet/test_imagenet_c.txt\",\n",
    "    \"ImageNet-R\": \"benchmark_imglist/imagenet/test_imagenet_r.txt\",\n",
    "    \"ssb_hard\": \"benchmark_imglist/imagenet/test_ssb_hard.txt\",\n",
    "    \"iNaturalList\": \"benchmark_imglist/imagenet/test_inaturalist.txt\",\n",
    "    \"Textures\": \"benchmark_imglist/imagenet/test_textures.txt\",\n",
    "}\n",
    "\n",
    "_Ozxs = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdfe83e9-8c59-4871-a50b-d615c9a536c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_test = 128\n",
    "for key in ood_datasets:\n",
    "    print(key)\n",
    "    ood_dataset = ImglistDataset(\n",
    "        name = 'test',\n",
    "        imglist_pth = os.path.join(\".../.../.../data/\", ood_datasets[key]),\n",
    "        data_dir = \".../.../.../data/images_largescale\" if key != \"Textures\" else \".../.../.../data/images_classic\",\n",
    "        num_classes = 1000,\n",
    "        preprocessor = prepro,\n",
    "        data_aux_preprocessor = prepro\n",
    "    )\n",
    "    Z_dataset = torch.utils.data.Subset(ood_dataset, np.random.choice(len(ood_dataset), N_test))\n",
    "    Z = DataLoader(Z_dataset, batch_size = 4, shuffle = False)\n",
    "    _Ozxs[key] = ntk_helper.compute_ntk(Z, X, mode = 'full', x1_map = lambda b : b['data'], x2_map = lambda b : b['data']).detach().cpu().numpy()\n",
    "    np.save(\"%s.npy\" % key, _Ozxs[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7e6de9cd-fe59-4b89-80cf-d670b20258ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "_Oxx_diag = np.load(\"IDxx.npy\")\n",
    "for key in ood_datasets:\n",
    "    _Ozxs[key] = np.load(\"%s.npy\" % key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23675032-d48f-4860-970c-d8d28d4f2a61",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def visualize(Ozx, Oxx):\n",
    "    # print(Ozx.shape)\n",
    "    # print(Oxx.shape)\n",
    "    lhs = (Oxx[None, :] - 2 * np.trace(Ozx, axis1 = -1, axis2 = -2)).min(axis = 1)\n",
    "    rhs = (Oxx.mean() - 2 * np.linalg.norm(Ozx, 'fro', axis = (-1, -2)).mean(axis = 1, keepdims = True))\n",
    "    # print(Oxx.mean())\n",
    "    # print(Oxx)\n",
    "    return lhs, rhs\n",
    "\n",
    "ls = {k: visualize(_Ozxs[k], _Oxx_diag)[0] for k in _Ozxs}\n",
    "rs = {k: visualize(_Ozxs[k], _Oxx_diag)[1] for k in _Ozxs}\n",
    "# r = visualize(_Ozxs[\"ImageNet\"], _Oxx_diag)[1]\n",
    "# print(ls)\n",
    "# print(rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "977d7535-1db1-4a4c-bc85-5960718f7182",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('./style.mplstyle')\n",
    "plt.tight_layout()\n",
    "\n",
    "fig, ax = plt.subplots(figsize = (1.2, 1.2))\n",
    "\n",
    "# previous_files = {\n",
    "#     \"imagenet\": \"imagenet.npy\",\n",
    "#     # \"imagenet-blur2\": \"imagenet-blur2.npy\",\n",
    "#     \"imagenet-blur6\": \"imagenet-blur6.npy\",\n",
    "# }\n",
    "\n",
    "# plt.scatter(np.random.normal(loc = 0.0, scale = 0.1, size = l.shape), l, s = 4)\n",
    "# plt.boxplot(l)\n",
    "\n",
    "# l_f = [np.load(previous_files[f]) for f in previous_files]\n",
    "# plt.boxplot([l, *l_f], labels = ['l', *previous_files.keys()])\n",
    "\n",
    "# l_f = [np.load(previous_files[f]) for f in previous_files]\n",
    "# plt.boxplot(l_f, labels = previous_files.keys())\n",
    "\n",
    "# l_f = [ls[l] for l in ls]\n",
    "# plt.boxplot(l_f, flierprops={'marker': 'o', 'markersize': 3})\n",
    "# plt.xticks([])\n",
    "\n",
    "for l in ls:\n",
    "    print(ls[l].shape)\n",
    "    print(rs[l].shape)\n",
    "    if l == \"ImageNet\":\n",
    "        # plt.scatter(ls[l], rs[l], marker = '+', label = l, s = 0.2)\n",
    "        pass\n",
    "    else:\n",
    "        plt.scatter(ls[l], rs[l], s = 3, linewidths = 0)\n",
    "\n",
    "l = \"ImageNet\"\n",
    "plt.scatter(ls[l], rs[l], marker = '+', label = \"ID\", s = 8, linewidths = 0.5)\n",
    "\n",
    "a = plt.ylim()\n",
    "plt.scatter([0,], [0,], label = \"OOD\", color='grey', s = 3, linewidths = 0)\n",
    "plt.ylim(a)\n",
    "\n",
    "# plt.plot([0, 8], [0, 8], '--k', linewidth = 0.8)\n",
    "plt.legend(\n",
    "    loc='upper left', framealpha=0.8, bbox_to_anchor=(-0.03, 1.03),\n",
    "    # mode=\"expand\"\n",
    "    handlelength = 0.5,\n",
    "    handletextpad = 0.3,\n",
    ")\n",
    "\n",
    "ax.set_xlabel('LHS of Eq.~(8)')\n",
    "ax.set_ylabel('RHS of Eq.~(8)')\n",
    "\n",
    "ax.text(-.1,1,'\\\\textbf{d)}',\n",
    "        horizontalalignment='center',\n",
    "        transform=ax.transAxes)\n",
    "\n",
    "plt.savefig(\"Eq8-visualization.pdf\", format=\"pdf\", bbox_inches = 'tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
