{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc61e750-ad7f-40e4-b8d6-e53b6cc04c02",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import main\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import os\n",
    "import argparse\n",
    "from tqdm.auto import tqdm\n",
    "import PIL as pil\n",
    "from PIL import Image\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch.utils.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d413b15e-0ec6-483e-8ef7-7ae7ebf82895",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Args:\n",
    "  projector = \"8192-8192-8192\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "696b9a94-527a-452e-b4f9-c453d727f840",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5af411ca-2280-4085-b951-9c56027fa605",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../VICReg/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "916f965a-3a77-42f5-8381-4a4268b6ea04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import vicreg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "969c85b2-7c66-4e14-9a5d-c7c9f97ec200",
   "metadata": {},
   "outputs": [],
   "source": [
    "vicreg."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d29d08-1d4f-4e7f-950d-1b9df6ef66c5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = main.BarlowTwins(Args)\n",
    "checkpoint = torch.load(\"./checkpoint.pth\")\n",
    "state_dict = checkpoint['model']\n",
    "old_keys = state_dict.keys()\n",
    "new_keys = list(map(lambda x: x.split(\"module.\")[-1], old_keys))\n",
    "state_dict = dict(zip(new_keys, list(state_dict.values())))\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "528b29bb-a4c1-411b-8e25-d89be26bfc14",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = model.to(\"cuda:0\")\n",
    "model.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "182a1133-a600-4831-a8c2-40bbcf83fe97",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pil_loader(path: str) -> Image.Image:\n",
    "    with open(path, 'rb') as f:\n",
    "        img = Image.open(f)\n",
    "        return img.convert('RGB')\n",
    "\n",
    "class DirImageDataset(Dataset):\n",
    "    def __init__(self, img_dir, transform=None):\n",
    "        self.img_list = os.listdir(img_dir)\n",
    "        self.img_dir = img_dir\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_path = os.path.join(self.img_dir, self.img_list[idx])\n",
    "        image = pil_loader(img_path)\n",
    "        label = self.img_list[idx].split(\"_\")[0]\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        return image, label\n",
    "\n",
    "class LinkImageDataset(Dataset):\n",
    "    \"Given a list of Image Links all the images will be loaded along with their associated labels\"\n",
    "    def __init__(self, img_list, transform=None):\n",
    "        self.img_list = img_list\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_path = self.img_list[idx]\n",
    "        image = pil_loader(img_path)\n",
    "        label = (self.img_list[idx].split(\"/\")[-1]).split(\"_\")[0]\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        return image, label, img_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ca67edb-3f85-4c7b-ba4c-2f50b36fee7c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "img_list = os.listdir(\"../SSALD/Data/imagenet/\")\n",
    "img_list = list(map(lambda x: \"../SSALD/Data/imagenet/\"+x, img_list))\n",
    "transform = transforms.Compose([transforms.Resize((224,224)),\n",
    "                            transforms.ToTensor()\n",
    "                           ])\n",
    "\n",
    "\n",
    "dataset = LinkImageDataset(img_list,\n",
    "                           transform=transform)\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(dataset,\n",
    "                                         batch_size=512,\n",
    "                                         shuffle=False,\n",
    "                                         drop_last = False,\n",
    "                                         num_workers=8,\n",
    "                                         pin_memory = True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d6dfa9b-cbed-4d20-95e5-9b156f3d1fee",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "feature_list = []\n",
    "label_list = []\n",
    "path_list = []\n",
    "embedding_list = []\n",
    "with torch.no_grad():\n",
    "    for image, label, img_path in tqdm(dataloader):\n",
    "        image = image.to(\"cuda:0\")\n",
    "        embedding = model.backbone(image)\n",
    "        projection = model.projector(embedding)\n",
    "        projection = projection.cpu().numpy()\n",
    "        embedding = embedding.cpu().numpy()\n",
    "        feature_list.extend(projection)\n",
    "        embedding_list.extend(embedding)\n",
    "        label_list.extend(label)\n",
    "        path_list.extend(img_path)\n",
    "feature_list = torch.Tensor(np.array(feature_list))\n",
    "embedding_list = torch.Tensor(np.array(embedding_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bafd3008-e834-444e-a70e-10c20c6c46e7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "1+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a03c1ea-1491-4f09-b372-39e892b8fcaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_list = np.array(feature_list)\n",
    "label_list = np.array(label_list)\n",
    "path_list = np.array(path_list)\n",
    "embedding_list = np.array(embedding_list)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9239ccc1-9579-4b7d-acd4-dd0d47657137",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.save(\"./space/feature_list.npy\",feature_list)\n",
    "np.save(\"./space/label_list.npy\",label_list)\n",
    "np.save(\"./space/path_list.npy\",path_list)\n",
    "np.save(\"./space/embedding_list.npy\",embedding_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20c3f35d-5747-4852-9b85-29b91c752082",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "path_list = np.load(\"./space/path_list.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b14d3eeb-ed30-4174-ae4f-d27a01281955",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "path_list = np.array(list(map(lambda x: \"./\"+x.split(\"../SSALD/\")[-1] , path_list.tolist())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc3c8585-ea34-439d-86f7-ea6a9b5153bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "path_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95ded859-7dca-489c-a64d-0902e10f8e62",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.save(\"./space/path_list.npy\",path_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d436b2dd-a5fd-416b-9f56-098e59cfe17a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
