{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3220b43-ba1c-4803-9810-78c7bf1f4f88",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import datasets, models, transforms\n",
    "from torchvision import datasets\n",
    "import PIL\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import datasets, models, transforms\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import sklearn\n",
    "import numpy as np\n",
    "import einops\n",
    "import wandb\n",
    "import PIL\n",
    "import sys\n",
    "sys.path.append(\"./vicreg/\")\n",
    "import vicreg\n",
    "from main import *\n",
    "import main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbd9da6d-8e97-4a0e-9f9f-b29c0dd9c43f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pil_loader(path: str) -> Image.Image:\n",
    "    with open(path, 'rb') as f:\n",
    "        img = Image.open(f)\n",
    "        return img.convert('RGB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfe91da4-0df1-4820-a82c-dd99d215764c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DirImageDataset(Dataset):\n",
    "    def __init__(self, img_dir, transform=None):\n",
    "        self.img_list = os.listdir(img_dir)\n",
    "        self.img_dir = img_dir\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_path = os.path.join(self.img_dir, self.img_list[idx])\n",
    "        image = pil_loader(img_path)\n",
    "        label = self.img_list[idx].split(\"_\")[0]\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        return image, label\n",
    "\n",
    "class LinkImageDataset(Dataset):\n",
    "    \"Given a list of Image Links all the images will be loaded along with their associated labels\"\n",
    "    def __init__(self, img_list, transform=None):\n",
    "        self.img_list = img_list\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_path = self.img_list[idx]\n",
    "        image = pil_loader(img_path)\n",
    "        label = (self.img_list[idx].split(\"/\")[-1]).split(\"_\")[0]\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        return image, label, img_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "172a54da-b607-4161-83e2-690f5b146e29",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "  arch = \"resnet50\"\n",
    "  mlp = \"8192-8192-8192\"\n",
    "\n",
    "args=Args()\n",
    "\n",
    "model = main.VICReg(args=args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41427c6c-911b-45f7-8936-0aa630e9a07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load(\"./resnet50_fullckpt.pth\")\n",
    "weights = {}\n",
    "for k in checkpoint[\"model\"].keys():\n",
    "    weights[k[7:]] = checkpoint[\"model\"][k]\n",
    "model.load_state_dict(weights)\n",
    "model = model.to(\"cuda:1\")\n",
    "model.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed5e451f-9f3c-4616-9a7b-7ec29fecd8b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_list = os.listdir(\"../SSALD/Data/imagenet/\")\n",
    "img_list = list(map(lambda x: \"../SSALD/Data/imagenet/\"+x, img_list))\n",
    "transform = transforms.Compose([transforms.Resize((224,224)),\n",
    "                            transforms.ToTensor()\n",
    "                           ])\n",
    "\n",
    "\n",
    "dataset = LinkImageDataset(img_list,\n",
    "                           transform=transform)\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(dataset,\n",
    "                                         batch_size=512,\n",
    "                                         shuffle=False,\n",
    "                                         drop_last = False,\n",
    "                                         num_workers=8,\n",
    "                                         pin_memory = True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2364e2b2-dde6-4a1c-b735-cc1e382a59e7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "feature_list = []\n",
    "label_list = []\n",
    "path_list = []\n",
    "embedding_list = []\n",
    "with torch.no_grad():\n",
    "    for image, label, img_path in tqdm(dataloader):\n",
    "        image = image.to(\"cuda:1\")\n",
    "        embedding = model.backbone(image)\n",
    "        projection = model.projector(embedding)\n",
    "        projection = projection.cpu().numpy()\n",
    "        embedding = embedding.cpu().numpy()\n",
    "        feature_list.extend(projection)\n",
    "        embedding_list.extend(embedding)\n",
    "        label_list.extend(label)\n",
    "        path_list.extend(img_path)\n",
    "feature_list = torch.Tensor(np.array(feature_list))\n",
    "embedding_list = torch.Tensor(np.array(embedding_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86362d37-3bb0-41ae-a466-6b2c2b9e3dc6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "feature_list = np.array(feature_list)\n",
    "label_list = np.array(label_list)\n",
    "path_list = np.array(path_list)\n",
    "embedding_list = np.array(embedding_list)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d3b932-0b70-461a-b1f0-6731528e0fb3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.save(\"./space/feature_list.npy\",feature_list)\n",
    "np.save(\"./space/label_list.npy\",label_list)\n",
    "np.save(\"./space/path_list.npy\",path_list)\n",
    "np.save(\"./space/embedding_list.npy\",embedding_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e269d2a9-b8ba-401d-9ca0-4e2b5bc8c835",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
