{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9f981f39-0a23-4508-89ee-616e32aa8c76",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torchextractor in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (0.3.0)\n",
      "Requirement already satisfied: numpy in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torchextractor) (1.23.5)\n",
      "Requirement already satisfied: torch>=1.4.0 in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torchextractor) (2.0.0)\n",
      "Requirement already satisfied: filelock in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torch>=1.4.0->torchextractor) (3.9.0)\n",
      "Requirement already satisfied: typing-extensions in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torch>=1.4.0->torchextractor) (4.4.0)\n",
      "Requirement already satisfied: sympy in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torch>=1.4.0->torchextractor) (1.11.1)\n",
      "Requirement already satisfied: networkx in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torch>=1.4.0->torchextractor) (2.8.4)\n",
      "Requirement already satisfied: jinja2 in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from torch>=1.4.0->torchextractor) (3.1.2)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from jinja2->torch>=1.4.0->torchextractor) (2.1.1)\n",
      "Requirement already satisfied: mpmath>=0.19 in /home/matteo/anaconda3/envs/ai/lib/python3.9/site-packages (from sympy->torch>=1.4.0->torchextractor) (1.2.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install torchextractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4ac9855b-f9a5-4cec-8e25-c13bbb866c25",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# %pip install torchmetrics[image]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "160dda51-5017-4689-8146-17586a5203ae",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/braindiff/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import wandb\n",
    "import torchmetrics\n",
    "from torchvision.models import alexnet, inception_v3\n",
    "import json\n",
    "import os\n",
    "from os.path import join as opj\n",
    "from torchvision.transforms import *\n",
    "import numpy as np\n",
    "from scipy import signal\n",
    "import tqdm\n",
    "from torchsummary import summary\n",
    "import torchextractor as tx\n",
    "from torchvision.models.feature_extraction import get_graph_node_names\n",
    "from torchvision.models.feature_extraction import create_feature_extractor\n",
    "import clip\n",
    "from torchmetrics.image.fid import FrechetInceptionDistance\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "52d763ec-70e6-4718-8ca9-219d0a70c35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices=[25,31,68,121,126,318,384,492,531,606,702,860]\n",
    "\n",
    "indices2=[70,116,165,261,278,363,451,774]\n",
    "indices3=[41,205,230,411,428,446,502,777]\n",
    "extra=[95,905]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d7c779f5-369d-4887-8e51-0b12635f9eb3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# target_run=\"rural-blaze-28\"\n",
    "# target_run=\"warm-mountain-29\"\n",
    "\n",
    "target_run=\"resilient-pyramid-1\"  #sub01\n",
    "# target_run=\"legendary-violet-2\"  #sub02\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "288ffee7-6f5f-4125-8a5e-08ac5477a068",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "api=wandb.Api()\n",
    "project=api.project(\"BrainCaptioning\")\n",
    "\n",
    "entity, project = \"matteoferrante\", \"BrainCaptioning\"  # set to your entity and project\n",
    "runs = api.runs(entity + \"/\" + project)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "884c7f93-e744-472e-8dfa-484afbe9ea8e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found resilient-pyramid-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m:   27 of 27 files downloaded.  \n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact run-s51yzgh9-table:v0, 3790.83MB. 12700 files... \n",
      "\u001b[34m\u001b[1mwandb\u001b[0m:   12700 of 12700 files downloaded.  \n",
      "Done. 0:0:9.4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./artifacts/run-s51yzgh9-table:v0\n"
     ]
    }
   ],
   "source": [
    "active_run=None\n",
    "for run in api.runs(project):\n",
    "    if run.name==target_run:\n",
    "        print(f\"found {target_run}\")\n",
    "        active_run=target_run\n",
    "        \n",
    "        for artifact in run.logged_artifacts():\n",
    "            artifact_path = artifact.download()\n",
    "            \n",
    "            if \"-table:v0\" in artifact_path:\n",
    "                decoded_artifact=artifact_path\n",
    "                print(artifact_path)\n",
    "                f=open(f\"{artifact_path}/table.table.json\")\n",
    "                meta = json.load(f)\n",
    "                data=meta[\"data\"]\n",
    "                \n",
    "                image_captions=[]\n",
    "                brain_captions=[]\n",
    "                init_guess=[]\n",
    "                depth_guess=[]\n",
    "                shown=[]\n",
    "                generated=[]\n",
    "                generated_control=[]\n",
    "                for i in range(len(data)):\n",
    "                    image_captions.append(data[i][0])\n",
    "                    \n",
    "                    brain_captions.append(data[i][1])\n",
    "                    \n",
    "                    shown.append(data[i][2][\"path\"])\n",
    "                    \n",
    "                    init_guess.append(data[i][3][\"path\"])\n",
    "                    depth_guess.append(data[i][4][\"path\"])\n",
    "                    \n",
    "                    generated.append(data[i][5][\"path\"])\n",
    "                    generated_control.append(data[i][7][\"path\"])\n",
    "                    \n",
    "                \n",
    "            \n",
    "        break\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06bb7848-c950-400e-82c0-59abb8ce0a58",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Decoded images\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "77012bff-0d82-42d1-a444-05ba1cbc8743",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "shown_imgs=[Image.open(opj(decoded_artifact,i)) for i in shown]\n",
    "generated_imgs=[Image.open(opj(decoded_artifact,i)) for i in generated]\n",
    "generated_control_imgs=[Image.open(opj(decoded_artifact,i)) for i in generated_control]\n",
    "# depth_guess_imgs=[Image.open(opj(decoded_artifact,i))for i in depth_guess]\n",
    "\n",
    "# init_guess_imgs=[Image.open(opj(decoded_artifact,i))for i in init_guess]\n",
    "# \n",
    "                    \n",
    "# depth_guess_1_imgs=np.array(depth_guess)[indices]\n",
    "# depth_guess_1_imgs=[Image.open(opj(decoded_artifact,i)) for i in depth_guess_1_imgs]\n",
    "\n",
    "\n",
    "# init_guess_1_imgs=np.array(init_guess)[indices]\n",
    "# init_guess_1_imgs=[Image.open(opj(decoded_artifact,i)) for i in init_guess_1_imgs]\n",
    "\n",
    "# compare_brain_1_captions=np.array(brain_captions)[indices]\n",
    "# compare_image_1_captions=np.array(image_captions)[indices]\n",
    "\n",
    "\n",
    "# depth_guess_2_imgs=np.array(depth_guess)[indices2]\n",
    "# depth_guess_2_imgs=[Image.open(opj(decoded_artifact,i)) for i in depth_guess_2_imgs]\n",
    "\n",
    "\n",
    "# init_guess_2_imgs=np.array(init_guess)[indices2]\n",
    "# init_guess_2_imgs=[Image.open(opj(decoded_artifact,i)) for i in init_guess_2_imgs]\n",
    "\n",
    "# compare_brain_2_captions=np.array(brain_captions)[indices2]\n",
    "# compare_image_2_captions=np.array(image_captions)[indices2]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d180270-8cd9-456d-8469-5c6e08d37ded",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pixel Correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "77044517-3ed9-4891-97d6-a4753cbbe6ea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "base_transform=Compose([Resize(425),ToTensor()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "88c8f121-7004-487f-a89a-86df7606c0d8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "first_candidates=[i for i in generated_imgs]\n",
    "#first_candidates=[i for i in generated_control_imgs]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3635526b-087c-4c8e-997e-d78bf8c982c2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 982/982 [00:31<00:00, 30.78it/s]\n"
     ]
    }
   ],
   "source": [
    "pix_corr=[]\n",
    "for x,y in tqdm.tqdm(list(zip(shown_imgs,first_candidates))):\n",
    "    x=base_transform(x).numpy()\n",
    "    y=base_transform(y).numpy()\n",
    "    pc= np.abs(np.corrcoef(x.flatten(), y.flatten())[0][1])\n",
    "    pix_corr.append(pc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbdb74a3-83ca-4568-afd8-1df9682e05a0",
   "metadata": {
    "tags": []
   },
   "source": [
    "## SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9223dae7-7a11-4346-b4eb-51822c4248b7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ssim_metric=torchmetrics.StructuralSimilarityIndexMeasure()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "07977f26-bf25-4a04-b632-5f0625e8532a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_torch=[base_transform(i) for i in shown_imgs]\n",
    "y_torch=[base_transform(i) for i in first_candidates]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3562dbb4-dde0-4b27-ac80-46e74b13b24f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_torch=torch.stack(x_torch)\n",
    "y_torch=torch.stack(y_torch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f5ec7498-2ac4-4048-b312-0c40a7322c33",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ssim=ssim_metric(x_torch,y_torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7eedaa0-ec3c-40c7-ab8e-a053f5dad658",
   "metadata": {},
   "source": [
    "## N-way "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c5324bd1-b8d2-4718-b6e5-2439ffb82648",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def n_way_acc(x,y,model,n=2,device=\"cpu\"):\n",
    "    \n",
    "    model.to(device)\n",
    "    cosine=torch.nn.CosineSimilarity()\n",
    "    acc=0\n",
    "    with torch.no_grad():\n",
    "        for i in tqdm.tqdm(range(len(x))):\n",
    "            original_features=model(x[i].unsqueeze(0).to(device)).cpu()\n",
    "            generated_features=model(y[i].unsqueeze(0).to(device)).cpu()\n",
    "\n",
    "            rnd_idx=torch.randint(len(x_torch),(1,)).item()\n",
    "\n",
    "            random_features=model(y[rnd_idx].unsqueeze(0).to(device)).cpu()\n",
    "\n",
    "            gen_dist=cosine(original_features.reshape(len(original_features),-1),generated_features.reshape(len(generated_features),-1))\n",
    "            rnd_dist=cosine(original_features.reshape(len(original_features),-1),random_features.reshape(len(random_features),-1))\n",
    "            if gen_dist>=rnd_dist:\n",
    "                acc+=1\n",
    "    return acc/len(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d400c99b-436d-477b-a388-9378280bf558",
   "metadata": {
    "tags": []
   },
   "source": [
    "### AlexNet (2) and (5), InceptionNet_V3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e5bf5fc1-3084-4cd7-8dc9-73dd156dbc96",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/matteo/.cache/torch/hub/pytorch_vision_v0.10.0\n",
      "/home/matteo/anaconda3/envs/braindiff/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/home/matteo/anaconda3/envs/braindiff/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AlexNet(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
       "    (1): ReLU(inplace=True)\n",
       "    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "    (4): ReLU(inplace=True)\n",
       "    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (7): ReLU(inplace=True)\n",
       "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (9): ReLU(inplace=True)\n",
       "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (11): ReLU(inplace=True)\n",
       "    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
       "  (classifier): Sequential(\n",
       "    (0): Dropout(p=0.5, inplace=False)\n",
       "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): Dropout(p=0.5, inplace=False)\n",
       "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "    (5): ReLU(inplace=True)\n",
       "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alexnet_model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)\n",
    "alexnet_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cbecb487-6837-4f32-9613-4512309bf3dd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "alexnet_transforms=Compose([transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "            \n",
    "    \n",
    "inception_transforms=Compose([transforms.Resize(342),\n",
    "    transforms.CenterCrop(299),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "            \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "802efa82-2144-4f2b-9aca-ddd1e208abac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_torch_alexnet=torch.stack([alexnet_transforms(i) for i in shown_imgs])\n",
    "y_torch_alexnet=torch.stack([alexnet_transforms(i) for i in first_candidates])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "78f48fe0-8538-4467-9177-2ef5d86171ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_torch_inception=torch.stack([inception_transforms(i) for i in shown_imgs])\n",
    "y_torch_inception=torch.stack([inception_transforms(i) for i in first_candidates])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9dff927c-3d8e-4619-a453-bd058d3b12ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "modules_2=list(alexnet_model.children())[0][:3]\n",
    "modules_5=list(alexnet_model.children())[0][:10]\n",
    "alexnet_model_2=torch.nn.Sequential(*modules_2)\n",
    "alexnet_model_5=torch.nn.Sequential(*modules_5)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4afd04f5-8a1c-4612-b7ed-138b3c1ac471",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/braindiff/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=Inception_V3_Weights.IMAGENET1K_V1`. You can also use `weights=Inception_V3_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "inception_model=inception_v3(pretrained=True)\n",
    "\n",
    "inception_model.fc = torch.nn.Identity()\n",
    "inception_model.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "f0b8ade4-2563-40de-97e8-37bf6f026fb5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# modules_in=list(inception_model.children())[:-3]\n",
    "\n",
    "# inception_model=torch.nn.Sequential(*modules_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5da0d1f1-5f28-4a2f-869d-7cd8fc8f2efb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████| 982/982 [00:04<00:00, 229.33it/s]\n"
     ]
    }
   ],
   "source": [
    "alex_2=n_way_acc(x_torch_alexnet,y_torch_alexnet,alexnet_model_2,device=\"cuda:2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7f544ed2-ae57-43cc-b3d3-c8ea63835586",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████| 982/982 [00:04<00:00, 244.16it/s]\n"
     ]
    }
   ],
   "source": [
    "alex_5=n_way_acc(x_torch_alexnet,y_torch_alexnet,alexnet_model_5,device=\"cuda:2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "00d9dfb8-06c6-4357-a77e-de8843977838",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9134419551934827, 0.9755600814663951)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alex_2,alex_5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "30683cf2-bb73-4bc1-8660-eedc32e8709e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 982/982 [00:32<00:00, 30.57it/s]\n"
     ]
    }
   ],
   "source": [
    "inception=n_way_acc(x_torch_inception,y_torch_inception,inception_model,device=\"cuda:2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "948469c7-6130-42d9-8895-1ba71105231f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8492871690427699"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inception"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "649ee01b-9da2-413a-a703-c95846ef52ab",
   "metadata": {
    "tags": []
   },
   "source": [
    "### CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f5a41be8-5283-4a32-a242-32e4bee1c45b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def n_way_acc_clip(x,y,model,n=2,device=\"cpu\",preprocess=None):\n",
    "    \n",
    "    cosine=torch.nn.CosineSimilarity()\n",
    "    acc=0\n",
    "    with torch.no_grad():\n",
    "        for i in tqdm.tqdm(range(len(x))):\n",
    "            \n",
    "            orig_image = preprocess(x[i]).unsqueeze(0).to(device)\n",
    "            \n",
    "            gen_image = preprocess(y[i]).unsqueeze(0).to(device)\n",
    "            \n",
    "            rnd_idx=torch.randint(len(x_torch),(1,)).item()\n",
    "            rnd_image = preprocess(y[rnd_idx]).unsqueeze(0).to(device)\n",
    "            \n",
    "            \n",
    "            \n",
    "            original_features=model.encode_image(orig_image).cpu().double()\n",
    "            generated_features=model.encode_image(gen_image).cpu().double()\n",
    "\n",
    "\n",
    "            random_features=model.encode_image(rnd_image).cpu().double()\n",
    "\n",
    "\n",
    "            gen_dist=cosine(original_features.reshape(len(original_features),-1),generated_features.reshape(len(generated_features),-1))\n",
    "            rnd_dist=cosine(original_features.reshape(len(original_features),-1),random_features.reshape(len(random_features),-1))\n",
    "            if gen_dist>=rnd_dist:\n",
    "                acc+=1\n",
    "    return acc/len(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "c4646aa3-18bd-4760-b228-2d70f8083326",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "clip_model, preprocess = clip.load(\"ViT-B/32\", device=\"cuda:2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d4103108-7222-4738-bb50-21c1bc64c2aa",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 982/982 [00:35<00:00, 27.69it/s]\n"
     ]
    }
   ],
   "source": [
    "clip_acc=n_way_acc_clip(shown_imgs,first_candidates,clip_model,device=\"cuda:2\",preprocess=preprocess)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b57e4c-283c-40e4-957b-8cf1cb53cca8",
   "metadata": {
    "tags": []
   },
   "source": [
    "## FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "a620bc1a-c37f-4187-a609-22bcd25f0aed",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# %pip install torchmetrics[image]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "f1b95e45-4f07-4b11-a8aa-15e6dfc55cef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "fid = FrechetInceptionDistance(feature=64,normalize=True)\n",
    "\n",
    "fid.update(x_torch_inception, real=True)\n",
    "fid.update(y_torch_inception, real=False)\n",
    "f64=fid.compute()\n",
    "\n",
    "# fid_2048 = FrechetInceptionDistance(feature=2048,normalize=True)\n",
    "\n",
    "# fid_2048.update(x_torch_inception, real=True)\n",
    "# fid_2048.update(y_torch_inception, real=False)\n",
    "# f2048=fid_2048.compute()\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "f3ecb577-7f3d-4d4e-84bc-8c178d40a73a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(16.0055)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c43339d6-1098-4d5c-9290-b5f732bba3a9",
   "metadata": {},
   "source": [
    "## Final metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "77ed5f54-a79a-40f3-9839-7a652f71ff41",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PixCorr: \t 0.3503\n",
      "SSIM: \t\t 0.3267\n",
      "AlexNet (2): \t 0.9134\n",
      "AlexNet (5): \t 0.9756\n",
      "Incpeption: \t 0.8493\n",
      "CLIP: \t\t 0.9155\n",
      "FID: \t\t 16.0055\n"
     ]
    }
   ],
   "source": [
    "print(f\"PixCorr: \\t {np.mean(pix_corr):.4f}\")\n",
    "print(f\"SSIM: \\t\\t {ssim:.4f}\")\n",
    "print(f\"AlexNet (2): \\t {alex_2:.4f}\")\n",
    "print(f\"AlexNet (5): \\t {alex_5:.4f}\")\n",
    "print(f\"Incpeption: \\t {inception:.4f}\")\n",
    "print(f\"CLIP: \\t\\t {clip_acc:.4f}\")\n",
    "print(f\"FID: \\t\\t {f64.item():.4f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "d5d9b64c-6fab-47ba-9f08-3b8d372652ed",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = pd.DataFrame({\n",
    "    \"Metric\": [\"PixCorr\", \"SSIM\", \"AlexNet (2)\", \"AlexNet (5)\", \"Inception\", \"CLIP\", \"FID\"],\n",
    "    \"Value\": [np.mean(pix_corr), ssim, alex_2, alex_5, inception, clip_acc, f64.item()]\n",
    "})\n",
    "\n",
    "df[\"Value\"] = df[\"Value\"].apply(lambda x: f\"{x:.4f}\")\n",
    "\n",
    "df.to_csv(f\"image_metrics_{target_run}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1639cdda-3267-44f5-b1f4-52b9e26e1599",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "braindiff",
   "language": "python",
   "name": "braindiff"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
