{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e31ce66-0b2a-4d87-9e35-c042422e0d73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import random\n",
    "import json\n",
    "import gc\n",
    "from typing import Tuple, Optional, Dict\n",
    "from functools import partial\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import tqdm\n",
    "import nibabel as nib\n",
    "from einops import rearrange\n",
    "from scipy import ndimage\n",
    "import wandb\n",
    "\n",
    "dir2 = os.path.abspath('../..')\n",
    "dir1 = os.path.dirname(dir2)\n",
    "if not dir1 in sys.path: \n",
    "    sys.path.append(dir1)\n",
    "\n",
    "from research.data.natural_scenes import (\n",
    "    NaturalScenesDataset,\n",
    "    StimulusDataset,\n",
    "    KeyDataset\n",
    ")\n",
    "from research.models.fmri_decoders import Decoder\n",
    "\n",
    "from research.experiments.nsd.nsd_experiment import NSDExperiment\n",
    "from research.metrics.metrics import (\n",
    "    cosine_similarity, \n",
    "    r2_score,\n",
    "    pearsonr,\n",
    "    embedding_distance,\n",
    "    cosine_distance,\n",
    "    squared_euclidean_distance,\n",
    "    contrastive_score,\n",
    "    two_versus_two,\n",
    "    smooth_euclidean_distance,\n",
    "    top_knn_test\n",
    ")\n",
    "from pipeline.utils import product, index_unsorted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9dae6e06-bab9-4a2d-896f-2ef049fd3b17",
   "metadata": {},
   "outputs": [],
   "source": [
    "nsd_path = Path('D:\\\\Datasets\\\\NSD\\\\')\n",
    "nsd = NaturalScenesDataset(nsd_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a7cfa449-7c5b-4ebc-a5a2-2186f3cc2c45",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from research.experiments.nsd.nsd_run_decoding import main\n",
    "\n",
    "model_name = 'ViT-B=32'\n",
    "stimulus_key = 'embedding'\n",
    "voxel_selection = 'nc'\n",
    "threshold = 8\n",
    "num_voxels = None\n",
    "\n",
    "voxel_selection_path = 'derivatives/noise-ceiling.hdf5'\n",
    "if threshold is None:\n",
    "    voxel_selection_key = f'split-01/sorted_indices'\n",
    "else:\n",
    "    voxel_selection_key = f'split-01/value'\n",
    "\n",
    "loss = 'contrastive'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81a78c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train a single decoder for each participant\n",
    "\n",
    "group_name = 'group-22'\n",
    "\n",
    "seed = 0\n",
    "for subject_id in range(8):\n",
    "    subject_name = f'subj0{subject_id + 1}'\n",
    "    experiment = main(\n",
    "        nsd_path,\n",
    "        subject_name,\n",
    "        model_name,\n",
    "        stimulus_key,\n",
    "        voxel_selection_path=voxel_selection_path,\n",
    "        voxel_selection_key=voxel_selection_key,\n",
    "        group=group_name,\n",
    "        permutation_test=False,\n",
    "        max_iterations=5001,\n",
    "        loss=loss,\n",
    "        batch_size=128,\n",
    "        num_voxels=num_voxels,\n",
    "        threshold=threshold,\n",
    "        temperature=0.03,\n",
    "        seed=seed,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a071bce-c246-4346-bdee-796e4cac4aee",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Train 50 decoders for each participant\n",
    "\n",
    "\n",
    "group_name = 'group-22_reruns'\n",
    "\n",
    "for run_id in range(50):\n",
    "    seed = run_id + 1\n",
    "    for subject_id in range(8):\n",
    "        subject_name = f'subj0{subject_id + 1}'\n",
    "        experiment = main(\n",
    "            nsd_path,\n",
    "            subject_name,\n",
    "            model_name,\n",
    "            stimulus_key,\n",
    "            voxel_selection_path=voxel_selection_path,\n",
    "            voxel_selection_key=voxel_selection_key,\n",
    "            group=group_name,\n",
    "            permutation_test=False,\n",
    "            max_iterations=5001,\n",
    "            loss=loss,\n",
    "            batch_size=128,\n",
    "            num_voxels=num_voxels,\n",
    "            threshold=threshold,\n",
    "            temperature=0.03,\n",
    "            seed=seed,\n",
    "            result_key=f'run_{run_id}',\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ef15d3-65a4-4586-9527-af40496cd015",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "group_name = 'threshold'\n",
    "temperature = 0.03\n",
    "\n",
    "results = {}\n",
    "for subject_id in range(8):\n",
    "    subject_name = f'subj0{subject_id + 1}'\n",
    "    results[subject_name] = {}\n",
    "    for threshold in (8, 9, 10):\n",
    "        experiment = main(\n",
    "            nsd_path,\n",
    "            subject_name,\n",
    "            model_name,\n",
    "            stimulus_key,\n",
    "            voxel_selection_path=voxel_selection_path,\n",
    "            voxel_selection_key=voxel_selection_key,\n",
    "            group=group_name,\n",
    "            permutation_test=False,\n",
    "            #permutation_fraction=permutation_fraction,\n",
    "            max_iterations=1501,\n",
    "            #loss='contrastive',\n",
    "            loss=loss,\n",
    "            batch_size=128,\n",
    "            num_voxels=None,\n",
    "            threshold=threshold,\n",
    "            temperature=temperature,\n",
    "            result_key='wandb_run_name',\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "00202c35-150b-43af-95e3-6a0c2ffa896e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "subj01\n",
      "X_train.shape=(21750, 17883)\n",
      "1000 [0.07133333333333333]\n",
      "10000 [0.086]\n",
      "subj02\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[9], line 100\u001b[0m\n\u001b[0;32m     98\u001b[0m results \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m     99\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m subject_id \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m8\u001b[39m):\n\u001b[1;32m--> 100\u001b[0m     \u001b[43mfit_ridge\u001b[49m\u001b[43m(\u001b[49m\u001b[43msubject_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresults\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[1;32mIn[9], line 39\u001b[0m, in \u001b[0;36mfit_ridge\u001b[1;34m(subject_id, f, results)\u001b[0m\n\u001b[0;32m     28\u001b[0m train_mask, val_mask, test_mask \u001b[38;5;241m=\u001b[39m nsd\u001b[38;5;241m.\u001b[39mget_split(subject_name, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msplit-01\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m     30\u001b[0m betas_params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\n\u001b[0;32m     31\u001b[0m     subject_name\u001b[38;5;241m=\u001b[39msubject_name,\n\u001b[0;32m     32\u001b[0m     voxel_selection_path\u001b[38;5;241m=\u001b[39mvoxel_selection_path,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     37\u001b[0m     return_tensor_dataset\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m     38\u001b[0m )\n\u001b[1;32m---> 39\u001b[0m betas, betas_indices \u001b[38;5;241m=\u001b[39m \u001b[43mnsd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_betas\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbetas_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     40\u001b[0m X_train, X_val, X_test \u001b[38;5;241m=\u001b[39m betas[train_mask], betas[val_mask], betas[test_mask]\n\u001b[0;32m     41\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mX_train\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\research\\data\\natural_scenes.py:155\u001b[0m, in \u001b[0;36mNaturalScenesDataset.load_betas\u001b[1;34m(self, subject_name, betas_indices, voxel_selection_path, voxel_selection_key, num_voxels, threshold, return_volume_indices, return_tensor_dataset, session_normalize, scale_betas)\u001b[0m\n\u001b[0;32m    151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(betas_indices\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m    152\u001b[0m     betas_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mflatten_indices(subject_name, betas_indices)\n\u001b[0;32m    154\u001b[0m betas \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mstack([\n\u001b[1;32m--> 155\u001b[0m     \u001b[43msubject_betas\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mbetas\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\n\u001b[0;32m    156\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m betas_indices\n\u001b[0;32m    157\u001b[0m ], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m    158\u001b[0m betas \u001b[38;5;241m=\u001b[39m betas\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m    159\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m scale_betas:\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\h5py\\_hl\\dataset.py:758\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[1;34m(self, args, new_dtype)\u001b[0m\n\u001b[0;32m    756\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fast_read_ok \u001b[38;5;129;01mand\u001b[39;00m (new_dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m    757\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 758\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fast_reader\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    759\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m    760\u001b[0m         \u001b[38;5;28;01mpass\u001b[39;00m  \u001b[38;5;66;03m# Fall back to Python read pathway below\u001b[39;00m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Train ridge regression baseline\n",
    "\n",
    "from sklearn.linear_model import RidgeCV, Ridge\n",
    "from research.models.fmri_decoders import Decoder\n",
    "\n",
    "model_name = 'ViT-B=32'\n",
    "stimulus_key = 'embedding'\n",
    "\n",
    "def fit_ridge(subject_id, f, results):\n",
    "    gc.collect()\n",
    "    subject_name = f'subj0{subject_id+1}'\n",
    "    print(subject_name)\n",
    "    voxel_selection = 'nc'\n",
    "    threshold = 8.\n",
    "    num_voxels = None\n",
    "\n",
    "    if voxel_selection == 'nc':\n",
    "        voxel_selection_path = 'derivatives/noise-ceiling.hdf5'\n",
    "        if threshold is None:\n",
    "            voxel_selection_key = f'split-01/sorted_indices'\n",
    "        else:\n",
    "            voxel_selection_key = f'split-01/value'\n",
    "    elif voxel_selection == 'fracridge':\n",
    "        voxel_selection_path = f'derivatives/encoded_betas/{model_name}/fracridge.hdf5'\n",
    "        if threshold is None:\n",
    "            voxel_selection_key = f'{stimulus_key}/volume_indices'\n",
    "        else:\n",
    "            voxel_selection_key = f'{stimulus_key}/value'\n",
    "\n",
    "    train_mask, val_mask, test_mask = nsd.get_split(subject_name, 'split-01')\n",
    "\n",
    "    betas_params = dict(\n",
    "        subject_name=subject_name,\n",
    "        voxel_selection_path=voxel_selection_path,\n",
    "        voxel_selection_key=voxel_selection_key,\n",
    "        num_voxels=num_voxels,\n",
    "        threshold=threshold,\n",
    "        return_volume_indices=True,\n",
    "        return_tensor_dataset=False,\n",
    "    )\n",
    "    betas, betas_indices = nsd.load_betas(**betas_params)\n",
    "    X_train, X_val, X_test = betas[train_mask], betas[val_mask], betas[test_mask]\n",
    "    print(f'{X_train.shape=}')\n",
    "\n",
    "    stimulus_params = dict(\n",
    "        subject_name=subject_name,\n",
    "        stimulus_path=f'derivatives/stimulus_embeddings/{model_name}.hdf5',\n",
    "        stimulus_key=stimulus_key,\n",
    "        delay_loading=False,\n",
    "        return_tensor_dataset=False,\n",
    "        return_stimulus_ids=True,\n",
    "    )\n",
    "    stimulus, stimulus_ids = nsd.load_stimulus(**stimulus_params)\n",
    "    stimulus = stimulus.astype(np.float32)\n",
    "    Y_train, Y_val, Y_test = stimulus[train_mask], stimulus[val_mask], stimulus[test_mask]\n",
    "\n",
    "    best_model = None\n",
    "    best_top_1 = 0\n",
    "\n",
    "    #alpha_grid = (0.1, 1, 10, 100, 1000, 10000, 100000)\n",
    "    alpha_grid = (1000, 10000)\n",
    "    alpha_results = results[subject_name] = []\n",
    "    for alpha in alpha_grid:\n",
    "        gc.collect()\n",
    "        model = Ridge(alpha=alpha)\n",
    "        model.fit(X_train, Y_train)\n",
    "        Y_val_pred = model.predict(X_val)\n",
    "\n",
    "        metric = 'cosine'\n",
    "        top_k_values = [1]\n",
    "\n",
    "        unique_stimulus_ids, unique_index, unique_inverse = np.unique(stimulus_ids[val_mask], return_index=True, return_inverse=True)\n",
    "        top_knn_accuracy = top_knn_test(Y_val[unique_index], Y_val_pred, unique_inverse, k=top_k_values, metric=metric)\n",
    "        print(alpha, top_knn_accuracy)\n",
    "        alpha_results.append(top_knn_accuracy[0])\n",
    "\n",
    "        if top_knn_accuracy[0] > best_top_1:\n",
    "            best_alpha = alpha\n",
    "            best_top_1 = top_knn_accuracy[0]\n",
    "            best_model = model\n",
    "\n",
    "    Y_val_pred = best_model.predict(X_val)\n",
    "    Y_test_pred = best_model.predict(X_test)\n",
    "\n",
    "    group = f.create_group(f'{subject_name}/{stimulus_key}')\n",
    "    group.attrs['alpha'] = best_alpha\n",
    "    \n",
    "    group[f'model/weight'] = model.coef_\n",
    "    group[f'model/bias'] = model.intercept_\n",
    "    group[f'volume_indices'] = betas_indices\n",
    "    group[f'val/Y_pred'] = Y_val_pred\n",
    "    group[f'val/stimulus_ids'] = stimulus_ids[val_mask]\n",
    "    group[f'test/Y_pred'] = Y_test_pred\n",
    "    group[f'test/stimulus_ids'] = stimulus_ids[test_mask]\n",
    "\n",
    "out_path = nsd_path / f'derivatives/decoded_features/{model_name}'\n",
    "out_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "with h5py.File(f'{out_path}/ridge-6.hdf5', 'a') as f:\n",
    "    results = {}\n",
    "    for subject_id in range(8):\n",
    "        fit_ridge(subject_id, f, results)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "591b6847-0f8b-4f72-94c7-7e2b717d2ffa",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_grid = (1000, 10000)\n",
    "x = np.arange(len(alpha_grid))\n",
    "plt.xlabel('alpha')\n",
    "plt.ylabel('top 1 accuracy')\n",
    "plt.title('top 1 accuracy vs ridge regression accuracy')\n",
    "for subject_name, top_1 in results.items():\n",
    "    plt.plot(x, np.array(top_1) * 100, label=subject_name)\n",
    "    plt.xticks(x, alpha_grid)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0190a6f0-ec4f-457a-a950-c7e4252180b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e83a484-d173-44e1-aef5-7bce7a1fce75",
   "metadata": {},
   "outputs": [],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cdae0b5-955a-49db-8b3e-d7e1f45191fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "with h5py.File(nsd_path / f'derivatives/decoded_features/{model_name}/ridge-1.hdf5', 'w') as f:\n",
    "    group = f.create_group(f'{subject_name}/{stimulus_key}')\n",
    "    group[f'model/weight'] = model.coef_\n",
    "    group[f'model/bias'] = model.intercept_\n",
    "    group[f'volume_indices'] = betas_indices\n",
    "    group[f'val/Y_pred'] = Y_val_pred\n",
    "    group[f'val/stimulus_ids'] = stimulus_ids[val_mask]\n",
    "    group[f'test/Y_pred'] = Y_test_pred\n",
    "    group[f'test/stimulus_ids'] = stimulus_ids[test_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4976b555-26e8-4415-b329-002af6d89306",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.intercept_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee2e381-a3e6-4761-bce0-95cf72ac4439",
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f135006c-2db1-49b2-8f12-6ee9010d1773",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'cosine'\n",
    "top_k_values = [1, 5, 10, 50, 100, 500]\n",
    "\n",
    "unique_stimulus_ids, unique_index, unique_inverse = np.unique(stimulus_ids[val_mask], return_index=True, return_inverse=True)\n",
    "top_knn_accuracy = top_knn_test(Y_val[unique_index], Y_val_pred, unique_inverse, k=top_k_values, metric=metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c096bd-d41a-4540-b687-2bb6666b1e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_knn_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdfe2cdd-c313-4a7d-acf6-354f8768b0b3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "wandb_logging = False,\n",
    "\n",
    "run_models = [\n",
    "    ('ViT-B=32', 'embedding'),\n",
    "    #('ViT-B=32', 'transformer.resblocks.3'),\n",
    "    #('bigbigan-resnet50', 'z_mean'),\n",
    "    #('DPT_Large', 'scratch.refinenet4'),\n",
    "]\n",
    "\n",
    "run_models += [('ViT-B=32', f'transformer.resblocks.{i}') for i in range(12)]\n",
    "\n",
    "#subjects = nsd.subjects.keys()\n",
    "#subjects = [f'subj0{i}' for i in range(1, 9)]\n",
    "#subjects = [f'subj06' for i in range(1, 9)]\n",
    "\n",
    "for model_name, stimulus_key in run_models:\n",
    "    for subject_name in subjects:\n",
    "        def run():\n",
    "            notes = None\n",
    "\n",
    "            experiment_params = dict(\n",
    "                batch_size=128 if stimulus_key == 'embedding' else 64,\n",
    "                distance_metric='cosine' if stimulus_key == 'embedding' else 'euclidean',\n",
    "                group='group-4',\n",
    "                max_iterations = 10001,\n",
    "                evaluation_interval = 2500,\n",
    "                channels_last=False, # (model_name == 'ViT-B=32' and stimulus_key != 'embedding'),\n",
    "                wandb_logging=True,\n",
    "            )\n",
    "\n",
    "            betas_params = dict(\n",
    "                subject_name=subject_name,\n",
    "                voxel_selection_path='derivatives/voxel-selection.hdf5',\n",
    "                voxel_selection_key='nc/sorted_indices_flat',\n",
    "                num_voxels=2500,\n",
    "                return_volume_indices=True\n",
    "            )\n",
    "            betas, betas_indices = nsd.load_betas(**betas_params)\n",
    "\n",
    "            stimulus_params = dict(\n",
    "                subject_name=subject_name,\n",
    "                #stimulus_path='nsddata_stimuli/stimuli/nsd/nsd_stimuli.hdf5',\n",
    "                #stimulus_key='imgBrick',\n",
    "                stimulus_path=f'derivatives/stimulus_embeddings/{model_name}.hdf5',\n",
    "                stimulus_key=stimulus_key,\n",
    "                delay_loading=True\n",
    "            )\n",
    "            stimulus = nsd.load_stimulus(**stimulus_params)\n",
    "\n",
    "            dataset = KeyDataset({'betas': betas, 'stimulus': stimulus})\n",
    "            train_dataset, val_dataset, test_dataset = nsd.apply_subject_split(dataset, subject_name, 'split-01')\n",
    "\n",
    "            config = {'model_name': model_name, **betas_params, **stimulus_params}\n",
    "            experiment = run_experiment(\n",
    "                train_dataset,\n",
    "                val_dataset,\n",
    "                config=config,\n",
    "                **experiment_params,\n",
    "            )\n",
    "\n",
    "\n",
    "            with torch.no_grad():\n",
    "                _, Y_val_pred, Y_val_ids = experiment.run_all(val_dataset)\n",
    "                _, Y_test_pred, Y_test_ids = experiment.run_all(test_dataset)\n",
    "                _ = None\n",
    "\n",
    "\n",
    "            def require_dataset(group, key, tensor):\n",
    "                if key in group:\n",
    "                    group[key][:] = tensor\n",
    "                else:\n",
    "                    group[key] = tensor\n",
    "            results_path = nsd_path / 'derivatives/decoded_features'\n",
    "\n",
    "            key_name = wandb.run.group if wandb.run.group else wandb.run.name\n",
    "            save_file_path = results_path / wandb.config['model_name'] / f'{key_name}.hdf5'\n",
    "            save_file_path.parent.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "            h5_key = (wandb.config['subject_name'], wandb.config['stimulus_key'])\n",
    "\n",
    "            attributes = dict(wandb.config)\n",
    "            attributes['wandb_run_name'] = wandb.run.name\n",
    "            attributes['wandb_run_url'] = wandb.run.url\n",
    "            attributes['wandb_group'] = wandb.run.group\n",
    "            attributes['wandb_notes'] = wandb.run.notes\n",
    "\n",
    "            with h5py.File(save_file_path, 'a') as f:\n",
    "                key = '/'.join(h5_key)\n",
    "                group = f.require_group(key)\n",
    "                for k, v in attributes.items():\n",
    "                    group.attrs[k] = v\n",
    "                group.attrs['iteration'] = experiment.iteration\n",
    "                require_dataset(group, 'volume_indices', betas_indices)\n",
    "                require_dataset(group, 'test/Y_pred', Y_test_pred.detach().cpu())\n",
    "                require_dataset(group, 'test/stimulus_ids', Y_test_ids)\n",
    "                require_dataset(group, 'val/Y_pred', Y_val_pred.detach().cpu())\n",
    "                require_dataset(group, 'val/stimulus_ids', Y_val_ids)\n",
    "\n",
    "                model_group = group.require_group('model')\n",
    "                for param_name, weights in experiment.model.state_dict().items():\n",
    "                    weights = weights.cpu()\n",
    "                    require_dataset(model_group, param_name, weights)\n",
    "\n",
    "            experiment = None\n",
    "            wandb.finish()\n",
    "        run()\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52738a0e-56f9-41db-898c-cf866c7f068f",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_path = nsd_path / 'derivatives/decoded_features'\n",
    "\n",
    "group_name = 'group-5'\n",
    "fold = 'val'\n",
    "\n",
    "eval_models = [\n",
    "    ('group-10', 'ViT-B=32', 'embedding'),\n",
    "    #('group-3', 'bigbigan-resnet50', 'z_mean'),\n",
    "    #('DPT_Large', 'scratch.refinenet4'),\n",
    "    #('group-5', 'clip-vit-large-patch14-text', 'embedding_unpooled')\n",
    "]\n",
    "#eval_models += [('ViT-B=32', f'transformer.resblocks.{i}') for i in range(12)]\n",
    "\n",
    "subjects = [f'subj0{i}' for i in range(1, 9)]\n",
    "top_k_values = [1, 5, 10, 50, 100, 500]\n",
    "\n",
    "results = {}\n",
    "for group_name, model_name, stimulus_key in eval_models:\n",
    "    results[(model_name, stimulus_key)] = stimulus_results = {}\n",
    "    \n",
    "    stimulus_file = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')\n",
    "    stimulus = stimulus_file[stimulus_key]\n",
    "    result_file = h5py.File(nsd_path / f'derivatives/decoded_features/{model_name}/{group_name}.hdf5', 'r')\n",
    "    \n",
    "    for subject_name in subjects:\n",
    "        stimulus_results[subject_name] = subject_results = {}\n",
    "        print(model_name, stimulus_key, subject_name)\n",
    "        result = result_file[subject_name][stimulus_key][fold]\n",
    "        stimulus_ids = result['stimulus_ids'][:]\n",
    "        \n",
    "        print('load')\n",
    "        Y_pred = result['Y_pred'][:]\n",
    "        Y = index_unsorted(stimulus, stimulus_ids)\n",
    "        N = Y.shape[0]\n",
    "        Y_pred = Y_pred.reshape(N, -1)\n",
    "        Y = Y.reshape(N, -1)\n",
    "        \n",
    "        print('knn')\n",
    "        metric = 'cosine' if stimulus_key == 'embedding' else 'euclidean' \n",
    "        \n",
    "        unique_stimulus_ids, unique_index, unique_inverse = np.unique(stimulus_ids, return_index=True, return_inverse=True)\n",
    "        top_knn_accuracy = top_knn_test(Y[unique_index], Y_pred, unique_inverse, k=top_k_values, metric=metric)\n",
    "        subject_results['top_knn_accuracy'] = top_knn_accuracy\n",
    "        print(subject_results)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c63f6b5c-78c6-4109-af7e-099f2f02ffa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000\n",
    "chance_accuracy = [k / N for k in top_k_values]\n",
    "\n",
    "model_name = 'ViT-B=32'\n",
    "stimulus_key = 'embedding'\n",
    "\n",
    "plt.figure(figsize=(12, 8))\n",
    "plt.xticks(ticks=range(len(top_k_values)), labels=top_k_values)\n",
    "plt.title(f'Top knn accuracy (n={N})\\n{model_name=}, {stimulus_key=}')\n",
    "plt.xlabel('k')\n",
    "plt.ylabel('accuracy')\n",
    "plt.plot(range(len(top_k_values)), chance_accuracy, label='chance (k/n)', color='gray')\n",
    "for subject, subject_results in results[(model_name, stimulus_key)].items():\n",
    "    top_knn_accuracy = subject_results['top_knn_accuracy']\n",
    "    plt.plot(range(len(top_k_values)), top_knn_accuracy, label=subject)\n",
    "\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "783a0016-844d-4602-b37e-fa0c9a1724a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000\n",
    "chance_accuracy = [k / N for k in top_k_values]\n",
    "\n",
    "model_name = 'clip-vit-large-patch14-text'\n",
    "stimulus_key = 'embedding_unpooled'\n",
    "\n",
    "plt.figure(figsize=(12, 8))\n",
    "plt.xticks(ticks=range(len(top_k_values)), labels=top_k_values)\n",
    "plt.title(f'Top knn accuracy (n={N})\\n{model_name=}, {stimulus_key=}')\n",
    "plt.xlabel('k')\n",
    "plt.ylabel('accuracy')\n",
    "plt.plot(range(len(top_k_values)), chance_accuracy, label='chance (k/n)', color='gray')\n",
    "for subject, subject_results in results[(model_name, stimulus_key)].items():\n",
    "    top_knn_accuracy = subject_results['top_knn_accuracy']\n",
    "    plt.plot(range(len(top_k_values)), top_knn_accuracy, label=subject)\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c894204-0270-4f31-964f-151054f404b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_stimulus_ids, unique_indices = np.unique(stimulus_ids, return_index=True)\n",
    "\n",
    "np.stack([Y_pred[i == stimulus_ids].mean(axis=0) for i in unique_stimulus_ids]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac42247f-11fb-4856-a871-23cb98fe866c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_pred.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
