{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clone the Github repo https://github.com/aw31/empirical-ntks to the local filesystem\n",
    "import os\n",
    "os.chdir('empirical-ntks-main')\n",
    "\n",
    "import copy\n",
    "import logging\n",
    "import pathlib\n",
    "import threading\n",
    "import time\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.multiprocessing import Process, Queue\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from multiqueue_worker import multiqueue_worker\n",
    "from utils import init_torch, humanize_units\n",
    "\n",
    "local = threading.local()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import pprint\n",
    "import sys\n",
    "from torch.multiprocessing import set_start_method, set_sharing_strategy\n",
    "from utils import init_logging, load_model, load_dataset\n",
    "\n",
    "# Set up\n",
    "set_start_method(\"spawn\")\n",
    "set_sharing_strategy(\"file_system\")\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"dataset\", type=str)\n",
    "parser.add_argument(\"model\", type=str)\n",
    "parser.add_argument(\"--datadir\", type=str, default=\"./datasets\")\n",
    "parser.add_argument(\"--savedir\", type=str, default=\"./ntks\")\n",
    "parser.add_argument(\"--logdir\", type=str)\n",
    "parser.add_argument(\"--workers-per-device\", type=int, default=1)\n",
    "parser.add_argument(\"--grad-chunksize\", type=int, default=1900000)\n",
    "parser.add_argument(\"--mm-col-chunksize\", type=int, default=20000)\n",
    "parser.add_argument(\"--ntk-dtype\", type=str, default=\"float32\")\n",
    "parser.add_argument(\"--loader-batch-size\", type=int, default=512)\n",
    "parser.add_argument(\"--loader-num-workers\", type=int, default=12)\n",
    "parser.add_argument(\"--no-pinned-memory\", dest=\"pin_memory\", action=\"store_false\")\n",
    "parser.add_argument(\"--allow-tf32\", action=\"store_true\")\n",
    "parser.add_argument(\"--benchmark\", action=\"store_true\")\n",
    "parser.add_argument(\n",
    "    \"--non-deterministic\", dest=\"deterministic\", action=\"store_false\"\n",
    ")\n",
    "args = parser.parse_args(args=['CIFAR-100', 'resnet-50_pretrained'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize torch\n",
    "init_torch_kwargs = {\n",
    "    \"allow_tf32\": args.allow_tf32,\n",
    "    \"benchmark\": args.benchmark,\n",
    "    \"deterministic\": args.deterministic,\n",
    "}\n",
    "init_torch(**init_torch_kwargs, verbose=True)\n",
    "\n",
    "# Initialize model\n",
    "model = load_model(args.model)\n",
    "\n",
    "param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "param_batches = (param_count - 1) // args.grad_chunksize + 1\n",
    "logging.info(f\"Splitting {param_count} parameters into {param_batches} batches\")\n",
    "\n",
    "# Initialize datasets\n",
    "datadir = pathlib.Path(args.datadir)\n",
    "train_set = load_dataset(datadir, args.dataset, \"train\")\n",
    "if args.dataset=='Flowers-102':\n",
    "    val_set = load_dataset(datadir, args.dataset, \"val\")\n",
    "\n",
    "test_set = load_dataset(datadir, args.dataset, \"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeatureExtractor(torch.nn.Module):\n",
    "    def __init__(self, pretrained_model):\n",
    "        super(FeatureExtractor, self).__init__()\n",
    "        self.features = torch.nn.Sequential(*list(pretrained_model.children())[:-1])\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.features(x)\n",
    "        features = features.view(features.size(0), -1)\n",
    "        return features\n",
    "\n",
    "# Initialize model\n",
    "pretrain_model = load_model(args.model)\n",
    "# Create an instance of the feature extractor\n",
    "model = FeatureExtractor(pretrain_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_features(model, test_set, batch_size=128):\n",
    "    # Set the model to evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    # Disable gradient calculation\n",
    "    with torch.no_grad():\n",
    "        # Move the model to the GPU if available\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        model.to(device)\n",
    "\n",
    "        # Create a DataLoader for the test dataset\n",
    "        test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=12)\n",
    "\n",
    "        # Create an empty list to store the features\n",
    "        X = []\n",
    "        Y = []\n",
    "        # Iterate over the test dataset\n",
    "        for images, labels in tqdm(test_loader):\n",
    "            # Move the images to the GPU if available\n",
    "            images = images.to(device)\n",
    "\n",
    "            # Forward pass through the model\n",
    "            output = model.features(images)\n",
    "\n",
    "            # Append the features to the list\n",
    "            X.append(output.cpu())\n",
    "            Y.append(labels.cpu())\n",
    "\n",
    "        # Concatenate the features into a single tensor\n",
    "        X = torch.cat(X)\n",
    "        X = X.reshape(X.shape[0],-1).numpy()\n",
    "        Y = torch.nn.functional.one_hot(torch.cat(Y)).numpy()\n",
    "    return X, Y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 391/391 [29:04<00:00,  4.46s/it] \n",
      "100%|██████████| 79/79 [07:36<00:00,  5.78s/it] \n"
     ]
    }
   ],
   "source": [
    "X_train, Y_train = get_features(model, train_set)\n",
    "if args.dataset=='Flowers-102':\n",
    "    X_val, Y_val = get_features(model, val_set)\n",
    "X_test, Y_test = get_features(model, test_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "res = {'X_train':X_train, 'Y_train':Y_train,\n",
    "        'X_test':X_test, 'Y_test':Y_test}\n",
    "if args.dataset=='Flowers-102':\n",
    "    res['X_val'] = X_val\n",
    "    res['Y_val'] = Y_val\n",
    "with open(\"../data_{}_{}.npz\".format(args.dataset,args.model), \"wb\") as f:\n",
    "    np.savez(f, **res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Subset the CIFAR-100 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, model = 'CIFAR-100', 'resnet-50_pretrained'\n",
    "with open('../data_{}_{}_full.npz'.format(data, model), 'rb') as f:\n",
    "    dat = np.load(f)\n",
    "    X = dat['X_train']\n",
    "    Ys = dat['Y_train']\n",
    "    X_test = dat['X_test']    \n",
    "    Ys_test = dat['Y_test']\n",
    "import numpy as np\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "\n",
    "sss = StratifiedShuffleSplit(n_splits=10, test_size=0.9, random_state=0)\n",
    "idx = next(sss.split(X_train, np.argmax(Y_train, axis=1)))[0]\n",
    "\n",
    "res = {'X_train':X_train[idx,:], 'Y_train':Y_train[idx,:],\n",
    "        'X_test':X_test, 'Y_test':Y_test}\n",
    "\n",
    "with open(\"../data_{}_{}.npz\".format(args.dataset,args.model), \"wb\") as f:\n",
    "    np.savez(f, **res)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vitae",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
