{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53146de1-97c4-4bf9-9569-a6a400d377fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import make_grid\n",
    "from tqdm import tqdm\n",
    "from datetime import datetime\n",
    "import argparse\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "local_rank = 0\n",
    "print(\"device:\",device)\n",
    "\n",
    "import utils\n",
    "seed=42\n",
    "utils.seed_everything(seed=seed)\n",
    "\n",
    "if utils.is_interactive():\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2\n",
    "imsize = 512"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf904985-a9ed-4d12-a965-a46db5473d1b",
   "metadata": {},
   "source": [
    "# Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26b33767-16b1-447b-afbf-1cfa6118815f",
   "metadata": {},
   "outputs": [],
   "source": [
    "subj=2\n",
    "recon_path = \"recons/subj1_nl_sclip_basictest2_finetune_s2_aamax_100_hl_recons_img2img1_8samples.pt\"\n",
    "all_images_path = f\"recons/all_images_subj{subj}_nsd_split1.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adb10ae2-0188-454a-b227-3a90783bdaa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_brain_recons = torch.load(f'{recon_path}')\n",
    "all_images = torch.load(f'{all_images_path}')\n",
    "\n",
    "print(all_images.shape)\n",
    "print(all_brain_recons.shape)\n",
    "\n",
    "all_images = all_images.to(device)\n",
    "all_brain_recons = all_brain_recons.to(device).to(all_images.dtype).clamp(0,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cf403a1-f03c-4346-879f-7b8719ca654a",
   "metadata": {},
   "source": [
    "### Main Display code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db613b6a-33f1-459c-b494-52328297b7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "imsize = 256\n",
    "all_images = transforms.Resize((imsize,imsize))(all_images)\n",
    "all_brain_recons = transforms.Resize((imsize,imsize))(all_brain_recons)\n",
    "np.random.seed(0)\n",
    "ind = np.flip(np.array([ 1, 37, 18, 75, 78,  9,  4, 66,  5, 83]))\n",
    "all_interleaved = torch.zeros(len(ind)*2,3,imsize,imsize)\n",
    "icount = 0\n",
    "for t in ind:\n",
    "    all_interleaved[icount] = all_images[t]\n",
    "    all_interleaved[icount+1] = all_brain_recons[t]\n",
    "    icount += 2\n",
    "\n",
    "plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
    "def show(imgs,figsize):\n",
    "    if not isinstance(imgs, list):\n",
    "        imgs = [imgs]\n",
    "    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)\n",
    "    for i, img in enumerate(imgs):\n",
    "        img = img.detach()\n",
    "        img = transforms.ToPILImage()(img)\n",
    "        axs[0, i].imshow(np.asarray(img))\n",
    "        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "    \n",
    "grid = make_grid(all_interleaved, nrow=10, padding=2)\n",
    "show(grid,figsize=(20,16))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e06f9be1-8726-42f9-94fa-e2a74e5db8ff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eb853f6-650a-4e6d-b507-84eacbb5bc1e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "be7ba39c-cd18-4942-a86a-ee640865bcdc",
   "metadata": {},
   "source": [
    "# 2-Way Identification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49f8056e-a80c-4abc-8c2d-3193bb43f945",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names\n",
    "\n",
    "@torch.no_grad()\n",
    "def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True):\n",
    "    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))\n",
    "    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))\n",
    "    if feature_layer is None:\n",
    "        preds = preds.float().flatten(1).cpu().numpy()\n",
    "        reals = reals.float().flatten(1).cpu().numpy()\n",
    "    else:\n",
    "        preds = preds[feature_layer].float().flatten(1).cpu().numpy()\n",
    "        reals = reals[feature_layer].float().flatten(1).cpu().numpy()\n",
    "\n",
    "    r = np.corrcoef(reals, preds)\n",
    "    r = r[:len(all_images), len(all_images):]\n",
    "    congruents = np.diag(r)\n",
    "\n",
    "    success = r < congruents\n",
    "    success_cnt = np.sum(success, 0)\n",
    "\n",
    "    if return_avg:\n",
    "        perf = np.mean(success_cnt) / (len(all_images)-1)\n",
    "        return perf\n",
    "    else:\n",
    "        return success_cnt, len(all_images)-1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63801703-0f78-47ca-89c6-54dcf71156a4",
   "metadata": {},
   "source": [
    "## PixCorr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d89db923-16c8-4ec8-9bf9-cc9749031188",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "])\n",
    "\n",
    "# Flatten images while keeping the batch dimension\n",
    "all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()\n",
    "all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()\n",
    "\n",
    "print(all_images_flattened.shape)\n",
    "print(all_brain_recons_flattened.shape)\n",
    "\n",
    "corrsum = 0\n",
    "corr_list = []\n",
    "for i in tqdm(range(1000)):\n",
    "    temp = np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]\n",
    "    corr_list.append(temp)\n",
    "    corrsum += temp\n",
    "corrmean = corrsum / 1000\n",
    "\n",
    "pixcorr = corrmean\n",
    "print(pixcorr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2740d15b-1a02-4d50-9fe4-9dbfae38800e",
   "metadata": {},
   "source": [
    "## SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a712ab76-1c8f-4232-995c-d06f8191d9b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# see https://github.com/zijin-gu/meshconv-decoding/issues/3\n",
    "from skimage.color import rgb2gray\n",
    "from skimage.metrics import structural_similarity as ssim\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), \n",
    "])\n",
    "\n",
    "# convert image to grayscale with rgb2grey\n",
    "img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())\n",
    "recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu())\n",
    "print(\"converted, now calculating ssim...\")\n",
    "\n",
    "ssim_score=[]\n",
    "for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):\n",
    "    ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))\n",
    "\n",
    "ssim = np.mean(ssim_score)\n",
    "print(ssim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4554b22-7faa-4e59-a6db-83bf775dd9e4",
   "metadata": {
    "tags": []
   },
   "source": [
    "### AlexNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc906db-bc5f-4cce-87f4-878706c14d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import alexnet, AlexNet_Weights\n",
    "alex_weights = AlexNet_Weights.IMAGENET1K_V1\n",
    "\n",
    "alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)\n",
    "alex_model.eval().requires_grad_(False)\n",
    "\n",
    "# see alex_weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "layer = 'early, AlexNet(2)'\n",
    "print(f\"\\n---{layer}---\")\n",
    "all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, \n",
    "                                                          alex_model, preprocess, 'features.4')\n",
    "alexnet2 = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {alexnet2:.4f}\")\n",
    "\n",
    "layer = 'mid, AlexNet(5)'\n",
    "print(f\"\\n---{layer}---\")\n",
    "all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, \n",
    "                                                          alex_model, preprocess, 'features.11')\n",
    "alexnet5 = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {alexnet5:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3a8ac02-a024-442f-a0a9-638a8afdeb0d",
   "metadata": {},
   "source": [
    "### InceptionV3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "555b088c-07b4-4164-9e38-bdb5d43d58ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import inception_v3, Inception_V3_Weights\n",
    "weights = Inception_V3_Weights.DEFAULT\n",
    "inception_model = create_feature_extractor(inception_v3(weights=weights), \n",
    "                                           return_nodes=['avgpool']).to(device)\n",
    "inception_model.eval().requires_grad_(False)\n",
    "\n",
    "# see weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "all_per_correct = two_way_identification(all_brain_recons, all_images,\n",
    "                                        inception_model, preprocess, 'avgpool')\n",
    "        \n",
    "inception = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {inception:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12537959-3045-46f2-833f-80ff6327d40a",
   "metadata": {},
   "source": [
    "### CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc928ed-952c-4b2f-ae05-2cd8b8c770c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import clip\n",
    "clip_model, preprocess = clip.load(\"ViT-L/14\", device=device)\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
    "                         std=[0.26862954, 0.26130258, 0.27577711]),\n",
    "])\n",
    "\n",
    "all_per_correct = two_way_identification(all_brain_recons, all_images,\n",
    "                                        clip_model.encode_image, preprocess, None, return_avg=True) # final layer\n",
    "clip_ = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {clip_:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e00a5a08-8acb-4fd9-b581-697df7b845db",
   "metadata": {},
   "source": [
    "### Efficient Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d103b5af-ae7c-4f07-94e0-6b35d951ce78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights\n",
    "weights = EfficientNet_B1_Weights.DEFAULT\n",
    "eff_model = create_feature_extractor(efficientnet_b1(weights=weights), \n",
    "                                    return_nodes=['avgpool']).to(device)\n",
    "eff_model.eval().requires_grad_(False)\n",
    "\n",
    "# see weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "gt = eff_model(preprocess(all_images))['avgpool']\n",
    "gt = gt.reshape(len(gt),-1).cpu().numpy()\n",
    "fake = eff_model(preprocess(all_brain_recons))['avgpool']\n",
    "fake = fake.reshape(len(fake),-1).cpu().numpy()\n",
    "\n",
    "effnet = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()\n",
    "print(\"Distance:\",effnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81c764bd-4c46-4102-9342-5d63bc47c4e4",
   "metadata": {},
   "source": [
    "### SwAV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eacb33d-06b7-4847-b106-3a8ae15798f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')\n",
    "swav_model = create_feature_extractor(swav_model, \n",
    "                                    return_nodes=['avgpool']).to(device)\n",
    "swav_model.eval().requires_grad_(False)\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "gt = swav_model(preprocess(all_images))['avgpool']\n",
    "gt = gt.reshape(len(gt),-1).cpu().numpy()\n",
    "fake = swav_model(preprocess(all_brain_recons))['avgpool']\n",
    "fake = fake.reshape(len(fake),-1).cpu().numpy()\n",
    "\n",
    "swav = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()\n",
    "print(\"Distance:\",swav)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02dbee6a-56fc-4bd7-9d52-4413c7a01a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a dictionary to store variable names and their corresponding values\n",
    "data = {\n",
    "    \"Metric\": [\"PixCorr\", \"SSIM\", \"AlexNet(2)\", \"AlexNet(5)\", \"InceptionV3\", \"CLIP\", \"EffNet-B\", \"SwAV\"],\n",
    "    \"Value\": [pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav],\n",
    "}\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "print(df.to_string(index=False))\n",
    "\n",
    "if not utils.is_interactive():\n",
    "    # save table to txt file\n",
    "    df.to_csv(f'{recon_path[:-3]}.csv', sep='\\t', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
