{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e0f38025-c211-495d-9b5e-87134834f9c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import h5py\n",
    "import h5py\n",
    "import pandas as pd\n",
    "import nibabel as nib\n",
    "from pathlib import Path\n",
    "from einops import rearrange\n",
    "\n",
    "dir2 = os.path.abspath('../..')\n",
    "dir1 = os.path.dirname(dir2)\n",
    "if not dir1 in sys.path: \n",
    "    sys.path.append(dir1)\n",
    "    \n",
    "from research.data.natural_scenes import NaturalScenesDataset\n",
    "from research.metrics.metrics import compute_ncsnr_fast, compute_nc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0730d6f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_path = Path('D:\\\\Datasets\\\\NSD\\\\')\n",
    "\n",
    "derivatives_path = dataset_path / 'derivatives'\n",
    "betas_path = dataset_path / 'nsddata_betas' / 'ppdata'\n",
    "ppdata_path = dataset_path / 'nsddata' / 'ppdata'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6cb9171b",
   "metadata": {},
   "outputs": [],
   "source": [
    "subjects = {f'subj0{i}': {} for i in range(1, 9)}\n",
    "\n",
    "for subject_name, subject_data in subjects.items():\n",
    "    responses_file_path = ppdata_path / subject_name / 'behav' / 'responses.tsv'\n",
    "    subject_data['responses'] = pd.read_csv(responses_file_path, sep='\\t',)\n",
    "    \n",
    "    # The last 3 sessions are currently held-out for the algonauts challenge\n",
    "    # remove them for now.\n",
    "    session_ids = subject_data['responses']['SESSION']\n",
    "    held_out_mask = session_ids > (np.max(session_ids) - 3)\n",
    "    subject_data['responses'] = subject_data['responses'][~held_out_mask]\n",
    "\n",
    "    subject_sessions_path = betas_path / subject_name / 'func1pt8mm' / 'betas_fithrf_GLMdenoise_RR'\n",
    "    num_sessions = np.max(subject_data['responses']['SESSION'])\n",
    "\n",
    "    subject_data['sessions'] = [\n",
    "        h5py.File(subject_sessions_path / f'betas_session{i:02}.hdf5', 'r')\n",
    "        for i in range(1, num_sessions + 1)\n",
    "    ]\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "90f6d920",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "subj01\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 20\u001b[0m\n\u001b[0;32m     15\u001b[0m Y \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate([\n\u001b[0;32m     16\u001b[0m     session[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m'\u001b[39m][:, i]\n\u001b[0;32m     17\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m session \u001b[38;5;129;01min\u001b[39;00m sessions\n\u001b[0;32m     18\u001b[0m ])\n\u001b[0;32m     19\u001b[0m slice_size \u001b[38;5;241m=\u001b[39m H \u001b[38;5;241m*\u001b[39m D\n\u001b[1;32m---> 20\u001b[0m \u001b[43mf\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mbetas\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mslice_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m:\u001b[49m\u001b[43mslice_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mi\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m rearrange(Y, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt ... -> t (...)\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\h5py\\_hl\\dataset.py:999\u001b[0m, in \u001b[0;36mDataset.__setitem__\u001b[1;34m(self, args, val)\u001b[0m\n\u001b[0;32m    997\u001b[0m mspace \u001b[38;5;241m=\u001b[39m h5s\u001b[38;5;241m.\u001b[39mcreate_simple(selection\u001b[38;5;241m.\u001b[39mexpand_shape(mshape))\n\u001b[0;32m    998\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fspace \u001b[38;5;129;01min\u001b[39;00m selection\u001b[38;5;241m.\u001b[39mbroadcast(mshape):\n\u001b[1;32m--> 999\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdxpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dxpl\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Concatenate all of the betas_sessions together into a single file (this will take a while)\n",
    "\n",
    "for subject_name, subject_data in subjects.items():\n",
    "    print(subject_name)\n",
    "    path = derivatives_path / 'betas' / subject_name / 'func1pt8mm' / 'betas_fithrf_GLMdenoise_RR'\n",
    "    path.mkdir(parents=True, exist_ok=True)\n",
    "    with h5py.File(path / 'betas_sessions.hdf5', 'a') as f:\n",
    "\n",
    "        sessions = subject_data['sessions']\n",
    "        num_sessions = len(sessions)\n",
    "        shape = sessions[0]['betas'].shape\n",
    "        T, W, H, D = shape\n",
    "        T_full = T * len(sessions)\n",
    "        \n",
    "        f.require_dataset('betas', shape=(T_full, W * H * D), dtype=np.int16, chunks=(T_full, 1))\n",
    "        for i in range(W):\n",
    "            Y = np.concatenate([\n",
    "                session['betas'][:, i]\n",
    "                for session in sessions\n",
    "            ])\n",
    "            slice_size = H * D\n",
    "            f['betas'][:, slice_size * i:slice_size * (i + 1)] = rearrange(Y, 't ... -> t (...)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0c8e1f1b-f56c-447e-ba90-3949e2033c90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the NaturalScenesDataset object\n",
    "\n",
    "nsd = NaturalScenesDataset(dataset_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01775854-0e7a-45ca-97f4-54c11789f163",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load image ids of the shared 1000 images across all participants\n",
    "shared_1000_path = dataset_path / 'nsddata' / 'stimuli' / 'nsd' / 'shared1000.tsv'\n",
    "shared_1000 = pd.read_csv(shared_1000_path, sep='\\t', header=None)\n",
    "shared_1000 = set(shared_1000[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "76ec8364-7ada-4441-b5ad-2457ca1162d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "subj01 image_ids.shape=(27750,), len(three_repetition_ids)=8420\n",
      "subj02 image_ids.shape=(27750,), len(three_repetition_ids)=8420\n",
      "subj03 image_ids.shape=(21750,), len(three_repetition_ids)=5081\n",
      "subj04 image_ids.shape=(20250,), len(three_repetition_ids)=4424\n",
      "subj05 image_ids.shape=(27750,), len(three_repetition_ids)=8420\n",
      "subj06 image_ids.shape=(21750,), len(three_repetition_ids)=5081\n",
      "subj07 image_ids.shape=(27750,), len(three_repetition_ids)=8420\n",
      "subj08 image_ids.shape=(20250,), len(three_repetition_ids)=4424\n",
      "len(shared_1000_three_repetitions)=413\n"
     ]
    }
   ],
   "source": [
    "# Generate a train-test-validation split\n",
    "split_name = 'split-01'\n",
    "N_test = 1000\n",
    "N_validation = 1000\n",
    "\n",
    "seed = 0\n",
    "\n",
    "for subject_name, subject_data in nsd.subjects.items():\n",
    "    responses = subject_data['responses']\n",
    "    \n",
    "    image_ids = responses['73KID'].to_numpy()\n",
    "    unique_image_ids, unique_counts = np.unique(image_ids, return_counts=True)\n",
    "    three_repetition_ids = unique_image_ids[unique_counts == 3]\n",
    "    subject_data['three_repetition_ids'] = set(three_repetition_ids)\n",
    "    print(f'{subject_name} {image_ids.shape=}, {len(three_repetition_ids)=}')\n",
    "    \n",
    "shared_1000_three_repetitions = set.intersection(\n",
    "    shared_1000,\n",
    "    *[subject_data['three_repetition_ids']\n",
    "    for subject_data in nsd.subjects.values()]\n",
    ")\n",
    "print(f'{len(shared_1000_three_repetitions)=}')\n",
    "N_non_shared = N_test - len(shared_1000_three_repetitions)\n",
    "\n",
    "\n",
    "for subject_name, subject_data in nsd.subjects.items():\n",
    "    three_repetition_ids = subject_data['three_repetition_ids']\n",
    "    non_shared_three_repetition_ids = list(three_repetition_ids - shared_1000_three_repetitions)\n",
    "    random.Random(seed).shuffle(non_shared_three_repetition_ids)\n",
    "    \n",
    "    test_image_ids = list(shared_1000_three_repetitions) + non_shared_three_repetition_ids[:N_non_shared]\n",
    "    validation_image_ids = non_shared_three_repetition_ids[N_non_shared:(N_non_shared + N_validation)]\n",
    "    subject_data['test_image_ids'] = np.array(test_image_ids)\n",
    "    subject_data['validation_image_ids'] = np.array(test_image_ids)\n",
    "    \n",
    "    test_image_ids = set(test_image_ids)\n",
    "    validation_image_ids = set(validation_image_ids)\n",
    "    image_ids = subject_data['responses']['73KID'].to_numpy()\n",
    "    subject_data['test_response_ids'] = np.argwhere([image_id in test_image_ids for image_id in image_ids])[:, 0]\n",
    "    subject_data['validation_response_ids'] = np.argwhere([image_id in validation_image_ids for image_id in image_ids])[:, 0]\n",
    "\n",
    "(derivatives_path / 'data_splits').mkdir(exist_ok=True, parents=True)\n",
    "with h5py.File(derivatives_path / 'data_splits' / f'{split_name}.hdf5', 'w') as f:\n",
    "    for subject_name, subject_data in nsd.subjects.items():\n",
    "        subject = f.require_group(subject_name)\n",
    "        \n",
    "        three_repetition_ids = subject_data['three_repetition_ids']\n",
    "        non_shared_three_repetition_ids = list(three_repetition_ids - shared_1000_three_repetitions)\n",
    "        random.Random(seed).shuffle(non_shared_three_repetition_ids)\n",
    "\n",
    "        test_image_ids = list(shared_1000_three_repetitions) + non_shared_three_repetition_ids[:N_non_shared]\n",
    "        validation_image_ids = non_shared_three_repetition_ids[N_non_shared:(N_non_shared + N_validation)]\n",
    "        subject['test_image_ids'] = np.array(test_image_ids)\n",
    "        subject['validation_image_ids'] = np.array(test_image_ids)\n",
    "\n",
    "        test_image_ids = set(test_image_ids)\n",
    "        validation_image_ids = set(validation_image_ids)\n",
    "        image_ids = subject_data['responses']['73KID'].to_numpy()\n",
    "        subject['test_response_mask'] = np.array([image_id in test_image_ids for image_id in image_ids], dtype=bool)\n",
    "        subject['validation_response_mask'] = np.array([image_id in validation_image_ids for image_id in image_ids], dtype=bool)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4cce7e03",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "subj01\n",
      "10133/699192, 1.4%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "g:\\Github Repositories\\contrastive-decoding-neurips24\\research\\metrics\\metrics.py:291: RuntimeWarning: divide by zero encountered in divide\n",
      "  ncsnr = std_signal / std_noise\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20267/699192, 2.9%\n",
      "30401/699192, 4.3%\n",
      "40535/699192, 5.8%\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[20], line 33\u001b[0m\n\u001b[0;32m     31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m betas_indices \u001b[38;5;129;01min\u001b[39;00m indices_batches:\n\u001b[0;32m     32\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbetas_indices[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_voxels\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbetas_indices[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;250m \u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;250m \u001b[39mnum_voxels\u001b[38;5;250m \u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m100\u001b[39m\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m---> 33\u001b[0m     betas \u001b[38;5;241m=\u001b[39m \u001b[43mnsd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_betas\u001b[49m\u001b[43m(\u001b[49m\u001b[43msubject_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbetas_indices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbetas_indices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_tensor_dataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m     34\u001b[0m     betas \u001b[38;5;241m=\u001b[39m betas[train_mask]\n\u001b[0;32m     35\u001b[0m     ncsnr\u001b[38;5;241m.\u001b[39mappend(compute_ncsnr_fast(betas, repetition_ids))\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\research\\data\\natural_scenes.py:155\u001b[0m, in \u001b[0;36mNaturalScenesDataset.load_betas\u001b[1;34m(self, subject_name, betas_indices, voxel_selection_path, voxel_selection_key, num_voxels, threshold, return_volume_indices, return_tensor_dataset, session_normalize, scale_betas)\u001b[0m\n\u001b[0;32m    151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(betas_indices\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m    152\u001b[0m     betas_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mflatten_indices(subject_name, betas_indices)\n\u001b[0;32m    154\u001b[0m betas \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mstack([\n\u001b[1;32m--> 155\u001b[0m     \u001b[43msubject_betas\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mbetas\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\n\u001b[0;32m    156\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m betas_indices\n\u001b[0;32m    157\u001b[0m ], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m    158\u001b[0m betas \u001b[38;5;241m=\u001b[39m betas\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m    159\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m scale_betas:\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\h5py\\_hl\\dataset.py:758\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[1;34m(self, args, new_dtype)\u001b[0m\n\u001b[0;32m    756\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fast_read_ok \u001b[38;5;129;01mand\u001b[39;00m (new_dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m    757\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 758\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fast_reader\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    759\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m    760\u001b[0m         \u001b[38;5;28;01mpass\u001b[39;00m  \u001b[38;5;66;03m# Fall back to Python read pathway below\u001b[39;00m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "def require_dataset(group, name, data):\n",
    "    group.require_dataset(name, shape=data.shape, dtype=data.dtype)\n",
    "    group[name][:] = data\n",
    "\n",
    "with h5py.File(nsd.dataset_path / 'derivatives/noise-ceiling.hdf5', 'a') as f:\n",
    "    for subject_id in range(8):\n",
    "        subject_name = f'subj0{subject_id + 1}'\n",
    "        print(subject_name)\n",
    "        \n",
    "        subject = nsd.subjects[subject_name]\n",
    "        train_mask = nsd.get_split(subject_name, 'split-01')[0]\n",
    "\n",
    "        betas_h5 = subject['betas']\n",
    "        responses = subject['responses']\n",
    "        stimulus_ids = np.array(responses['73KID']) - 1\n",
    "        stimulus_ids = stimulus_ids[train_mask]\n",
    "\n",
    "        n = 3\n",
    "        unique_ids, unique_counts = np.unique(stimulus_ids, return_counts=True)\n",
    "        atleast_n_ids = unique_ids[unique_counts >= n]\n",
    "        repetition_ids = np.stack([\n",
    "            np.where(stimulus_ids == i)[0][:n]\n",
    "            for i in atleast_n_ids\n",
    "        ])\n",
    "\n",
    "        num_betas, num_voxels = betas_h5['betas'].shape\n",
    "        voxel_batch_size = 10000\n",
    "        indices_batches = np.array_split(np.arange(num_voxels), num_voxels // voxel_batch_size)\n",
    "        ncsnr = []\n",
    "\n",
    "        for betas_indices in indices_batches:\n",
    "            print(f'{betas_indices[-1]}/{num_voxels}, {betas_indices[-1] / num_voxels * 100:.1f}%')\n",
    "            betas = nsd.load_betas(subject_name, betas_indices=betas_indices, return_tensor_dataset=False)[0]\n",
    "            betas = betas[train_mask]\n",
    "            ncsnr.append(compute_ncsnr_fast(betas, repetition_ids))\n",
    "        ncsnr = np.concatenate(ncsnr)\n",
    "\n",
    "        nc = compute_nc(ncsnr, num_averages=1)\n",
    "\n",
    "        voxel_selection_path = 'derivatives/voxel-selection.hdf5'\n",
    "        voxel_selection_key = 'nc/value'\n",
    "\n",
    "        voxel_selection_file = h5py.File(nsd.dataset_path / voxel_selection_path, 'r')\n",
    "        key = f'{subject_name}/{voxel_selection_key}'\n",
    "        nc_original = voxel_selection_file[key][:]\n",
    "        \n",
    "        nc = nc.reshape(nc_original.shape)\n",
    "        nc[np.isnan(nc)] = 0.\n",
    "        grid = np.argwhere(np.ones_like(nc, dtype=bool))\n",
    "        nc_sorted_indices_flat = nc.argsort(axis=None)[::-1].astype(int)\n",
    "        nc_sorted_indices = grid[nc_sorted_indices_flat].astype(int)\n",
    "        \n",
    "        require_dataset(f, f'{subject_name}/split-01/value', nc)\n",
    "        require_dataset(f, f'{subject_name}/split-01/sorted_indices_flat', nc_sorted_indices_flat)\n",
    "        require_dataset(f, f'{subject_name}/split-01/sorted_indices', nc_sorted_indices)\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "97f93265",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "nsd_path = Path('D:\\\\Datasets\\\\NSD')\n",
    "stimuli_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'\n",
    "stimulus_images = h5py.File(stimuli_path, 'r')['imgBrick']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "32a5851f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']\n"
     ]
    }
   ],
   "source": [
    "# Load a clip model\n",
    "import clip\n",
    "\n",
    "print(clip.available_models())\n",
    "model_name = 'ViT-B/32'\n",
    "full_model, preprocess = clip.load(model_name, device=device)\n",
    "model = full_model.visual\n",
    "\n",
    "save_modules = {\n",
    "    '': 'embedding'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "475ece85",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[27], line 36\u001b[0m\n\u001b[0;32m     34\u001b[0m         hook_handles\u001b[38;5;241m.\u001b[39mappend(hook_handle)\n\u001b[0;32m     35\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m---> 36\u001b[0m     \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     37\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook_handle \u001b[38;5;129;01min\u001b[39;00m hook_handles:\n\u001b[0;32m     38\u001b[0m     hook_handle\u001b[38;5;241m.\u001b[39mremove()\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1603\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1600\u001b[0m     bw_hook \u001b[38;5;241m=\u001b[39m hooks\u001b[38;5;241m.\u001b[39mBackwardHook(\u001b[38;5;28mself\u001b[39m, full_backward_hooks, backward_pre_hooks)\n\u001b[0;32m   1601\u001b[0m     args \u001b[38;5;241m=\u001b[39m bw_hook\u001b[38;5;241m.\u001b[39msetup_input_hook(args)\n\u001b[1;32m-> 1603\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1604\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks:\n\u001b[0;32m   1605\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m hook_id, hook \u001b[38;5;129;01min\u001b[39;00m (\n\u001b[0;32m   1606\u001b[0m         \u001b[38;5;241m*\u001b[39m_global_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[0;32m   1607\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[0;32m   1608\u001b[0m     ):\n\u001b[0;32m   1609\u001b[0m         \u001b[38;5;66;03m# mark that always called hook is run\u001b[39;00m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\clip\\model.py:233\u001b[0m, in \u001b[0;36mVisionTransformer.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m    230\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mln_pre(x)\n\u001b[0;32m    232\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m)  \u001b[38;5;66;03m# NLD -> LND\u001b[39;00m\n\u001b[1;32m--> 233\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    234\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m)  \u001b[38;5;66;03m# LND -> NLD\u001b[39;00m\n\u001b[0;32m    236\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mln_post(x[:, \u001b[38;5;241m0\u001b[39m, :])\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\clip\\model.py:204\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m    203\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[1;32m--> 204\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresblocks\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:219\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    217\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m    218\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 219\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m    220\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\clip\\model.py:191\u001b[0m, in \u001b[0;36mResidualAttentionBlock.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m    190\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[1;32m--> 191\u001b[0m     x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mln_1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    192\u001b[0m     x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mln_2(x))\n\u001b[0;32m    193\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\clip\\model.py:188\u001b[0m, in \u001b[0;36mResidualAttentionBlock.attention\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m    186\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mattention\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[0;32m    187\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattn_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattn_mask\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mdtype, device\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mdevice) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattn_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m--> 188\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\activation.py:1275\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[1;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m   1261\u001b[0m     attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmulti_head_attention_forward(\n\u001b[0;32m   1262\u001b[0m         query, key, value, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_dim, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads,\n\u001b[0;32m   1263\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_bias,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1272\u001b[0m         average_attn_weights\u001b[38;5;241m=\u001b[39maverage_attn_weights,\n\u001b[0;32m   1273\u001b[0m         is_causal\u001b[38;5;241m=\u001b[39mis_causal)\n\u001b[0;32m   1274\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1275\u001b[0m     attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulti_head_attention_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1276\u001b[0m \u001b[43m        \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1277\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1278\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_v\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_zero_attn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1279\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1280\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtraining\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1281\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1282\u001b[0m \u001b[43m        \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneed_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1283\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1284\u001b[0m \u001b[43m        \u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1285\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1286\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_first \u001b[38;5;129;01mand\u001b[39;00m is_batched:\n\u001b[0;32m   1287\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m attn_output\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m), attn_output_weights\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\functional.py:5420\u001b[0m, in \u001b[0;36mmulti_head_attention_forward\u001b[1;34m(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m   5418\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m use_separate_proj_weight:\n\u001b[0;32m   5419\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m in_proj_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_separate_proj_weight is False but in_proj_weight is None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m-> 5420\u001b[0m     q, k, v \u001b[38;5;241m=\u001b[39m \u001b[43m_in_projection_packed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_proj_bias\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   5421\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   5422\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m q_proj_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_separate_proj_weight is True but q_proj_weight is None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "File \u001b[1;32mg:\\Github Repositories\\contrastive-decoding-neurips24\\.venv\\Lib\\site-packages\\torch\\nn\\functional.py:4920\u001b[0m, in \u001b[0;36m_in_projection_packed\u001b[1;34m(q, k, v, w, b)\u001b[0m\n\u001b[0;32m   4917\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mis\u001b[39;00m v:\n\u001b[0;32m   4918\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m q \u001b[38;5;129;01mis\u001b[39;00m k:\n\u001b[0;32m   4919\u001b[0m         \u001b[38;5;66;03m# self-attention\u001b[39;00m\n\u001b[1;32m-> 4920\u001b[0m         proj \u001b[38;5;241m=\u001b[39m \u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   4921\u001b[0m         \u001b[38;5;66;03m# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()\u001b[39;00m\n\u001b[0;32m   4922\u001b[0m         proj \u001b[38;5;241m=\u001b[39m proj\u001b[38;5;241m.\u001b[39munflatten(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, (\u001b[38;5;241m3\u001b[39m, E))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from functools import partial\n",
    "#from tqdm.notebook import tqdm\n",
    "from PIL import Image\n",
    "from functools import partial\n",
    "from typing import Sequence, Dict\n",
    "\n",
    "out_path = dataset_path / 'derivatives' / 'stimulus_embeddings'\n",
    "out_path.mkdir(exist_ok=True, parents=True)\n",
    "modules = dict(model.named_modules())\n",
    "\n",
    "with h5py.File(out_path / f\"{model_name.replace('/', '=').replace('@', '-')}.hdf5\", \"a\") as f:\n",
    "    N = stimulus_images.shape[0]\n",
    "    for stimulus_id in range(N):\n",
    "        image_data = stimulus_images[stimulus_id]\n",
    "\n",
    "        image = Image.fromarray(image_data)\n",
    "        x = preprocess(image).unsqueeze(0).to(device) #.to(torch.float16)\n",
    "\n",
    "        features = {}\n",
    "        def forward_hook(module_name, module, x_in, x_out):\n",
    "            if x_out.shape[0] == 1:\n",
    "                x_out = x_out[0]\n",
    "            features[module_name] = x_out.clone().cpu().float().numpy()\n",
    "        hook_handles = []\n",
    "        if isinstance(save_modules, Sequence):\n",
    "            for module_name in save_modules:\n",
    "                module = modules[module_name]\n",
    "                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))\n",
    "                hook_handles.append(hook_handle)\n",
    "        elif isinstance(save_modules, Dict):\n",
    "            for module_name, feature_name in save_modules.items():\n",
    "                module = modules[module_name]\n",
    "                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))\n",
    "                hook_handles.append(hook_handle)\n",
    "        with torch.no_grad():\n",
    "            model(x)\n",
    "        for hook_handle in hook_handles:\n",
    "            hook_handle.remove()\n",
    "        for feature_name, feature in features.items():\n",
    "            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)\n",
    "            f[feature_name][stimulus_id] = feature\n",
    "            \n",
    "            \n",
    "            \n",
    "            "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
