{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5f265e-407a-40bd-92fb-a652091fd7ea",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import math\n",
    "from einops import rearrange\n",
    "import time\n",
    "import random\n",
    "import string\n",
    "import h5py\n",
    "from tqdm import tqdm\n",
    "import webdataset as wds\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from accelerate import Accelerator, DeepSpeedPlugin\n",
    "\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "from transformers import CLIPModel, AutoTokenizer, AutoProcessor\n",
    "import evaluate\n",
    "import pandas as pd\n",
    "\n",
    "from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder\n",
    "# from v2_models import GNet8_Encoder\n",
    "\n",
    "# tf32 data type is faster than standard float32\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "# custom functions #\n",
    "import utils\n",
    "\n",
    "accelerator = Accelerator(split_batches=False, mixed_precision=\"fp16\")\n",
    "device = accelerator.device\n",
    "print(\"device:\",device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04da4892-e576-44f9-8e20-07f5c35b62af",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"subj1_l_bclip_basictest_wll_5_finetune_s7_250_aamax\"\n",
    "all_recons_path = f\"evals/{model_name}/{model_name}_all_enhancedrecons.pt\"\n",
    "subj = 7\n",
    "data_path = \"data\"\n",
    "cache_dir = \"data/fmri/cache\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4ccb9d-7e49-4d09-9fa3-de6c33e81353",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load ground truths, you can find these files on huggingface: https://huggingface.co/datasets/pscotti/mindeyev2/tree/main/evals\n",
    "all_images = torch.load(f\"evals/all_images_subj0{subj}.pt\")\n",
    "all_captions = torch.load(f\"recons/{model_name}_all_predcaptions.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04395680-a31c-436a-9bfb-a4a44e8d0b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"all_recons_path:\", all_recons_path)\n",
    "all_recons = torch.load(all_recons_path)\n",
    "# Low-level submodule\n",
    "all_blurryrecons = torch.load(f\"recons/{model_name}_all_blurryrecons.pt\")\n",
    "# GIT predicted captions\n",
    "all_predcaptions = torch.load(f\"recons/{model_name}_all_predcaptions.pt\")\n",
    "\n",
    "# model name\n",
    "model_name_plus_suffix = f\"{model_name}_all_enhancedrecons\"\n",
    "print(model_name_plus_suffix)\n",
    "print(all_images.shape, all_recons.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ffb659a-8154-4536-ab27-2d976da1bf4e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# if running this interactively, can specify jupyter_args here for argparser to use\n",
    "if utils.is_interactive():\n",
    "    model_name = \"subj1_l_bclip_basictest_wll_5_finetune_s7_250_aamax\"\n",
    "    all_recons_path = f\"evals/{model_name}/{model_name}_all_enhancedrecons.pt\"\n",
    "    subj = 7\n",
    "    \n",
    "    data_path = \"data\"\n",
    "    cache_dir = \"data/fmri/cache\"\n",
    "    \n",
    "    print(\"model_name:\", model_name)\n",
    "\n",
    "    jupyter_args = f\"--model_name={model_name} --subj={subj} --data_path={data_path} --cache_dir={cache_dir} --all_recons_path={all_recons_path}\"\n",
    "    print(jupyter_args)\n",
    "    jupyter_args = jupyter_args.split()\n",
    "    \n",
    "    from IPython.display import clear_output # function to clear print outputs in cell\n",
    "    %load_ext autoreload \n",
    "    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n",
    "    %autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95d66b33-b327-4895-a861-ecc6ccc51296",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Evals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f42009e9-f910-4f02-8db6-d46778aa6595",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "imsize = 256\n",
    "if all_images.shape[-1] != imsize:\n",
    "    all_images = transforms.Resize((imsize,imsize))(all_images).float()\n",
    "if all_recons.shape[-1] != imsize:\n",
    "    all_recons = transforms.Resize((imsize,imsize))(all_recons).float()\n",
    "if all_blurryrecons.shape[-1] != imsize:\n",
    "    all_blurryrecons = transforms.Resize((imsize,imsize))(all_blurryrecons).float()\n",
    "    \n",
    "if \"enhanced\" in model_name_plus_suffix:\n",
    "    print(\"weighted averaging to improve low-level evals\")\n",
    "    all_recons = all_recons*.75 + all_blurryrecons*.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "434b33b5-c799-4054-889c-ac74663d31ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 2 / 117 / 231 / 164 / 619 / 791\n",
    "import textwrap\n",
    "def wrap_title(title, wrap_width):\n",
    "    return \"\\n\".join(textwrap.wrap(title, wrap_width))\n",
    "\n",
    "fig, axes = plt.subplots(3, 4, figsize=(10, 8))\n",
    "jj=-1; kk=0;\n",
    "for j in np.array([2,165,119,200,231,210]):\n",
    "    jj+=1\n",
    "    # print(kk,jj)\n",
    "    axes[kk][jj].imshow(utils.torch_to_Image(all_images[j]))\n",
    "    axes[kk][jj].axis('off')\n",
    "    # axes[kk][jj].set_title(wrap_title(str(all_captions[[j]]),wrap_width=30), fontsize=8)\n",
    "    jj+=1\n",
    "    axes[kk][jj].imshow(utils.torch_to_Image(all_recons[j]))\n",
    "    axes[kk][jj].axis('off')\n",
    "    axes[kk][jj].set_title(wrap_title(str(all_predcaptions[[j]]),wrap_width=30), fontsize=8)\n",
    "    if jj==3: \n",
    "        kk+=1; jj=-1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a26e124-2444-434d-a399-d03c2c90cc08",
   "metadata": {},
   "source": [
    "## 2-way identification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e1778ff-5d6a-4087-b59f-0f44b9e0eada",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names\n",
    "\n",
    "@torch.no_grad()\n",
    "def two_way_identification(all_recons, all_images, model, preprocess, feature_layer=None, return_avg=True):\n",
    "    preds = model(torch.stack([preprocess(recon) for recon in all_recons], dim=0).to(device))\n",
    "    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))\n",
    "    if feature_layer is None:\n",
    "        preds = preds.float().flatten(1).cpu().numpy()\n",
    "        reals = reals.float().flatten(1).cpu().numpy()\n",
    "    else:\n",
    "        preds = preds[feature_layer].float().flatten(1).cpu().numpy()\n",
    "        reals = reals[feature_layer].float().flatten(1).cpu().numpy()\n",
    "\n",
    "    r = np.corrcoef(reals, preds)\n",
    "    r = r[:len(all_images), len(all_images):]\n",
    "    congruents = np.diag(r)\n",
    "\n",
    "    success = r < congruents\n",
    "    success_cnt = np.sum(success, 0)\n",
    "\n",
    "    if return_avg:\n",
    "        perf = np.mean(success_cnt) / (len(all_images)-1)\n",
    "        return perf\n",
    "    else:\n",
    "        return success_cnt, len(all_images)-1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df6be966-52ef-4cf6-8078-8d2d9617564b",
   "metadata": {},
   "source": [
    "## PixCorr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e17ea38-a254-4e90-a910-711734fdd8eb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "])\n",
    "\n",
    "# Flatten images while keeping the batch dimension\n",
    "all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()\n",
    "all_recons_flattened = preprocess(all_recons).view(len(all_recons), -1).cpu()\n",
    "\n",
    "print(all_images_flattened.shape)\n",
    "print(all_recons_flattened.shape)\n",
    "\n",
    "corrsum = 0\n",
    "for i in tqdm(range(len(all_images))):\n",
    "    corrsum += np.corrcoef(all_images_flattened[i], all_recons_flattened[i])[0][1]\n",
    "corrmean = corrsum / len(all_images)\n",
    "\n",
    "pixcorr = corrmean\n",
    "print(pixcorr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a556d5b-33a2-44aa-b48d-4b168316bbdd",
   "metadata": {
    "tags": []
   },
   "source": [
    "## SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2326fc4c-1248-4d0f-9176-218c6460f285",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# see https://github.com/zijin-gu/meshconv-decoding/issues/3\n",
    "from skimage.color import rgb2gray\n",
    "from skimage.metrics import structural_similarity as ssim\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), \n",
    "])\n",
    "\n",
    "# convert image to grayscale with rgb2grey\n",
    "img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())\n",
    "recon_gray = rgb2gray(preprocess(all_recons).permute((0,2,3,1)).cpu())\n",
    "print(\"converted, now calculating ssim...\")\n",
    "\n",
    "ssim_score=[]\n",
    "for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):\n",
    "    ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))\n",
    "\n",
    "ssim = np.mean(ssim_score)\n",
    "print(ssim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35138520-ec00-48a6-90dc-249a32a783d2",
   "metadata": {},
   "source": [
    "## AlexNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b45cc6c-ab80-43e2-b446-c8fcb4fc54e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision.models import alexnet, AlexNet_Weights\n",
    "alex_weights = AlexNet_Weights.IMAGENET1K_V1\n",
    "\n",
    "alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)\n",
    "alex_model.eval().requires_grad_(False)\n",
    "\n",
    "# see alex_weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "layer = 'early, AlexNet(2)'\n",
    "print(f\"\\n---{layer}---\")\n",
    "all_per_correct = two_way_identification(all_recons.to(device).float(), all_images, \n",
    "                                                          alex_model, preprocess, 'features.4')\n",
    "alexnet2 = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {alexnet2:.4f}\")\n",
    "\n",
    "layer = 'mid, AlexNet(5)'\n",
    "print(f\"\\n---{layer}---\")\n",
    "all_per_correct = two_way_identification(all_recons.to(device).float(), all_images, \n",
    "                                                          alex_model, preprocess, 'features.11')\n",
    "alexnet5 = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {alexnet5:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c296bab2-d106-469e-b997-b32d21a2cf01",
   "metadata": {},
   "source": [
    "## InceptionV3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a9c1b2b-af2a-476d-a1ac-32ee915ac2ec",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision.models import inception_v3, Inception_V3_Weights\n",
    "weights = Inception_V3_Weights.DEFAULT\n",
    "inception_model = create_feature_extractor(inception_v3(weights=weights), \n",
    "                                           return_nodes=['avgpool']).to(device)\n",
    "inception_model.eval().requires_grad_(False)\n",
    "\n",
    "# see weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "all_per_correct = two_way_identification(all_recons, all_images,\n",
    "                                        inception_model, preprocess, 'avgpool')\n",
    "        \n",
    "inception = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {inception:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7a25f7f-8298-4413-b512-8a1173413e07",
   "metadata": {},
   "source": [
    "## CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6afbf7ce-8793-4988-a328-a632acd88aa9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import clip\n",
    "clip_model, preprocess = clip.load(\"ViT-L/14\", device=device)\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
    "                         std=[0.26862954, 0.26130258, 0.27577711]),\n",
    "])\n",
    "\n",
    "all_per_correct = two_way_identification(all_recons, all_images,\n",
    "                                        clip_model.encode_image, preprocess, None) # final layer\n",
    "clip_ = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {clip_:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4fed9f8-ef1a-4c6d-a83f-2a934b6e87fd",
   "metadata": {},
   "source": [
    "## Efficient Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14143c0f-1b32-43ef-98d8-8ed458df4551",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import scipy as sp\n",
    "from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights\n",
    "weights = EfficientNet_B1_Weights.DEFAULT\n",
    "eff_model = create_feature_extractor(efficientnet_b1(weights=weights), \n",
    "                                    return_nodes=['avgpool'])\n",
    "eff_model.eval().requires_grad_(False)\n",
    "\n",
    "# see weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "gt = eff_model(preprocess(all_images))['avgpool']\n",
    "gt = gt.reshape(len(gt),-1).cpu().numpy()\n",
    "fake = eff_model(preprocess(all_recons))['avgpool']\n",
    "fake = fake.reshape(len(fake),-1).cpu().numpy()\n",
    "\n",
    "effnet = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()\n",
    "print(\"Distance:\",effnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "405f669d-cab7-4c75-90cd-651283f65a9e",
   "metadata": {},
   "source": [
    "## SwAV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c60b0c4-79fe-4cff-95e9-99733c821e67",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')\n",
    "swav_model = create_feature_extractor(swav_model, \n",
    "                                    return_nodes=['avgpool'])\n",
    "swav_model.eval().requires_grad_(False)\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "gt = swav_model(preprocess(all_images))['avgpool']\n",
    "gt = gt.reshape(len(gt),-1).cpu().numpy()\n",
    "fake = swav_model(preprocess(all_recons))['avgpool']\n",
    "fake = fake.reshape(len(fake),-1).cpu().numpy()\n",
    "\n",
    "swav = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()\n",
    "print(\"Distance:\",swav)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "528db098-6977-4dbd-9bc6-dfd6a96e0fc7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchmetrics import PearsonCorrCoef\n",
    "GNet = GNet8_Encoder(device=device,subject=subj,model_path=f\"{cache_dir}/gnet_multisubject.pt\")\n",
    "PeC = PearsonCorrCoef(num_outputs=len(recon_list))\n",
    "beta_primes = GNet.predict(recon_list)\n",
    "\n",
    "region_brain_correlations = {}\n",
    "for region, mask in subject_masks.items():\n",
    "    score = PeC(test_voxels_averaged[:,mask].moveaxis(0,1), beta_primes[:,mask].moveaxis(0,1))\n",
    "    region_brain_correlations[region] = float(torch.mean(score))\n",
    "print(region_brain_correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b794c2d7-ebba-4993-a09d-ffb314cb30e8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a dictionary to store variable names and their corresponding values\n",
    "import pandas as pd\n",
    "data = {\n",
    "    \"Metric\": [\"PixCorr\", \"SSIM\", \"AlexNet(2)\", \"AlexNet(5)\", \"InceptionV3\", \"CLIP\", \"EffNet-B\", \"SwAV\", \"FwdRetrieval\", \"BwdRetrieval\",\n",
    "               \"Brain Corr. nsd_general\", \"Brain Corr. V1\", \"Brain Corr. V2\", \"Brain Corr. V3\", \"Brain Corr. V4\",  \"Brain Corr. higher_vis\"],\n",
    "    \"Value\": [pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav, percent_correct_fwd, percent_correct_bwd, \n",
    "              region_brain_correlations[\"nsd_general\"], region_brain_correlations[\"V1\"], region_brain_correlations[\"V2\"], region_brain_correlations[\"V3\"], region_brain_correlations[\"V4\"], region_brain_correlations[\"higher_vis\"]]}\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "print(model_name_plus_suffix)\n",
    "print(df.to_string(index=False))\n",
    "print(df[\"Value\"].to_string(index=False))\n",
    "\n",
    "# save table to txt file\n",
    "os.makedirs('tables/',exist_ok=True)\n",
    "df[\"Value\"].to_csv(f'tables/{model_name_plus_suffix}.csv', sep='\\t', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20999416-35c5-4604-a69d-b89687d44f00",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_git_generated_captions = torch.load(f\"evals/all_git_generated_captions.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44e6d2de-e303-4197-86fc-a4328fc7169b",
   "metadata": {},
   "source": [
    "## Meteor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7159fa8-234d-4809-9c44-69fc86d7d668",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "meteor = evaluate.load('meteor')\n",
    "meteor_img_ref=meteor.compute(predictions=all_git_generated_captions,references=all_captions)\n",
    "meteor_brain_ref=meteor.compute(predictions=all_predcaptions,references=all_captions)\n",
    "meteor_brain_img=meteor.compute(predictions=all_predcaptions,references=all_git_generated_captions)\n",
    "\n",
    "\n",
    "relative_brain_image_meteor=meteor_brain_img[\"meteor\"]/meteor_img_ref[\"meteor\"]\n",
    "\n",
    "print(f\"[GROUND] METEOR GIT from images vs captions: {meteor_img_ref['meteor']}\")\n",
    "print(f\"[ABSOLUTE] METEOR GIT from brain vs captions: {meteor_brain_ref['meteor']}\")\n",
    "print(f\"[ABSOLUTE] METEOR GIT from brain vs images: {meteor_brain_img['meteor']}\")\n",
    "print(f\"[RELATIVE] METEOR  {relative_brain_image_meteor}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6186a6ef-39e4-44db-8bbc-54d5448c782c",
   "metadata": {},
   "source": [
    "## Rouge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30ff56b5-597c-4341-80d6-0005a9de95f2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rouge = evaluate.load('rouge')\n",
    "rouge_img_ref=rouge.compute(predictions=all_git_generated_captions,references=all_captions)\n",
    "rouge_brain_ref=rouge.compute(predictions=all_predcaptions,references=all_captions)\n",
    "rouge_brain_img=rouge.compute(predictions=all_predcaptions,references=all_git_generated_captions)\n",
    "\n",
    "\n",
    "relative_brain_image_rouge1 = rouge_brain_img['rouge1']/rouge_img_ref['rouge1']\n",
    "relative_brain_image_rougeL = rouge_brain_img['rougeL']/rouge_img_ref['rougeL']\n",
    "\n",
    "print(f\"[GROUND] ROUGE-1 GIT from images vs captions: {rouge_img_ref['rouge1']}\")\n",
    "print(f\"[ABSOLUTE] ROUGE-1 GIT from brain vs captions: {rouge_brain_ref['rouge1']}\")\n",
    "print(f\"[ABSOLUTE] ROUGE-1 GIT from brain vs images: {rouge_brain_img['rouge1']}\")\n",
    "print(f\"[RELATIVE] ROUGE-1  {relative_brain_image_rouge1}\")\n",
    "\n",
    "print(f\"[GROUND] ROUGE-L GIT from images vs captions: {rouge_img_ref['rougeL']}\")\n",
    "print(f\"[ABSOLUTE] ROUGE-L GIT from brain vs captions: {rouge_brain_ref['rougeL']}\")\n",
    "print(f\"[ABSOLUTE] ROUGE-L GIT from brain vs images: {rouge_brain_img['rougeL']}\")\n",
    "print(f\"[RELATIVE] ROUGE-L  {relative_brain_image_rougeL}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416035d6-3975-43e2-959b-a8ead420e288",
   "metadata": {},
   "source": [
    "## Sentence transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98959bd0-d365-4cc7-8872-f4b134e9d924",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    embedding_brain= sentence_model.encode(all_predcaptions, convert_to_tensor=True)\n",
    "    embedding_captions = sentence_model.encode(all_captions, convert_to_tensor=True)\n",
    "    embedding_images = sentence_model.encode(all_git_generated_captions, convert_to_tensor=True)\n",
    "\n",
    "    ss_sim_brain_img=util.pytorch_cos_sim(embedding_brain, embedding_images).cpu()\n",
    "    ss_sim_brain_cap=util.pytorch_cos_sim(embedding_brain, embedding_captions).cpu()\n",
    "    ss_sim_img_cap=util.pytorch_cos_sim(embedding_images, embedding_captions).cpu()\n",
    "\n",
    "    relative_brain_image_ss=ss_sim_brain_img.diag().mean()/ss_sim_img_cap.diag().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5388e49f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(f\"[GROUND] Sentence Transformer Similarity GIT from images vs captions: {ss_sim_img_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] Sentence Transformer Similarity GIT from brain vs captions: {ss_sim_brain_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] Sentence Transformer Similarity GIT from brain vs images: {ss_sim_brain_img.diag().mean()}\")\n",
    "print(f\"[RELATIVE] Sentence Transformer Similarity   {relative_brain_image_ss.mean()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8b936fb-2a76-4ddd-9d5b-0a5b5d4e8542",
   "metadata": {},
   "source": [
    "## CLIP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ebe5827",
   "metadata": {},
   "source": [
    "#### CLIP-B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2a7a624-b753-4605-a722-ad2963496723",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_clip = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "processor_clip = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    input_ids=tokenizer(list(all_predcaptions),return_tensors=\"pt\",padding=True)\n",
    "    embedding_brain= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "    input_ids=tokenizer(list(all_captions),return_tensors=\"pt\",padding=True)\n",
    "    embedding_captions= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "    input_ids=tokenizer(all_git_generated_captions,return_tensors=\"pt\",padding=True)\n",
    "    embedding_images= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "clip_B_sim_brain_img=util.pytorch_cos_sim(embedding_brain, embedding_images).cpu()\n",
    "clip_B_sim_brain_cap=util.pytorch_cos_sim(embedding_brain, embedding_captions).cpu()\n",
    "clip_B_sim_img_cap=util.pytorch_cos_sim(embedding_images, embedding_captions).cpu()\n",
    "\n",
    "relative_brain_image_clip_B=clip_B_sim_brain_img.diag().mean()/clip_B_sim_img_cap.diag().mean()\n",
    "\n",
    "print(f\"[GROUND] CLIP Similarity GIT from images vs captions: {clip_B_sim_img_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] CLIP Similarity GIT from brain vs captions: {clip_B_sim_brain_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] CLIP Similarity GIT from brain vs images: {clip_B_sim_brain_img.diag().mean()}\")\n",
    "print(f\"[RELATIVE] CLIP Similarity   {relative_brain_image_clip_B.mean()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a04c1d5",
   "metadata": {},
   "source": [
    "#### CLIP-L"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c219bd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_clip = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "processor_clip = AutoProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    input_ids=tokenizer(list(all_predcaptions),return_tensors=\"pt\",padding=True)\n",
    "    embedding_brain= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "    input_ids=tokenizer(list(all_captions),return_tensors=\"pt\",padding=True)\n",
    "    embedding_captions= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "    input_ids=tokenizer(all_git_generated_captions,return_tensors=\"pt\",padding=True)\n",
    "    embedding_images= model_clip.get_text_features(**input_ids)\n",
    "\n",
    "clip_L_sim_brain_img=util.pytorch_cos_sim(embedding_brain, embedding_images).cpu()\n",
    "clip_L_sim_brain_cap=util.pytorch_cos_sim(embedding_brain, embedding_captions).cpu()\n",
    "clip_L_sim_img_cap=util.pytorch_cos_sim(embedding_images, embedding_captions).cpu()\n",
    "\n",
    "relative_brain_image_clip_L=clip_L_sim_brain_img.diag().mean()/clip_L_sim_img_cap.diag().mean()\n",
    "\n",
    "print(f\"[GROUND] CLIP Similarity GIT from images vs captions: {clip_L_sim_img_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] CLIP Similarity GIT from brain vs captions: {clip_L_sim_brain_cap.diag().mean()}\")\n",
    "print(f\"[ABSOLUTE] CLIP Similarity GIT from brain vs images: {clip_L_sim_brain_img.diag().mean()}\")\n",
    "print(f\"[RELATIVE] CLIP Similarity   {relative_brain_image_clip_L.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bbfd73f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "caption_metrics={ \n",
    "            \"Rouge1_img_ref\":rouge_img_ref['rouge1'],\n",
    "            \"Rouge1_brain_ref\":rouge_brain_ref['rouge1'],\n",
    "            \"Rouge1_brain_img\":rouge_brain_img['rouge1'],\n",
    "            \"Rouge1_relative\":relative_brain_image_rouge1.mean(),\n",
    "            \"RougeL_img_ref\":rouge_img_ref['rougeL'],\n",
    "            \"RougeL_brain_ref\":rouge_brain_ref['rougeL'],\n",
    "            \"RougeL_brain_img\":rouge_brain_img['rougeL'],\n",
    "            \"RougeL_relative\":relative_brain_image_rougeL.mean(),\n",
    "            \"Meteor_img_ref\":meteor_img_ref['meteor'],\n",
    "            \"Meteor_brain_ref\":meteor_brain_ref['meteor'],\n",
    "            \"Meteor_brain_img\":meteor_brain_img['meteor'],\n",
    "            \"Meteor_relative\":relative_brain_image_meteor,\n",
    "            \"Sentence_img_ref\":ss_sim_img_cap.diag().mean().item(),\n",
    "            \"Sentence_brain_ref\":ss_sim_brain_cap.diag().mean().item(),\n",
    "            \"Sentence_brain_img\":ss_sim_brain_img.diag().mean().item(),\n",
    "            \"Sentence_relative\":relative_brain_image_ss.mean().item(),\n",
    "            \"CLIP-B_img_ref\":clip_B_sim_img_cap.diag().mean().item(),\n",
    "            \"CLIP-B_brain_ref\":clip_B_sim_brain_cap.diag().mean().item(),\n",
    "            \"CLIP-B_brain_img\":clip_B_sim_brain_img.diag().mean().item(),\n",
    "            \"CLIP-B_relative\":relative_brain_image_clip_B.mean().item(),\n",
    "            \"CLIP-L_img_ref\":clip_L_sim_img_cap.diag().mean().item(),\n",
    "            \"CLIP-L_brain_ref\":clip_L_sim_brain_cap.diag().mean().item(),\n",
    "            \"CLIP-L_brain_img\":clip_L_sim_brain_img.diag().mean().item(),\n",
    "            \"CLIP-L_relative\":relative_brain_image_clip_L.mean().item(),\n",
    "            }\n",
    "\n",
    "os.makedirs('tables/',exist_ok=True)\n",
    "df=pd.DataFrame.from_dict(caption_metrics,orient='index',columns=[\"Value\"])\n",
    "df.to_csv(f'tables/{model_name_plus_suffix}_caption_metrics.csv', sep='\\t', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
