{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b680743-fe42-4239-af54-0a2f32db4174",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import gc\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "from einops import rearrange\n",
    "import time\n",
    "import random\n",
    "import string\n",
    "import h5py\n",
    "from tqdm import tqdm\n",
    "import webdataset as wds\n",
    "\n",
    "\n",
    "# tf32 data type is faster than standard float32\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "sys.path.append('generative_models/')\n",
    "import sgm\n",
    "from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder\n",
    "\n",
    "# custom models and functions #\n",
    "import v2_utils as utils\n",
    "\n",
    "from dataloaders import ImageVoxelDataset,ImageVoxelAdapterDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c38ef3d9-6405-4f39-bf8b-802ad1410878",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e43cf0a-da43-45ec-94c0-053ff43ccd5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output # function to clear print outputs in cell\n",
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3eb7cbd-f774-4d3a-bd7c-9426ad29b5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = torch.float16 # change depending on your mixed_precision\n",
    "local_rank = 0\n",
    "world_size = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d89ea63-67fc-41fe-b51b-5bb3077a7581",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abae72a2-db11-41f0-9d31-0c55f895921a",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36544d40-5816-46ca-9d40-26dceeacc893",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='subj1_l_bclip_basictest_wll_5_finetune_s7_250_aamax'\n",
    "data_path='data/new_dl/'\n",
    "checkpoint_name = 'subj1_l_bclip_basictest_wll_5'\n",
    "index_name = 'subj1_indices_30d_105.0w_256.npy'\n",
    "adaptalign=True\n",
    "subj=7\n",
    "ref_sub=1\n",
    "max_lr=3e-4\n",
    "checkpoint_tag = 'mid_200'\n",
    "batch_size=16\n",
    "hidden=True\n",
    "#resume_from_ckpt=False\n",
    "mixup_pct=0.0\n",
    "norm_embs=True\n",
    "use_image_aug=False\n",
    "blurry_recon=True\n",
    "blur_scale=0.5\n",
    "num_epochs=201\n",
    "prior=True\n",
    "v2c=True\n",
    "lr_scheduler_type='cycle'\n",
    "ckpt_saving=True\n",
    "ckpt_interval=100\n",
    "use_cos_loss=False\n",
    "run_common=True\n",
    "run_train=False #Set this to false if you are training with limited data/on common images only\n",
    "seed=42\n",
    "use_projector=True\n",
    "cache_dir='checkpoints'\n",
    "hidden_dim=4096\n",
    "checkpoint_dir = os.path.abspath(f'train_logs/{checkpoint_name}')\n",
    "load_indices=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da03490a-3855-4657-8796-193c6e7a2146",
   "metadata": {},
   "outputs": [],
   "source": [
    "if subj == 1:\n",
    "    num_voxels = 15724\n",
    "elif subj == 2:\n",
    "    num_voxels = 14278\n",
    "elif subj == 3:\n",
    "    num_voxels = 15226\n",
    "elif subj == 4:\n",
    "    num_voxels = 13153\n",
    "elif subj == 5:\n",
    "    num_voxels = 13039\n",
    "elif subj == 6:\n",
    "    num_voxels = 17907\n",
    "elif subj == 7:\n",
    "    num_voxels = 12682\n",
    "elif subj == 8:\n",
    "    num_voxels = 14386"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aea907fd-7047-4245-92cc-28348ec442fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_voxels_list = [15724,num_voxels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee3456e2-d0c8-47c5-b46b-a94dadb6cda3",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_voxels_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc795ad3-95a4-438f-90af-17dcbe4ebe0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = os.path.abspath(f'train_logs/{model_name}')\n",
    "if not os.path.exists(outdir):\n",
    "    os.makedirs(outdir,exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78876aa2-3f52-40c6-aa96-11263d1747c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_image_aug:\n",
    "    import kornia\n",
    "    from kornia.augmentation.container import AugmentationSequential\n",
    "    img_augment = AugmentationSequential(\n",
    "        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n",
    "        kornia.augmentation.Resize((224, 224)),\n",
    "        kornia.augmentation.RandomHorizontalFlip(p=0.5),\n",
    "        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n",
    "        kornia.augmentation.RandomGrayscale(p=0.3),\n",
    "        data_keys=[\"input\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e37a1dc8-8e7d-48e5-bfa1-ea42b9880d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "import kornia\n",
    "from kornia.augmentation.container import AugmentationSequential\n",
    "img_augment2 = AugmentationSequential(\n",
    "    kornia.augmentation.Resize((224, 224)),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a943532f-b728-4b8b-89e1-6f6197ceca9f",
   "metadata": {},
   "source": [
    "## Build Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a74c808f-b265-4698-9746-463b0a38abe7",
   "metadata": {},
   "source": [
    "### Adapter GT for Subject of Your Choice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6fa594f-1a3f-486a-bdf8-127ca3720c8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    # nn.LayerNorm(out_features),\n",
    "                    # nn.GELU(),\n",
    "                    # nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "        \n",
    "ridge = RidgeRegression([num_voxels_list[0]], out_features=hidden_dim).to(device)\n",
    "ridge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14ae682a-7464-47c2-b396-bf523fb50aa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls train_logs/$checkpoint_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e74617-c4f6-4d32-9286-6d83eeae70d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading {checkpoint_dir}/{checkpoint_tag}.pth\")\n",
    "checkpoint = torch.load(f\"{checkpoint_dir}/{checkpoint_tag}.pth\", map_location='cpu')\n",
    "ridge_state_dict = {k.replace('ridge.', ''): v for k, v in checkpoint['model_state_dict'].items() if k.startswith('ridge.')}\n",
    "for k,v in ridge_state_dict.items():\n",
    "    print(k,v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1741e71d-d1ab-4b76-a41b-6967dcce80eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "ridge.load_state_dict(ridge_state_dict, strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19950c1d-1ded-468d-94fe-f817dc1bac40",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Save some space\n",
    "del checkpoint\n",
    "del ridge_state_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde5f4ac-5d8b-4cfb-a27b-b99ee3eb6f09",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load the reference subject common images\n",
    "feature_file = f'data/new_dl/subj{ref_sub:02d}/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{ref_sub:02d}/test/images.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5774f129-645d-4f27-98a4-10c464885ca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "# Create the DataLoader to handle batching and optional shuffling\n",
    "val_dl = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30bb3534-8344-4a85-a5c2-5731a48b9c4e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ridge.eval()\n",
    "adapter_outputs = []\n",
    "with torch.no_grad():\n",
    "    for test_i, (voxel, image) in enumerate(val_dl):\n",
    "        voxel = voxel.to(device)\n",
    "        voxel = torch.mean(voxel,axis=1).float()\n",
    "        #print(voxel.shape)\n",
    "        voxel_ridge = ridge(voxel,0)\n",
    "        adapter_outputs.append(voxel_ridge)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d16caa1-7d61-4b0b-a66d-4f4487f3fb2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_outputs = torch.concatenate(adapter_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "697e696b-4eef-43ef-8e8f-564763643f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_outputs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fd45aed-b16d-4fa9-a333-ec50046d651e",
   "metadata": {},
   "source": [
    "### Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebea4476-5f0f-4291-ac58-5aeebd8757b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# In case the mapping between subjects is not clean\n",
    "ref_image_ids = np.load(f'data/new_dl/subj{ref_sub:02d}/test/image_ids.npy')\n",
    "refid_to_index = {img_id: idx for idx, img_id in enumerate(ref_image_ids)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef76d803-ea0d-41a5-b0f0-98a2dee8cf60",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_image_ids = np.load(f'data/new_dl/subj0{subj}/test/image_ids.npy')\n",
    "new_to_ref_indices = [refid_to_index[img_id] for img_id in new_image_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e592b3-e497-478a-a5e5-009f44910d96",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.count_nonzero(new_image_ids == ref_image_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab142e9-397c-44c1-aadf-b9e77cd6e694",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Common DL\n",
    "if run_common:\n",
    "    #Load indices here if chosen:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/test/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/test/images.pt'\n",
    "\n",
    "    if load_indices:\n",
    "        #indices = np.load(f'indices/{checkpoint_name}/{index_name}.npy')\n",
    "        indices = np.load('indices/subj2_common_250_indices.npy')\n",
    "        indices = indices[0:250]\n",
    "        adapter_outputs = adapter_outputs[new_to_ref_indices]\n",
    "        common_dataset = ImageVoxelAdapterDataset(feature_file, image_file, adapter_outputs, indices=indices)\n",
    "    else:\n",
    "        common_dataset = ImageVoxelAdapterDataset(feature_file, image_file, adapter_outputs) #No indices here as you load all 1k common images\n",
    "    \n",
    "    common_dl = DataLoader(common_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "    print(len(common_dataset))\n",
    "    voxel,image,adapter = next(iter(common_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())\n",
    "    print(adapter.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d592c27a-00b7-4dcb-af58-d1482336e3a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test DL\n",
    "\n",
    "feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/images.pt'\n",
    "\n",
    "test_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "print(len(test_dataset))\n",
    "voxel,image = next(iter(test_dl))\n",
    "print(voxel.shape, image.shape)\n",
    "print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d64d51-5ff1-4a72-9c7a-6f554b13ab7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Train DL\n",
    "if run_train:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/images.pt'\n",
    "    \n",
    "    train_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "    print(len(train_dataset))\n",
    "    voxel,image = next(iter(train_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f76176d-50fb-45bf-bb12-dbec6685c90c",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41d71a72-8d39-40dd-9262-937217e8f344",
   "metadata": {},
   "source": [
    "### Low Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fbbd33e-b16e-49d5-a424-bd7a66f4a375",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(blur_scale,blurry_recon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5cf8c29-3ff2-45f4-89d0-cff52fa4c252",
   "metadata": {},
   "outputs": [],
   "source": [
    "if blurry_recon:\n",
    "    from diffusers import AutoencoderKL    \n",
    "    autoenc = AutoencoderKL(\n",
    "        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],\n",
    "        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],\n",
    "        block_out_channels=[128, 256, 512, 512],\n",
    "        layers_per_block=2,\n",
    "        sample_size=256,\n",
    "    )\n",
    "    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')\n",
    "    autoenc.load_state_dict(ckpt)\n",
    "    \n",
    "    autoenc.eval()\n",
    "    autoenc.requires_grad_(False)\n",
    "    autoenc.to(device)\n",
    "    utils.count_params(autoenc)\n",
    "    \n",
    "    from convnext import ConvnextXL\n",
    "    cnx = ConvnextXL(f'{cache_dir}/convnext_xlarge_alpha0.75_fullckpt.pth')\n",
    "    cnx.requires_grad_(False)\n",
    "    cnx.eval()\n",
    "    cnx.to(device)\n",
    "    \n",
    "    mean = torch.tensor([0.485, 0.456, 0.406]).to(device).reshape(1,3,1,1)\n",
    "    std = torch.tensor([0.228, 0.224, 0.225]).to(device).reshape(1,3,1,1)\n",
    "    \n",
    "    blur_augs = AugmentationSequential(\n",
    "        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),\n",
    "        kornia.augmentation.RandomGrayscale(p=0.1),\n",
    "        kornia.augmentation.RandomSolarize(p=0.1),\n",
    "        kornia.augmentation.RandomResizedCrop((224,224), scale=(.9,.9), ratio=(1,1), p=0.3),\n",
    "        data_keys=[\"input\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42db82a0-47f1-4f57-8bed-2464b1247d65",
   "metadata": {},
   "source": [
    "### Clipper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f12643-78ba-4013-9fbc-eb8f2eb952e1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "clip_img_embedder = FrozenOpenCLIPImageEmbedder(\n",
    "    arch=\"ViT-bigG-14\",\n",
    "    version=\"laion2b_s39b_b160k\",\n",
    "    output_tokens=True,\n",
    "    only_tokens=True,\n",
    ")\n",
    "clip_img_embedder.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eac920c-477b-4f5a-8668-5d0262ae93f8",
   "metadata": {},
   "source": [
    "### High Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "831cd9ec-8ead-4dae-bb6f-88c36969bfb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#num_voxels_list = [15724, 14278]\n",
    "clip_seq_dim = 256\n",
    "clip_emb_dim = 1664\n",
    "#hidden_dim = 4096"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1131e0f4-d025-4cc0-a014-fa0e2a85842d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class fMRIModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(fMRIModule, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "        \n",
    "model = fMRIModule()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f67e306c-7972-484d-b9d4-b6fecdc06441",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    # nn.LayerNorm(out_features),\n",
    "                    # nn.GELU(),\n",
    "                    # nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "        \n",
    "model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a0e60c-a044-4354-a439-46e6e359a401",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity check\n",
    "b = torch.randn((2,num_voxels_list[0])).to(device)\n",
    "print(b.shape)\n",
    "print(b.shape, model.ridge(b,0).shape, b[:,0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "959e1d53-7c81-4de0-9396-c21aaa21b04f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from v2_models import BrainNetwork\n",
    "#Set in_dim as anything if you dont care about the lin0 module\n",
    "voxel2clip_kwargs = dict(in_dim=hidden_dim, seq_len=1, h=hidden_dim, out_dim=clip_emb_dim*clip_seq_dim,\\\n",
    "                         clip_size=clip_emb_dim, clip_scale=1, blurry_recon=blurry_recon, n_blocks=4)\n",
    "voxel2clip = BrainNetwork(**voxel2clip_kwargs)\n",
    "model.voxel2clip = voxel2clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "721fbc5e-f3ed-48bc-9b77-cf0b5b719ee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.count_params(model.voxel2clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb057575-3212-4143-8ca4-27c1c3361950",
   "metadata": {},
   "outputs": [],
   "source": [
    "from v2_models import *\n",
    "\n",
    "out_dim = clip_emb_dim\n",
    "depth = 6\n",
    "dim_head = 52\n",
    "heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim\n",
    "timesteps = 100\n",
    "\n",
    "prior_network = PriorNetwork(\n",
    "        dim=out_dim,\n",
    "        depth=depth,\n",
    "        dim_head=dim_head,\n",
    "        heads=heads,\n",
    "        causal=False,\n",
    "        num_tokens = clip_seq_dim,\n",
    "        learned_query_mode=\"pos_emb\"\n",
    "    )\n",
    "\n",
    "model.diffusion_prior = BrainDiffusionPrior(\n",
    "    net=prior_network,\n",
    "    image_embed_dim=out_dim,\n",
    "    condition_on_text_encodings=False,\n",
    "    timesteps=timesteps,\n",
    "    cond_drop_prob=0.2,\n",
    "    image_embed_scale=None,\n",
    ")\n",
    "\n",
    "utils.count_params(model.diffusion_prior)\n",
    "utils.count_params(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9bb66ae-e816-45a6-b2bd-72fe637b0512",
   "metadata": {},
   "source": [
    "## Training Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e985f33a-88e4-45e9-9850-05fd5a1015cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_common and run_train:\n",
    "    num_iterations_per_epoch = len(common_dl)+len(train_dl)\n",
    "elif run_common:\n",
    "    num_iterations_per_epoch = len(common_dl)\n",
    "else:\n",
    "    num_iterations_per_epoch = len(train_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23829007-ea2a-4b25-828f-43754a94975a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations_per_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3cfe563-81c2-4168-95ce-3ac8ae07612b",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "\n",
    "opt_grouped_parameters = [\n",
    "    {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n",
    "]\n",
    "if v2c:\n",
    "    opt_grouped_parameters.extend([   \n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n",
    "    ])\n",
    "if prior:\n",
    "    opt_grouped_parameters.extend([\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ])\n",
    "\n",
    "print(len(opt_grouped_parameters), lr_scheduler_type)\n",
    "\n",
    "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n",
    "\n",
    "if lr_scheduler_type == 'linear':\n",
    "    lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
    "        optimizer,\n",
    "        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),\n",
    "        last_epoch=-1\n",
    "    )\n",
    "elif lr_scheduler_type == 'cycle':\n",
    "    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))\n",
    "    print(\"total_steps\", total_steps)\n",
    "    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
    "        optimizer, \n",
    "        max_lr=max_lr,\n",
    "        total_steps=total_steps,\n",
    "        final_div_factor=1000,\n",
    "        last_epoch=-1, pct_start=2/num_epochs\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "412d595d-5501-4471-8196-a7e81e59bf61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_ckpt(tag):\n",
    "    ckpt_path = outdir + f'/{tag}.pth'\n",
    "    print(f'saving {ckpt_path}', flush=True)\n",
    "    try:\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'lr_scheduler_state_dict': lr_scheduler.state_dict(),  \n",
    "            'train_losses': losses,\n",
    "            'val_losses': val_losses,\n",
    "            'lrs': lrs,\n",
    "            }, ckpt_path)\n",
    "    except Exception as e:  # It's a good practice to catch specific exceptions and possibly log them\n",
    "        print(f\"Couldn't save due to {e}... moving on to prevent crashing.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ece4edf-3029-4cc5-9eff-9035bbe74b23",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading {checkpoint_dir}/{checkpoint_tag}.pth\")\n",
    "checkpoint = torch.load(f'{checkpoint_dir}/{checkpoint_tag}.pth', map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75abaf6c-1826-4ede-af96-5d760d39b079",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set strict to false as you wont have any weights for the new adapter\n",
    "state_dict = checkpoint['model_state_dict']\n",
    "model.load_state_dict(state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a857ad4-8ee2-4700-8614-c81c330bc2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity check again\n",
    "feature_file = f'data/new_dl/subj{ref_sub:02d}/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{ref_sub:02d}/test/images.pt'\n",
    "\n",
    "ref_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "ref_dl = DataLoader(ref_dataset, batch_size=16, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffcbeb8c-15fe-4c3b-a7f9-84dd558de366",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.ridge.linears[0].eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    for test_i, (voxel, image) in enumerate(ref_dl):\n",
    "        voxel = voxel.to(device)\n",
    "        voxel = torch.mean(voxel, dim=1)\n",
    "        voxel_ridge = model.ridge(voxel,0)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf30d89c-099a-4baf-a73d-4938bba36f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "voxel_ridge[0:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c043dc8-5431-4c16-a664-691a0a59feca",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_outputs[0:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdd50e56-1646-4ea4-a963-4e439c904e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in model.ridge.linears[0].parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c6eaef-e8ee-4fea-9a4a-20798db4d2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0\n",
    "losses, val_losses, lrs = [], [], []\n",
    "nce_losses, val_nce_losses = [], []\n",
    "sim_losses, val_sim_losses = [], []\n",
    "best_val_loss = 1e9\n",
    "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n",
    "if hidden:\n",
    "    prior_mult = 30\n",
    "else:\n",
    "    prior_mult = .03\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40dea10c-064d-4f80-bce3-91d2e03b7cde",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87599d45-402e-49b2-afe9-980cf8109ecb",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24da58c2-1b63-45ca-b798-f7ed1de8cf80",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model: {model_name}\")\n",
    "print(f\"Reference: {checkpoint_name}/{checkpoint_tag}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "639fe552-c0d8-4ef5-aa9b-a2251ea4b839",
   "metadata": {},
   "source": [
    "### Adapter Alignment Step 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04fa6f74-814a-4ec5-bb0c-ba280ef434c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if adaptalign:\n",
    "    opt_grouped_parameters2 = [\n",
    "        {'params': [p for n, p in model.ridge.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.ridge.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ]\n",
    "    opt_ridge = torch.optim.AdamW(opt_grouped_parameters2, lr=3e-4)\n",
    "    \n",
    "    common_dl2 = DataLoader(common_dataset, batch_size=1000, shuffle=False, num_workers=0)\n",
    "    epoch=0\n",
    "    epochs = 200\n",
    "    mse = nn.MSELoss()\n",
    "    mse_scale=1\n",
    "    for epoch in range(epochs):\n",
    "        loss_adapter_total=0\n",
    "        model.ridge.train()\n",
    "        model.voxel2clip.eval()\n",
    "        model.diffusion_prior.eval()\n",
    "        for common_i, (voxel, image, adapter_gt) in enumerate(common_dl2):\n",
    "            opt_ridge.zero_grad()\n",
    "            loss=0.\n",
    "            voxel = voxel.to(device)\n",
    "            voxel = torch.mean(voxel,axis=1).float()\n",
    "            #print(voxel.shape)\n",
    "            voxel = voxel.to(device)\n",
    "            adapter_gt = adapter_gt.to(device)\n",
    "            voxel_ridge = model.ridge(voxel, 1)\n",
    "            if adaptalign:\n",
    "                loss_adapter = mse(voxel_ridge, adapter_gt)\n",
    "                loss_adapter_total += loss_adapter.item()\n",
    "                loss_adapter *= mse_scale\n",
    "                loss += loss_adapter\n",
    "            \n",
    "            loss.backward()\n",
    "            opt_ridge.step()\n",
    "        \n",
    "            losses.append(loss.item())\n",
    "        print(\"Adapter Loss:\",loss_adapter_total/(common_i+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e195366c-d326-4c5b-afc0-e4127f5562ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #Optional save this checkpoint if you want to run ablations\n",
    "# save_ckpt(f'aaonly_epoch0')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c50354de-7b2a-4374-b5de-ce70c59c70ad",
   "metadata": {},
   "source": [
    "### End-to-End Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1727c9cc-ee96-4644-8823-ba420c6b2e56",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch=0\n",
    "print(epoch,ckpt_interval,mixup_pct)\n",
    "max_retries = 5  # Define maximum number of retries\n",
    "counter=0\n",
    "best_epoch=0\n",
    "adapt_epochs=0\n",
    "mse_scale=10\n",
    "mse = nn.MSELoss()\n",
    "l1 = nn.L1Loss()\n",
    "def scale_tau(tau_ref: float, B: int, ref_B: int = 16, alpha: float = 0.5):\n",
    "    # alpha=0.5 (sqrt) is a good default; alpha=1.0 is stronger\n",
    "    return tau_ref * (B / ref_B) ** alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7e8886-c417-43f0-b13e-fd2b6a14bbe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n",
    "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n",
    "for epoch in progress_bar:\n",
    "    model.train()\n",
    "\n",
    "    sims_base = 0.\n",
    "    val_sims_base = 0.\n",
    "    recon_cossim=0.\n",
    "    test_recon_cossim=0.\n",
    "    fwd_percent_correct = 0.\n",
    "    bwd_percent_correct = 0.\n",
    "    val_fwd_percent_correct = 0.\n",
    "    val_bwd_percent_correct = 0.\n",
    "    test_loss_cossim_total = 0.\n",
    "    loss_cossim_total = 0.\n",
    "    loss_nce_sum = 0.\n",
    "    loss_prior_sum = 0.\n",
    "    val_loss_nce_sum = 0.\n",
    "    val_loss_prior_sum = 0.\n",
    "    loss_adapter_total = 0.\n",
    "    test_loss_adapter_total = 0.\n",
    "    loss_blurry_total = 0.\n",
    "    loss_blurry_cont_total = 0.\n",
    "\n",
    "    common_i=0\n",
    "    train_i=0\n",
    "    if run_common:\n",
    "        for common_i, (voxel, image, adapter_gt) in enumerate(common_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                if common_i==3:\n",
    "                    print(\"Inside common dl\")\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                #print(voxel.shape)\n",
    "                voxel = voxel.to(device)\n",
    "                adapter_gt = adapter_gt.to(device)                \n",
    "                if (epoch>=adapt_epochs) and epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                else:\n",
    "                    #Make sure this is resize only\n",
    "                    image = img_augment2(image)\n",
    "                image = torch.tensor(image, dtype=data_type).to(device)\n",
    "                clip_target = clip_img_embedder(image).to(device)\n",
    "        \n",
    "                voxel_ridge = model.ridge(voxel, 1)\n",
    "                #print(voxel_ridge.shape, adapter_gt.shape)\n",
    "                clip_voxels, clip_voxels_proj, blurry_image_enc_ = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                    \n",
    "                if blurry_recon:     \n",
    "                    image_enc_pred, transformer_feats = blurry_image_enc_\n",
    "    \n",
    "                    image_enc = autoenc.encode(2*image-1).latent_dist.mode() * 0.18215\n",
    "                    image_enc = image_enc.float()\n",
    "                    loss_blurry = l1(image_enc_pred, image_enc)\n",
    "                    loss_blurry_total += loss_blurry.item()\n",
    "    \n",
    "                    if epoch < int(mixup_pct * num_epochs):\n",
    "                        image_enc_shuf = image_enc[perm]\n",
    "                        betas_shape = [-1] + [1]*(len(image_enc.shape)-1)\n",
    "                        image_enc[select] = image_enc[select] * betas[select].reshape(*betas_shape) + \\\n",
    "                            image_enc_shuf[select] * (1 - betas[select]).reshape(*betas_shape)\n",
    "    \n",
    "                    image_norm = (image - mean)/std\n",
    "                    with torch.cuda.amp.autocast(enabled=False):\n",
    "                        image_aug = blur_augs(image.float())\n",
    "                    image_aug = image_aug.clamp_(0, 1)\n",
    "                    image_aug = (image_aug - mean)/std                    \n",
    "                    \n",
    "                    _, cnx_embeds = cnx(image_norm)\n",
    "                    _, cnx_aug_embeds = cnx(image_aug)\n",
    "    \n",
    "                    cont_loss = utils.soft_cont_loss(\n",
    "                        nn.functional.normalize(transformer_feats.reshape(-1, transformer_feats.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_aug_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        temp=0.2)\n",
    "                    loss_blurry_cont_total += cont_loss.item()\n",
    "    \n",
    "                    loss += (loss_blurry + 0.0*cont_loss) * blur_scale #/.18215\n",
    "\n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                \n",
    "                \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "                loss_adapter = mse(voxel_ridge, adapter_gt)\n",
    "                loss_adapter_total += loss_adapter.item()\n",
    "                if adaptalign:\n",
    "                    loss_adapter *= mse_scale\n",
    "                    loss += loss_adapter\n",
    "\n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "        \n",
    "                # gather batches across multi-gpu if there's multiple\n",
    "                # clip_voxel_gather = accelerator.gather(clip_voxels_norm.view(len(voxel),-1).contiguous())\n",
    "                # clip_target_gather = accelerator.gather(clip_target_norm.view(len(voxel),-1).contiguous())\n",
    "        \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    if run_train:\n",
    "        for train_i, (voxel, image) in enumerate(train_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                print(train_i)\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                repeat_index = train_i % 3\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                else:\n",
    "                    image = img_augment2(image)\n",
    "                voxel = voxel[:,repeat_index].float()\n",
    "                #print(voxel.shape)\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                image = torch.tensor(image, dtype=data_type).to(device)\n",
    "                clip_target = clip_img_embedder(image).to(device)\n",
    "                \n",
    "                #voxel_ridge = model.ridge(voxel, subj-1)\n",
    "                voxel_ridge = model.ridge(voxel, 1)\n",
    "                clip_voxels, clip_voxels_proj, blurry_image_enc_ = model.voxel2clip(voxel_ridge)\n",
    "    \n",
    "                #clip_voxels, clip_voxels_proj = diffusion_prior.module.voxel2clip(voxel) if distributed else diffusion_prior.voxel2clip(voxel)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            #aligned_clip_voxels /= diffusion_prior.module.image_embed_scale if distributed else diffusion_prior.image_embed_scale\n",
    "                            \n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                    \n",
    "                if blurry_recon:     \n",
    "                    image_enc_pred, transformer_feats = blurry_image_enc_\n",
    "    \n",
    "                    image_enc = autoenc.encode(2*image-1).latent_dist.mode() * 0.18215\n",
    "                    image_enc = image_enc.float()\n",
    "                    loss_blurry = l1(image_enc_pred, image_enc)\n",
    "                    loss_blurry_total += loss_blurry.item()\n",
    "    \n",
    "                    if epoch < int(mixup_pct * num_epochs):\n",
    "                        image_enc_shuf = image_enc[perm]\n",
    "                        betas_shape = [-1] + [1]*(len(image_enc.shape)-1)\n",
    "                        image_enc[select] = image_enc[select] * betas[select].reshape(*betas_shape) + \\\n",
    "                            image_enc_shuf[select] * (1 - betas[select]).reshape(*betas_shape)\n",
    "    \n",
    "                    image_norm = (image - mean)/std\n",
    "                    with torch.cuda.amp.autocast(enabled=False):\n",
    "                        image_aug = blur_augs(image.float())\n",
    "                    image_aug = image_aug.clamp_(0, 1)\n",
    "                    image_aug = (image_aug - mean)/std\n",
    "                    \n",
    "                    _, cnx_embeds = cnx(image_norm)\n",
    "                    _, cnx_aug_embeds = cnx(image_aug)\n",
    "    \n",
    "                    cont_loss = utils.soft_cont_loss(\n",
    "                        nn.functional.normalize(transformer_feats.reshape(-1, transformer_feats.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_aug_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        temp=0.2)\n",
    "                    loss_blurry_cont_total += cont_loss.item()\n",
    "    \n",
    "                    loss += (loss_blurry + 0.0*cont_loss) * blur_scale #/.18215\n",
    "    \n",
    "                \n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "        \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "                \n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "                \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    model.eval()\n",
    "    for val_i, (voxel, image) in enumerate(test_dl):\n",
    "        with torch.no_grad():\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                voxel = voxel.to(device)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                else:\n",
    "                    image = img_augment2(image)\n",
    "\n",
    "                image = torch.tensor(image, dtype=data_type).to(device)\n",
    "                clip_target = clip_img_embedder(image).to(device)\n",
    "            \n",
    "                #voxel_ridge = model.ridge(voxel, subj-1)\n",
    "                voxel_ridge = model.ridge(voxel, 1)\n",
    "                \n",
    "                clip_voxels, clip_voxels_proj, blurry_image_enc_ = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            val_loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            if torch.isnan(val_loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            test_recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "\n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    val_loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=None, betas=None, select=None)\n",
    "                else:\n",
    "                    val_loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = val_loss_nce + (prior_mult * val_loss_prior)\n",
    "                elif v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss = val_loss_nce\n",
    "                elif prior:\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = prior_mult * val_loss_prior\n",
    "                utils.check_loss(val_loss)\n",
    "                \n",
    "                val_losses.append(val_loss.item())\n",
    "\n",
    "                val_sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                \n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "\n",
    "    if local_rank==0:\n",
    "        val_loss = np.mean(val_losses[-(val_i+1):])\n",
    "        if val_loss < best_val_loss:\n",
    "            best_epoch = epoch\n",
    "            best_val_loss = val_loss\n",
    "        if utils.is_interactive():\n",
    "            clear_output(wait=True)\n",
    "            \n",
    "        logs = {\"train/loss\": np.mean(losses[-(train_i+common_i+1):]),\n",
    "            \"val/loss\": np.mean(val_losses[-(val_i+1):]),\n",
    "            \"train/lr\": lrs[-1],\n",
    "            \"train/num_steps\": len(losses),\n",
    "            \"val/num_steps\": len(val_losses),\n",
    "            \"train/cosine_sim_base\": sims_base / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_base\": val_sims_base / (val_i + 1),\n",
    "            \"train/cosine_sim_prior\": recon_cossim / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_prior\": test_recon_cossim / (val_i + 1),\n",
    "            \"train/loss_blurry_total\": loss_blurry_total / (train_i + common_i + 1),\n",
    "            \"train/loss_blurry_cont_total\": loss_blurry_cont_total / (train_i + common_i + 1),\n",
    "            \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + common_i+1),\n",
    "            \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + common_i+1),\n",
    "            \"val/val_fwd_pct_correct\": val_fwd_percent_correct / (val_i + 1),\n",
    "            \"val/val_bwd_pct_correct\": val_bwd_percent_correct / (val_i + 1),\n",
    "            \"train/adapter_mse\": loss_adapter_total / (common_i + 1),\n",
    "            \"train/loss_nce\": loss_nce_sum / (train_i + common_i+1),\n",
    "            \"train/loss_prior\": loss_prior_sum / (train_i + common_i+ 1),\n",
    "            \"val/loss_nce\": val_loss_nce_sum / (val_i + 1),\n",
    "            \"val/loss_prior\": val_loss_prior_sum / (val_i + 1)}\n",
    "        progress_bar.set_postfix(**logs)\n",
    "\n",
    "        # Save model checkpoint and reconstruct\n",
    "        #save_ckpt(f'last')\n",
    "        if epoch % ckpt_interval == 0:\n",
    "            save_ckpt(f'mid_{epoch}')\n",
    "\n",
    "print(\"\\n===Finished!===\\n\")\n",
    "print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f} at epoch: {best_epoch}')\n",
    "if not utils.is_interactive():\n",
    "    sys.exit(0)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3399f6c6-b974-48fd-ab64-610d95b5710c",
   "metadata": {},
   "outputs": [],
   "source": [
    "logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1b52e84-79f2-4247-a249-a417a75c702e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
