{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160bfa06",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from torchvision.models.feature_extraction import create_feature_extractor\n",
    "from skimage.metrics import structural_similarity as ssim\n",
    "from skimage.color import rgb2gray\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import scipy.spatial as sp\n",
    "import clip\n",
    "from PIL import Image\n",
    "from scipy.spatial.distance import correlation\n",
    "from torchmetrics.functional import accuracy\n",
    "from torchvision.models import ViT_H_14_Weights, vit_h_14\n",
    "warnings.filterwarnings('ignore')\n",
    "device = torch.device(\"cuda:4\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ff440c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True, device='cuda:4'):\n",
    "    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))\n",
    "    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))\n",
    "    if feature_layer is None:\n",
    "        preds = preds.float().flatten(1).cpu().numpy()\n",
    "        reals = reals.float().flatten(1).cpu().numpy()\n",
    "    else:\n",
    "        preds = preds[feature_layer].float().flatten(1).cpu().numpy()\n",
    "        reals = reals[feature_layer].float().flatten(1).cpu().numpy()\n",
    "\n",
    "    r = np.corrcoef(reals, preds)\n",
    "    r = r[:len(all_images), len(all_images):]\n",
    "    congruents = np.diag(r)\n",
    "\n",
    "    success = r < congruents\n",
    "    success_cnt = np.sum(success, 0)\n",
    "\n",
    "    if return_avg:\n",
    "        perf = np.mean(success_cnt) / (len(all_images)-1)\n",
    "        return perf\n",
    "    else:\n",
    "        return success_cnt, len(all_images)-1\n",
    "\n",
    "def cal_metrics(all_images, all_brain_recons, device):\n",
    "    all_images = all_images[:].to(device)\n",
    "    all_brain_recons = torch.stack([img for img in all_brain_recons[:]]).to(device).to(all_images.dtype).clamp(0,1).squeeze()\n",
    "\n",
    "    print(\"Images shape:\", all_images.shape)\n",
    "    print(\"Recons shape:\", all_brain_recons.shape)\n",
    "\n",
    "    # Ensure both tensors are the same size for MSE\n",
    "    resize = transforms.Resize((all_images.size(2), all_images.size(3)), interpolation=transforms.InterpolationMode.BILINEAR)\n",
    "    all_brain_recons = resize(all_brain_recons)\n",
    "\n",
    "    print(\"Images shape after resize:\", all_images.shape)\n",
    "    print(\"Recons shape after resize:\", all_brain_recons.shape)\n",
    "\n",
    "    # Preprocess\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    ])\n",
    "\n",
    "    # Flatten images while keeping the batch dimension\n",
    "    all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).to(device).cpu()\n",
    "    all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()\n",
    "\n",
    "    print(all_images_flattened.shape)\n",
    "    print(all_brain_recons_flattened.shape)\n",
    "\n",
    "    # PixCorr\n",
    "    print(\"\\n------calculating pixcorr------\")\n",
    "    corrsum = 0\n",
    "    for i in tqdm(range(len(all_images))):\n",
    "        corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]\n",
    "    pixcorr = corrsum / len(all_images)\n",
    "    print(\"PixCorr:\", pixcorr)\n",
    "\n",
    "    # SSIM\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(625, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    ])\n",
    "\n",
    "    img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu().numpy())\n",
    "    recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu().numpy())\n",
    "    print(\"converted, now calculating ssim...\")\n",
    "\n",
    "    ssim_score=[]\n",
    "    for im, rec in tqdm(zip(img_gray, recon_gray), total=len(all_images)):\n",
    "        ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))\n",
    "\n",
    "    ssim_mean = np.mean(ssim_score)\n",
    "    print(\"SSIM:\", ssim_mean)\n",
    "\n",
    "    # MSE\n",
    "    mse = torch.nn.functional.mse_loss(all_brain_recons, all_images).item()\n",
    "    print(\"MSE:\", mse)\n",
    "\n",
    "    # Cosine Similarity\n",
    "    cosine_sim = torch.nn.functional.cosine_similarity(all_brain_recons_flattened, all_images_flattened).mean().item()\n",
    "    print(\"Cosine Similarity:\", cosine_sim)\n",
    "\n",
    "    # Feature-based evaluations using different models\n",
    "    def evaluate_model(model, preprocess, feature_layers, layer_names):\n",
    "        results = {}\n",
    "        for feature_layer, layer_name in zip(feature_layers, layer_names):\n",
    "            print(f\"\\n---{layer_name}---\")\n",
    "            all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, \n",
    "                                                     model, preprocess, feature_layer, device=device)\n",
    "            results[layer_name] = np.mean(all_per_correct)\n",
    "            print(f\"2-way Percent Correct: {results[layer_name]:.4f}\")\n",
    "        return results\n",
    "\n",
    "    # AlexNet\n",
    "    from torchvision.models import alexnet, AlexNet_Weights\n",
    "    alex_weights = AlexNet_Weights.IMAGENET1K_V1\n",
    "    alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4', 'features.11']).to(device)\n",
    "    alex_model.eval().requires_grad_(False)\n",
    "\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    alexnet_results = evaluate_model(alex_model, preprocess, ['features.4', 'features.11'], ['AlexNet(2)', 'AlexNet(5)'])\n",
    "    del alex_model\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    # InceptionV3\n",
    "    from torchvision.models import inception_v3, Inception_V3_Weights\n",
    "    inception_weights = Inception_V3_Weights.DEFAULT\n",
    "    inception_model = create_feature_extractor(inception_v3(weights=inception_weights), return_nodes=['avgpool']).to(device)\n",
    "    inception_model.eval().requires_grad_(False)\n",
    "\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    inception_results = evaluate_model(inception_model, preprocess, ['avgpool'], ['InceptionV3'])\n",
    "    del inception_model\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    #CLIP\n",
    "    clip_model, preprocess = clip.load(\"ViT-L/14\", device=device)\n",
    "\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
    "                            std=[0.26862954, 0.26130258, 0.27577711]),\n",
    "    ])\n",
    "\n",
    "    all_per_correct = two_way_identification(all_brain_recons, all_images,\n",
    "                                            clip_model.encode_image, preprocess, None) # final layer\n",
    "    clip_results = np.mean(all_per_correct)\n",
    "    print(\"CLIP:\", clip_results)\n",
    "\n",
    "    # EfficientNet\n",
    "    from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights\n",
    "    eff_weights = EfficientNet_B1_Weights.DEFAULT\n",
    "    eff_model = create_feature_extractor(efficientnet_b1(weights=eff_weights), return_nodes=['avgpool']).to(device)\n",
    "    eff_model.eval().requires_grad_(False)\n",
    "\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    gt = eff_model(preprocess(all_images))['avgpool']\n",
    "    gt = gt.reshape(len(gt), -1).cpu().numpy()\n",
    "    fake = eff_model(preprocess(all_brain_recons))['avgpool']\n",
    "    fake = fake.reshape(len(fake), -1).cpu().numpy()\n",
    "    effnet_distance = np.array([sp.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()\n",
    "    print(\"EffNet Distance:\", effnet_distance)\n",
    "    del eff_model\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    # SwAV\n",
    "    swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')\n",
    "    swav_model = create_feature_extractor(swav_model, return_nodes=['avgpool']).to(device)\n",
    "    swav_model.eval().requires_grad_(False)\n",
    "\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    gt = swav_model(preprocess(all_images))['avgpool']\n",
    "    gt = gt.reshape(len(gt), -1).cpu().numpy()\n",
    "    fake = swav_model(preprocess(all_brain_recons))['avgpool']\n",
    "    fake = fake.reshape(len(fake), -1).cpu().numpy()\n",
    "    swav_distance = np.array([correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()\n",
    "    print(\"SwAV Distance:\", swav_distance)\n",
    "    del swav_model\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    # Save the results\n",
    "    metrics = {\n",
    "        'PixCorr': [pixcorr],\n",
    "        'SSIM': [ssim_mean],\n",
    "        'MSE': [mse],\n",
    "        'Cosine Similarity': [cosine_sim],\n",
    "        'AlexNet(2)': [alexnet_results[\"AlexNet(2)\"]],\n",
    "        'AlexNet(5)': [alexnet_results[\"AlexNet(5)\"]],\n",
    "        'InceptionV3': [inception_results[\"InceptionV3\"]],\n",
    "        'CLIP': [clip_results],  \n",
    "        'EffNet Distance': [effnet_distance],\n",
    "        'SwAV Distance': [swav_distance]\n",
    "    }\n",
    "    return metrics \n",
    "\n",
    "\n",
    "def calculate_metrics(all_images, all_brain_recons, device):\n",
    "    \n",
    "    metrics = cal_metrics(all_images, all_brain_recons, device)\n",
    "\n",
    "    return metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dece16a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([982, 3, 425, 425]), torch.Size([982, 3, 512, 512]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fill with paths to test stimuli and reconstructed stimuli\n",
    "test_images_path = 'TODO'\n",
    "reconstructed_images_path = 'TODO'\n",
    "NSD_test_stimulus = torch.load(test_images_path)\n",
    "recontructed_images = torch.load(reconstructed_images_path)\n",
    "NSD_test_stimulus.shape, recontructed_images.shape\n",
    "\n",
    "\n",
    "calculate_metrics(NSD_test_stimulus, recontructed_images, 'cuda:4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1144365",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "])\n",
    "\n",
    "NSD_test_stimulus = preprocess(NSD_test_stimulus)\n",
    "NSD_recontructed_images = preprocess(recontructed_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "280f3a68",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing images: 100%|██████████| 982/982 [07:25<00:00,  2.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall mean acc: 0.4167, Overall mean std: 0.2342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Load ViT model and preprocessing\n",
    "device = torch.device(\"cuda:4\" if torch.cuda.is_available() else \"cpu\")\n",
    "weights = ViT_H_14_Weights.DEFAULT\n",
    "vit_model = vit_h_14(weights=weights).to(device)\n",
    "vit_model.eval()\n",
    "preprocess = weights.transforms()\n",
    "\n",
    "# Function to compute accuracy\n",
    "def n_way_top_k_acc(pred, class_id, n_way, num_trials=40, top_k=1):\n",
    "    pick_range = [i for i in range(len(pred)) if i != class_id]\n",
    "    acc_list = []\n",
    "    for _ in range(num_trials):\n",
    "        idxs_picked = np.random.choice(pick_range, n_way - 1, replace=False)\n",
    "        pred_picked = torch.cat([pred[class_id].unsqueeze(0), pred[idxs_picked]])\n",
    "        acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), \n",
    "                       task='multiclass', num_classes=n_way, top_k=top_k)\n",
    "        acc_list.append(acc.item())\n",
    "    return np.mean(acc_list), np.std(acc_list)\n",
    "\n",
    "# Initialize subject-wise storage\n",
    "all_acc = []\n",
    "all_std = []\n",
    "\n",
    "for i in tqdm(range(len(NSD_recontructed_images)), desc=\"Processing images\"):\n",
    "    # Preprocess images\n",
    "    image = preprocess(NSD_test_stimulus[i].unsqueeze(0)).to(device)\n",
    "    recon_image = preprocess(NSD_recontructed_images[i].unsqueeze(0)).to(device)\n",
    "\n",
    "    # Get model outputs\n",
    "    recon_image_out = vit_model(recon_image).squeeze(0).softmax(0).detach()\n",
    "    gt_class_id = vit_model(image).squeeze(0).softmax(0).argmax().item()\n",
    "\n",
    "    # Compute accuracy\n",
    "    acc, std = n_way_top_k_acc(recon_image_out, gt_class_id, 50, 1000, 1)\n",
    "    all_acc.append(acc)\n",
    "    all_std.append(std)\n",
    "\n",
    "print(\"Overall mean acc: {:.4f}, Overall mean std: {:.4f}\".format(np.mean(all_acc), np.mean(all_std)))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "intdb",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
