{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1ce198bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload \n",
    "%autoreload 2\n",
    "\n",
    "import torch \n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision\n",
    "import timm\n",
    "\n",
    "from tqdm import tqdm\n",
    "from src.dp_sgd.config import shots, num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a04714a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.accelerator.current_accelerator()\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d48acc3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_extractor_name = \"vit-b-16-imagenet-21K\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1001ce88",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_extractor = timm.create_model(\"vit_base_patch16_224_in21k\", pretrained=True, num_classes=0).to(device)\n",
    "data_cfg = timm.data.resolve_data_config(feature_extractor.pretrained_cfg)\n",
    "image_transform = timm.data.create_transform(**data_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4ca3c054",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'url': 'https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',\n",
       " 'hf_hub_id': 'timm/vit_base_patch16_224.augreg_in21k',\n",
       " 'architecture': 'vit_base_patch16_224',\n",
       " 'tag': 'augreg_in21k',\n",
       " 'custom_load': True,\n",
       " 'input_size': (3, 224, 224),\n",
       " 'fixed_input_size': True,\n",
       " 'interpolation': 'bicubic',\n",
       " 'crop_pct': 0.9,\n",
       " 'crop_mode': 'center',\n",
       " 'mean': (0.5, 0.5, 0.5),\n",
       " 'std': (0.5, 0.5, 0.5),\n",
       " 'num_classes': 21843,\n",
       " 'pool_size': None,\n",
       " 'first_conv': 'patch_embed.proj',\n",
       " 'classifier': 'head',\n",
       " 'license': 'apache-2.0'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_extractor.pretrained_cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "99ccc8a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = torchvision.datasets.CIFAR10(root=\"datasets/cifar10\", train=True, download=True, transform=image_transform)\n",
    "test_data = torchvision.datasets.CIFAR10(root=\"datasets/cifar10\", train=False, download=True, transform=image_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d127ec74",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a91ac837",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(torch.tensor(train_data.targets), \"datasets/cifar10/train_labels.pt\")\n",
    "torch.save(torch.tensor(test_data.targets), \"datasets/cifar10/test_labels.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "376bb038",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 3, 224, 224])\n",
      "torch.Size([64, 768])\n"
     ]
    }
   ],
   "source": [
    "for images, labels in train_loader:\n",
    "    print(images.shape)\n",
    "    outputs = feature_extractor(images.to(device))\n",
    "    # print(outputs)\n",
    "    feature_dim = outputs.shape[1]\n",
    "    print(outputs.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "63d75500",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Python(24653) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.\n",
      "100%|██████████| 782/782 [29:00<00:00,  2.23s/it]\n",
      "100%|██████████| 157/157 [05:47<00:00,  2.21s/it]\n"
     ]
    }
   ],
   "source": [
    "train_features = torch.zeros((len(train_data), feature_dim))\n",
    "test_features = torch.zeros((len(test_data), feature_dim))\n",
    "\n",
    "def extract_features(data_loader, features_tensor):\n",
    "    with torch.no_grad():\n",
    "        start_idx = 0\n",
    "        for images, _ in tqdm(data_loader):\n",
    "            batch_size = images.shape[0]\n",
    "            outputs = feature_extractor(images.to(device))\n",
    "            features_tensor[start_idx:start_idx + batch_size] = outputs.cpu()\n",
    "            start_idx += batch_size\n",
    "\n",
    "\n",
    "extract_features(train_loader, train_features)\n",
    "torch.save(train_features, f\"datasets/cifar10/train_{feature_extractor_name}_features.pt\")\n",
    "\n",
    "extract_features(test_loader, test_features)\n",
    "torch.save(test_features, f\"datasets/cifar10/test_{feature_extractor_name}_features.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8fc1322a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class counts in training set: [50, 50, 50, 50, 50, 50, 50, 50, 50, 50]\n",
      "Class counts in validation set: [50, 50, 50, 50, 50, 50, 50, 50, 50, 50]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "train_features = torch.load(\"datasets/cifar10/train_vit-b-16-imagenet-21K_features.pt\")\n",
    "train_labels = torch.load(\"datasets/cifar10/train_labels.pt\")\n",
    "train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)\n",
    "\n",
    "torch.manual_seed(4237426)\n",
    "indices_train = []\n",
    "indices_validation = []\n",
    "for c in range(num_classes):\n",
    "    class_indices = (train_labels == c).nonzero(as_tuple=True)[0]\n",
    "    random_permutation = torch.randperm(len(class_indices))\n",
    "    selected_indices_train = class_indices[random_permutation[:shots]]\n",
    "    selected_indices_validation = class_indices[random_permutation[shots:shots * 2]]\n",
    "    indices_train.append(selected_indices_train)\n",
    "    indices_validation.append(selected_indices_validation)\n",
    "\n",
    "indices_train = torch.cat(indices_train)\n",
    "indices_validation = torch.cat(indices_validation)\n",
    "\n",
    "train_dataset = torch.utils.data.TensorDataset(train_features[indices_train], train_labels[indices_train])\n",
    "class_counts = torch.zeros(num_classes, dtype=torch.int)\n",
    "for image, label in train_dataset:\n",
    "    class_counts[label] += 1\n",
    "print(\"Class counts in training set:\", class_counts.tolist())\n",
    "\n",
    "validation_dataset = torch.utils.data.TensorDataset(train_features[indices_validation], train_labels[indices_validation])\n",
    "class_counts = torch.zeros(num_classes, dtype=torch.int)\n",
    "for image, label in validation_dataset:\n",
    "    class_counts[label] += 1\n",
    "print(\"Class counts in validation set:\", class_counts.tolist())\n",
    "\n",
    "torch.save(train_features[indices_train], f\"datasets/cifar10/few_shot_{shots}_train_{feature_extractor_name}_features.pt\")\n",
    "torch.save(train_labels[indices_train], f\"datasets/cifar10/few_shot_{shots}_train_labels.pt\")\n",
    "\n",
    "torch.save(train_features[indices_validation], f\"datasets/cifar10/few_shot_{shots}_validation_{feature_extractor_name}_features.pt\")\n",
    "torch.save(train_labels[indices_validation], f\"datasets/cifar10/few_shot_{shots}_validation_labels.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae4fb2aa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "weakening-dp-bounds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
