{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab08bd8-f1d7-4f20-94f5-d70af95d8329",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import gc\n",
    "import torch.nn.functional as F\n",
    "from dataloaders import ImageVoxelDataset,ImageVoxelAdapterDataset\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "from torch.utils.data import Subset\n",
    "\n",
    "import utils\n",
    "from models import Clipper, BrainNetwork, BrainDiffusionPrior, BrainDiffusionPriorOld, VersatileDiffusionPriorNetwork"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a90635af-e6b3-4014-bda8-c5c2832d8671",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from IPython.display import clear_output # function to clear print outputs in cell\n",
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f627ab-5c2d-4211-b4cc-a7ac7f7800b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = torch.float32 # change depending on your mixed_precision\n",
    "local_rank = 0\n",
    "world_size = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fcc416-6608-4e8c-8b83-f8c3b3fb07a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12caad3b-83e0-499f-9bd3-73b5fb217834",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "714a5333-40ab-48b2-9aa8-5678436c7f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls indices/subj1_nl_sclip_basictest2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f4b588-bac0-47cf-a2da-204953c66284",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='subj1_nl_sclip_basictest2_finetune_s5_aamax_iselect20_100'\n",
    "data_path='data/new_dl/'\n",
    "checkpoint_name = 'subj1_nl_sclip_basictest2'\n",
    "subj=5\n",
    "ref_sub=1\n",
    "batch_size=16\n",
    "hidden=True\n",
    "#resume_from_ckpt=False\n",
    "mixup_pct=0.0\n",
    "norm_embs=True\n",
    "adaptalign=True\n",
    "use_image_aug=False\n",
    "num_epochs=201\n",
    "prior=True\n",
    "v2c=True\n",
    "lr_scheduler_type='cycle'\n",
    "ckpt_saving=True\n",
    "ckpt_interval=200\n",
    "run_common=True\n",
    "run_train=False #Set this to false if you are training with limited data/on common images only\n",
    "seed=42\n",
    "max_lr=6e-5\n",
    "use_projector=True\n",
    "cache_dir='data/fmri/cache'\n",
    "hidden_dim=4096\n",
    "checkpoint_dir = os.path.abspath(f'train_logs/{checkpoint_name}')\n",
    "checkpoint_tag = 'mid_200'\n",
    "load_indices=True\n",
    "clip_seq_dim = 257\n",
    "clip_emb_dim = 768\n",
    "clip_size = clip_emb_dim\n",
    "index_name = 'subj1_indices_20d_50w_99.npy'\n",
    "clip_variant='ViT-L/14'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de9dbddc-f17d-4e07-97a2-dd8060555d30",
   "metadata": {},
   "outputs": [],
   "source": [
    "if subj == 1:\n",
    "    num_voxels = 15724\n",
    "elif subj == 2:\n",
    "    num_voxels = 14278\n",
    "elif subj == 3:\n",
    "    num_voxels = 15226\n",
    "elif subj == 4:\n",
    "    num_voxels = 13153\n",
    "elif subj == 5:\n",
    "    num_voxels = 13039\n",
    "elif subj == 6:\n",
    "    num_voxels = 17907\n",
    "elif subj == 7:\n",
    "    num_voxels = 12682\n",
    "elif subj == 8:\n",
    "    num_voxels = 14386\n",
    "#Change this if you're changing your reference subject\n",
    "num_voxels_list = [15724,num_voxels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cca2060-ba09-457d-acf6-4b992dcf48c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = os.path.abspath(f'train_logs/{model_name}')\n",
    "if not os.path.exists(outdir):\n",
    "    os.makedirs(outdir,exist_ok=True)\n",
    "if use_image_aug:\n",
    "    import kornia\n",
    "    from kornia.augmentation.container import AugmentationSequential\n",
    "    img_augment = AugmentationSequential(\n",
    "        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n",
    "        kornia.augmentation.Resize((224, 224)),\n",
    "        kornia.augmentation.RandomHorizontalFlip(p=0.5),\n",
    "        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n",
    "        kornia.augmentation.RandomGrayscale(p=0.3),\n",
    "        data_keys=[\"input\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbe144bc-7ada-4ed5-8b6a-4b4b2477ae5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_voxels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26ee2a20-607a-4502-bc8a-7fc6ed25aa23",
   "metadata": {},
   "source": [
    "## Build Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f42c0c8-bb7b-45d4-b1a6-ee4d15e4f94c",
   "metadata": {},
   "source": [
    "### Adapter GT for Subject of Your Choice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "618b0425-4151-4a46-867d-16df6a5ff9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    nn.LayerNorm(out_features),\n",
    "                    nn.GELU(),\n",
    "                    # nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "        \n",
    "ridge = RidgeRegression([num_voxels_list[0]], out_features=hidden_dim).to(device)\n",
    "ridge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfb295cc-4eaf-4620-b312-ec54f1f0053c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls train_logs/$checkpoint_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35a23ef2-57d5-4742-9485-7befb4179ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading {checkpoint_dir}/{checkpoint_tag}.pth\")\n",
    "checkpoint = torch.load(f\"{checkpoint_dir}/{checkpoint_tag}.pth\", map_location='cpu')\n",
    "ridge_state_dict = {k.replace('ridge.', ''): v for k, v in checkpoint['model_state_dict'].items() if k.startswith('ridge.')}\n",
    "for k,v in ridge_state_dict.items():\n",
    "    print(k,v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2a3fab9-bad0-4409-9487-8b346ee361e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ridge.load_state_dict(ridge_state_dict, strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a405d23e-4d29-4ec6-a109-f170471cd30e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Save some space\n",
    "del checkpoint\n",
    "del ridge_state_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40a17b3d-bf8f-437d-88a0-35c38926a91f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load the reference subject common images\n",
    "feature_file = f'data/new_dl/subj{ref_sub:02d}/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{ref_sub:02d}/test/images.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7091f68-e4d0-48b8-a338-f6c9e6a53384",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "# Create the DataLoader to handle batching and optional shuffling\n",
    "val_dl = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2828b01-657c-4217-a8ae-440b25f6b9bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "ridge.eval()\n",
    "adapter_outputs = []\n",
    "with torch.no_grad():\n",
    "    for test_i, (voxel, image) in enumerate(val_dl):\n",
    "        voxel = voxel.to(device)\n",
    "        voxel = torch.mean(voxel,axis=1).float()\n",
    "        #print(voxel.shape)\n",
    "        voxel_ridge = ridge(voxel,0)\n",
    "        adapter_outputs.append(voxel_ridge)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc41b8dc-c9e3-4e7f-8bb2-e1baae58aa57",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_outputs = torch.concatenate(adapter_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d0d43a-7a21-4634-9191-4768d17703d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_outputs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15468e38-5b04-4d2d-8dc4-4c227c4f43c2",
   "metadata": {},
   "source": [
    "### Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8da947c-2d7a-4ad1-ae5b-882919b16c80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# In case the mapping between subjects is not clean\n",
    "ref_image_ids = np.load(f'data/new_dl/subj{ref_sub:02d}/test/image_ids.npy')\n",
    "refid_to_index = {img_id: idx for idx, img_id in enumerate(ref_image_ids)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1727f871-b2e6-438d-bd1a-6b832039bcde",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_image_ids = np.load(f'data/new_dl/subj0{subj}/test/image_ids.npy')\n",
    "new_to_ref_indices = [refid_to_index[img_id] for img_id in new_image_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa12a039-f443-4a61-9813-e771a917d8be",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.count_nonzero(new_image_ids == ref_image_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ffd5a34-ea77-4fee-a674-470cbb716359",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Common DL\n",
    "if run_common:\n",
    "    #Load indices here if chosen:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/test/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/test/images.pt'\n",
    "\n",
    "    if load_indices:\n",
    "        indices = np.load(f'indices/{checkpoint_name}/{index_name}')\n",
    "        # indices = np.load('indices/subj2_common_250_indices.npy')\n",
    "        indices = indices[0:100]\n",
    "        adapter_outputs = adapter_outputs[new_to_ref_indices]\n",
    "        common_dataset = ImageVoxelAdapterDataset(feature_file, image_file, adapter_outputs, indices=indices)\n",
    "    else:\n",
    "        common_dataset = ImageVoxelAdapterDataset(feature_file, image_file, adapter_outputs) #No indices here as you load all 1k common images\n",
    "    \n",
    "    common_dl = DataLoader(common_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "    print(len(common_dataset))\n",
    "    voxel,image,adapter = next(iter(common_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())\n",
    "    print(adapter.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8914a9f4-5ba8-47b3-b887-ca520f56e9e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test DL\n",
    "\n",
    "feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/images.pt'\n",
    "\n",
    "test_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "print(len(test_dataset))\n",
    "voxel,image = next(iter(test_dl))\n",
    "print(voxel.shape, image.shape)\n",
    "print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3b3b3b2-c45b-4a02-9e53-066f239c8c00",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Train DL\n",
    "if run_train:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/images.pt'\n",
    "    \n",
    "    train_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "    N = 3000\n",
    "    train_subset = Subset(train_dataset, list(range(N)))\n",
    "    train_dl = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "    print(len(train_subset))\n",
    "    voxel,image = next(iter(train_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68b4992a-2d5f-467c-88e7-a9ecf5ef92c6",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36aad93f-96f4-4915-a3e5-9057161e6582",
   "metadata": {},
   "source": [
    "### Clipper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eef37bcc-99c5-4515-b582-faae6d092e21",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Creating Clipper...')\n",
    "if hidden:\n",
    "    print(\"Using hidden layer CLIP space (Versatile Diffusion)\")\n",
    "    if not norm_embs:\n",
    "        print(\"WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!\")\n",
    "    clip_extractor = Clipper(clip_variant, device=device, hidden_state=True, norm_embs=norm_embs)\n",
    "    out_dim = 257 * clip_size\n",
    "else:\n",
    "    print(\"Using final layer CLIP space (Stable Diffusion Img Variations)\")\n",
    "    if norm_embs:\n",
    "        print(\"WARNING: YOU WANT UN-NORMED EMBEDDINGS FOR IMG VARIATIONS!\")\n",
    "    clip_extractor = Clipper(clip_variant, device=device, hidden_state=False, norm_embs=norm_embs)\n",
    "    out_dim = clip_size\n",
    "print(\"out_dim:\",out_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cac1419-ac8e-437b-ae70-a351b4ede1db",
   "metadata": {},
   "source": [
    "### High Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c36d54-12a0-4381-85d9-7d70b56da2fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class fMRIModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(fMRIModule, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "        \n",
    "model = fMRIModule()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27d7e574-a759-425e-8a01-2c60cd297900",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    nn.LayerNorm(out_features),\n",
    "                    nn.GELU(),\n",
    "                    #nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "\n",
    "model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d858118-8b15-41c6-ac2f-d599753a6f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity test\n",
    "b = torch.randn((2,num_voxels_list[0])).to(device)\n",
    "print(b.shape)\n",
    "print(b.shape, model.ridge(b,0).shape, b[:,0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9dc66e8-1952-4a7b-b10d-a4b139bab8fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set in_dim as anything if you dont care about the lin0 module\n",
    "voxel2clip_kwargs = dict(in_dim=hidden_dim,out_dim=clip_emb_dim*clip_seq_dim,\\\n",
    "                         clip_size=clip_emb_dim,use_projector=use_projector, ext_ridge=True)\n",
    "voxel2clip = BrainNetwork(**voxel2clip_kwargs)\n",
    "model.voxel2clip = voxel2clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75e1155e-c333-41ad-bb1d-67630531d362",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup prior network\n",
    "out_dim = clip_emb_dim\n",
    "depth = 6\n",
    "dim_head = 64\n",
    "heads = clip_emb_dim//64 # heads * dim_head = 12 * 64 = 768\n",
    "out_dim = clip_emb_dim\n",
    "\n",
    "#There is a non hidden version here in ME1 but we never use it\n",
    "guidance_scale = 3.5\n",
    "timesteps = 100\n",
    "prior_network = VersatileDiffusionPriorNetwork(\n",
    "        dim=out_dim,\n",
    "        depth=depth,\n",
    "        dim_head=dim_head,\n",
    "        heads=heads,\n",
    "        causal=False,\n",
    "        num_tokens = clip_seq_dim,\n",
    "        learned_query_mode=\"pos_emb\"\n",
    "    ).to(device)\n",
    "print(\"prior_network loaded\")\n",
    "\n",
    "# custom version that can fix seeds\n",
    "diffusion_prior = BrainDiffusionPrior(\n",
    "    net=prior_network,\n",
    "    image_embed_dim=out_dim,\n",
    "    condition_on_text_encodings=False,\n",
    "    timesteps=timesteps,\n",
    "    cond_drop_prob=0.2,\n",
    "    image_embed_scale=None,\n",
    ").to(device)\n",
    "\n",
    "model.diffusion_prior = diffusion_prior\n",
    "print(\"params of diffusion prior:\")\n",
    "if local_rank==0:\n",
    "    utils.count_params(model.diffusion_prior)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e7ee58c-40cf-4de7-abdb-2f6ac4f03b59",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.count_params(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759bb9e4-e125-478f-a3de-4435f6401b44",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57d9c17a-86f5-45bd-857f-96f9dbc29df4",
   "metadata": {},
   "source": [
    "## Training Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51068dff-9dd7-4a95-8b50-451baddd15ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_common and run_train:\n",
    "    num_iterations_per_epoch = len(common_dl)+len(train_dl)\n",
    "elif run_common:\n",
    "    num_iterations_per_epoch = len(common_dl)\n",
    "else:\n",
    "    num_iterations_per_epoch = len(train_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7bd521e-264d-4eab-af3b-07ac9ada8b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations_per_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e70bb3-7bd2-4cfa-9afa-89f60a18bc75",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "opt_grouped_parameters = [\n",
    "    {'params': [p for n, p in model.ridge.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "    {'params': [p for n, p in model.ridge.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "]\n",
    "if v2c:\n",
    "    opt_grouped_parameters.extend([   \n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n",
    "    ])\n",
    "if prior:\n",
    "    opt_grouped_parameters.extend([\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ])\n",
    "\n",
    "print(len(opt_grouped_parameters), lr_scheduler_type)\n",
    "\n",
    "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n",
    "\n",
    "if lr_scheduler_type == 'linear':\n",
    "    lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
    "        optimizer,\n",
    "        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),\n",
    "        last_epoch=-1\n",
    "    )\n",
    "elif lr_scheduler_type == 'cycle':\n",
    "    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))\n",
    "    print(\"total_steps\", total_steps)\n",
    "    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
    "        optimizer, \n",
    "        max_lr=max_lr,\n",
    "        total_steps=total_steps,\n",
    "        final_div_factor=1000,\n",
    "        last_epoch=-1, pct_start=2/num_epochs\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "242827a9-156e-417e-a712-d51a89f2258a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_ckpt(tag):\n",
    "    ckpt_path = outdir + f'/{tag}.pth'\n",
    "    print(f'saving {ckpt_path}', flush=True)\n",
    "    try:\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),  \n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'lr_scheduler_state_dict': lr_scheduler.state_dict(),  \n",
    "            'train_losses': losses,\n",
    "            'val_losses': val_losses,\n",
    "            'lrs': lrs,\n",
    "            }, ckpt_path)\n",
    "    except Exception as e: \n",
    "        print(f\"Couldn't save due to {e}... moving on to prevent crashing.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f36256f4-aec2-4198-9daa-6c72aad4d74d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading {checkpoint_dir}/{checkpoint_tag}.pth\")\n",
    "checkpoint = torch.load(f'{checkpoint_dir}/{checkpoint_tag}.pth', map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69e4c054-d475-426b-b5f4-eb0d599d910a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set strict to false as you wont have any weights for the new adapter\n",
    "state_dict = checkpoint['model_state_dict']\n",
    "model.load_state_dict(state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28546fe5-90c5-48ae-b7a4-dd6758b57fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity check again\n",
    "feature_file = f'data/new_dl/subj{ref_sub:02d}/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{ref_sub:02d}/test/images.pt'\n",
    "\n",
    "ref_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "ref_dl = DataLoader(ref_dataset, batch_size=16, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f25e423-59a1-469d-bdc2-919e1b0899ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.ridge.linears[0].eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    for test_i, (voxel, image) in enumerate(ref_dl):\n",
    "        voxel = voxel.to(device)\n",
    "        voxel = torch.mean(voxel, dim=1)\n",
    "        voxel_ridge = model.ridge(voxel,0)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c1e757-4276-4ad1-b1e5-ac5bef9c7648",
   "metadata": {},
   "outputs": [],
   "source": [
    "voxel_ridge[0:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c914b13c-c85a-4009-a578-eb2e8098ee28",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_outputs[0:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6483df6f-81b5-4138-a8ab-de2eb8d8d821",
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in model.ridge.linears[0].parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5990a997-612b-4929-a0bd-42d9f35d462d",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0\n",
    "losses, val_losses, lrs = [], [], []\n",
    "nce_losses, val_nce_losses = [], []\n",
    "sim_losses, val_sim_losses = [], []\n",
    "best_val_loss = 1e9\n",
    "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n",
    "if hidden:\n",
    "    prior_mult = 30\n",
    "else:\n",
    "    prior_mult = .03\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c89983f1-5f41-40dd-9b5e-3d385654fba6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2f53828-782f-4d4f-ae45-f1a611606e68",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e3e388-77c4-4755-bdc9-800e6453823e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model: {model_name}\")\n",
    "print(f\"Reference: {checkpoint_name}/{checkpoint_tag}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29b81ab3-9107-4fc8-b640-cbae123ffb8c",
   "metadata": {},
   "source": [
    "### Adapter Alignment Step 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3db197d1-beaf-4f6d-97be-b9c6e5b2983a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if adaptalign:\n",
    "    opt_grouped_parameters2 = [\n",
    "        {'params': [p for n, p in model.ridge.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.ridge.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ]\n",
    "    opt_ridge = torch.optim.AdamW(opt_grouped_parameters2, lr=3e-3)\n",
    "    \n",
    "    common_dl2 = DataLoader(common_dataset, batch_size=1000, shuffle=False, num_workers=0)\n",
    "    epoch=0\n",
    "    epochs = 400\n",
    "    mse = nn.MSELoss()\n",
    "    mse_scale=1\n",
    "    for epoch in range(epochs):\n",
    "        loss_adapter_total=0\n",
    "        model.ridge.train()\n",
    "        model.voxel2clip.eval()\n",
    "        model.diffusion_prior.eval()\n",
    "        for common_i, (voxel, image, adapter_gt) in enumerate(common_dl2):\n",
    "            opt_ridge.zero_grad()\n",
    "            loss=0.\n",
    "            voxel = voxel.to(device)\n",
    "            voxel = torch.mean(voxel,axis=1).float()\n",
    "            #print(voxel.shape)\n",
    "            adapter_gt = adapter_gt.to(device)\n",
    "            voxel_ridge = model.ridge(voxel, 1)\n",
    "            if adaptalign:\n",
    "                loss_adapter = mse(voxel_ridge, adapter_gt)\n",
    "                loss_adapter_total += loss_adapter.item()\n",
    "                loss_adapter *= mse_scale\n",
    "                loss += loss_adapter\n",
    "            \n",
    "            loss.backward()\n",
    "            opt_ridge.step()\n",
    "        \n",
    "            losses.append(loss.item())\n",
    "        print(\"Adapter Loss:\",loss_adapter_total/(common_i+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83f0bcb8-b785-47d3-9b6f-d70b04a8a7bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_gt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dc63ffa-dba3-4a8e-88c4-1d9e0096bb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #Optional save this checkpoint if you want to run ablations\n",
    "#save_ckpt(f'aaonly_epoch0')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3783e63c-2989-4d64-8800-c4ba2eff5558",
   "metadata": {},
   "source": [
    "### End-to-End Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "281b507a-afc5-4025-8b24-01dfaf0b36b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch=0\n",
    "print(epoch,ckpt_interval,mixup_pct)\n",
    "max_retries = 5  # Define maximum number of retries\n",
    "counter=0\n",
    "best_epoch=0\n",
    "adapt_epochs=0\n",
    "mse_scale=10\n",
    "mse = nn.MSELoss()\n",
    "l1 = nn.L1Loss()\n",
    "def scale_tau(tau_ref: float, B: int, ref_B: int = 16, alpha: float = 0.5):\n",
    "    # alpha=0.5 (sqrt) is a good default; alpha=1.0 is stronger\n",
    "    return tau_ref * (B / ref_B) ** alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d8fd54f-3ec2-4c37-a9a8-ebb60f7d5444",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n",
    "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n",
    "for epoch in progress_bar:\n",
    "    model.train()\n",
    "\n",
    "    sims_base = 0.\n",
    "    val_sims_base = 0.\n",
    "    recon_cossim=0.\n",
    "    test_recon_cossim=0.\n",
    "    fwd_percent_correct = 0.\n",
    "    bwd_percent_correct = 0.\n",
    "    val_fwd_percent_correct = 0.\n",
    "    val_bwd_percent_correct = 0.\n",
    "    test_loss_cossim_total = 0.\n",
    "    loss_cossim_total = 0.\n",
    "    loss_nce_sum = 0.\n",
    "    loss_prior_sum = 0.\n",
    "    val_loss_nce_sum = 0.\n",
    "    val_loss_prior_sum = 0.\n",
    "    loss_adapter_total = 0.\n",
    "    test_loss_adapter_total = 0.\n",
    "\n",
    "    common_i=0\n",
    "    train_i=0\n",
    "    if run_common:\n",
    "        for common_i, (voxel, image, adapter_gt) in enumerate(common_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                if common_i==3:\n",
    "                    print(\"Inside common dl\")\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                #print(voxel.shape)\n",
    "                voxel = voxel.to(device)\n",
    "                adapter_gt = adapter_gt.to(device)                \n",
    "                if (epoch>=adapt_epochs) and epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                \n",
    "                clip_target = clip_extractor.embed_image(image).float()\n",
    "        \n",
    "                voxel_ridge = model.ridge(voxel, 1)\n",
    "                #print(voxel_ridge.shape, adapter_gt.shape)\n",
    "                clip_voxels, clip_voxels_proj = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                    \n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                \n",
    "                \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "\n",
    "                loss_adapter = mse(voxel_ridge, adapter_gt)\n",
    "                loss_adapter_total += loss_adapter.item()\n",
    "                if adaptalign:\n",
    "                    loss_adapter *= mse_scale\n",
    "                    loss += loss_adapter\n",
    "\n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "        \n",
    "                # gather batches across multi-gpu if there's multiple\n",
    "                # clip_voxel_gather = accelerator.gather(clip_voxels_norm.view(len(voxel),-1).contiguous())\n",
    "                # clip_target_gather = accelerator.gather(clip_target_norm.view(len(voxel),-1).contiguous())\n",
    "        \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    if run_train:\n",
    "        for train_i, (voxel, image) in enumerate(train_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                print(train_i)\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                repeat_index = train_i % 3\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                voxel = voxel[:,repeat_index].float()\n",
    "                #print(voxel.shape)\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                clip_target = clip_extractor.embed_image(image).float()                \n",
    "                \n",
    "                voxel_ridge = model.ridge(voxel, 1)\n",
    "                clip_voxels, clip_voxels_proj = model.voxel2clip(voxel_ridge)\n",
    "\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            #aligned_clip_voxels /= diffusion_prior.module.image_embed_scale if distributed else diffusion_prior.image_embed_scale\n",
    "                            \n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                \n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "        \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "                \n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "                \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    model.eval()\n",
    "    for val_i, (voxel, image) in enumerate(test_dl):\n",
    "        with torch.no_grad():\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                voxel = voxel.to(device)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "\n",
    "                clip_target = clip_extractor.embed_image(image).float()            \n",
    "                \n",
    "                #voxel_ridge = model.ridge(voxel, subj-1)\n",
    "                voxel_ridge = model.ridge(voxel, 1)\n",
    "                \n",
    "                clip_voxels, clip_voxels_proj = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            val_loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            if torch.isnan(val_loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            test_recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "\n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    val_loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=None, betas=None, select=None)\n",
    "                else:\n",
    "                    val_loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = val_loss_nce + (prior_mult * val_loss_prior)\n",
    "                elif v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss = val_loss_nce\n",
    "                elif prior:\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = prior_mult * val_loss_prior\n",
    "                utils.check_loss(val_loss)\n",
    "                \n",
    "                val_losses.append(val_loss.item())\n",
    "\n",
    "                val_sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                \n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "\n",
    "    if local_rank==0:\n",
    "        val_loss = np.mean(val_losses[-(val_i+1):])\n",
    "        if val_loss < best_val_loss:\n",
    "            best_epoch = epoch\n",
    "            best_val_loss = val_loss\n",
    "        if utils.is_interactive():\n",
    "            clear_output(wait=True)\n",
    "            \n",
    "        logs = {\"train/loss\": np.mean(losses[-(train_i+common_i+1):]),\n",
    "            \"val/loss\": np.mean(val_losses[-(val_i+1):]),\n",
    "            \"train/lr\": lrs[-1],\n",
    "            \"train/num_steps\": len(losses),\n",
    "            \"val/num_steps\": len(val_losses),\n",
    "            \"train/cosine_sim_base\": sims_base / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_base\": val_sims_base / (val_i + 1),\n",
    "            \"train/cosine_sim_prior\": recon_cossim / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_prior\": test_recon_cossim / (val_i + 1),\n",
    "            \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + common_i+1),\n",
    "            \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + common_i+1),\n",
    "            \"val/val_fwd_pct_correct\": val_fwd_percent_correct / (val_i + 1),\n",
    "            \"val/val_bwd_pct_correct\": val_bwd_percent_correct / (val_i + 1),\n",
    "            \"train/adapter_mse\": loss_adapter_total / (common_i + 1),\n",
    "            \"train/loss_nce\": loss_nce_sum / (train_i + common_i+1),\n",
    "            \"train/loss_prior\": loss_prior_sum / (train_i + common_i+ 1),\n",
    "            \"val/loss_nce\": val_loss_nce_sum / (val_i + 1),\n",
    "            \"val/loss_prior\": val_loss_prior_sum / (val_i + 1)}\n",
    "        progress_bar.set_postfix(**logs)\n",
    "\n",
    "        # Save model checkpoint and reconstruct\n",
    "        #save_ckpt(f'last')\n",
    "        if epoch % ckpt_interval == 0:\n",
    "            save_ckpt(f'mid_{epoch}')\n",
    "\n",
    "print(\"\\n===Finished!===\\n\")\n",
    "print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f} at epoch: {best_epoch}')\n",
    "if not utils.is_interactive():\n",
    "    sys.exit(0)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce24c28-cc06-4849-bcad-cacc252c1d96",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "693188a1-3a10-4fa6-ba74-9b871ab4c481",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b577340-f1f0-4cb7-a244-06533bf66921",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd404259-1810-414a-b7de-3e35cb2babe8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee20d01-e5e0-493b-8e1c-6419cfd42d52",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f88b803-964c-4830-ab9a-869d62d5ce6a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f96bf6fc-7c8f-4f5d-86f9-9c91176072ef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "502275d3-e1cd-48fd-803e-39420d0dad8c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f070d9a4-5617-4e4a-b4d5-2547acfde327",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0807f49d-d72a-4852-a64a-ef4528312d5f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11223c35-d3fa-4947-9660-75e5f1e64183",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24bc5445-f211-4f8a-ace1-6c76b725b3ee",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eed2cbb-977b-48a0-80a4-679931948aca",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdeb1266-d4be-4c81-8e3e-35515ea0d1a7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
