{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be44a99b-2fca-4e8b-9430-955c8f6b3128",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import gc\n",
    "import torch.nn.functional as F\n",
    "from dataloaders import ImageVoxelDataset,ImageVoxelAdapterDataset\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "import utils\n",
    "from models import Clipper, BrainNetwork, BrainDiffusionPrior, BrainDiffusionPriorOld, VersatileDiffusionPriorNetwork"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ba88b4e-b53a-4f49-8144-f07b097d16b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from IPython.display import clear_output # function to clear print outputs in cell\n",
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7f1b19-f17d-4663-ad9b-8e82e70edd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = torch.float32 # change depending on your mixed_precision\n",
    "local_rank = 0\n",
    "world_size = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5fe2eb0-5b94-4595-86d7-731e2b4d1e64",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e348f03-d8c6-4143-b689-c98939be9f1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2504d23f-9b55-4141-a29b-a81535f545a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='subj1_nl_sclip_basictest'\n",
    "data_path='data/new_dl/'\n",
    "checkpoint_name = model_name\n",
    "subj=1\n",
    "batch_size=16\n",
    "hidden=True\n",
    "clip_variant='ViT-L/14'\n",
    "mixup_pct=0.15\n",
    "norm_embs=True\n",
    "use_image_aug=True\n",
    "num_epochs=300\n",
    "prior=True\n",
    "v2c=True\n",
    "lr_scheduler_type='cycle'\n",
    "ckpt_saving=True\n",
    "ckpt_interval=50\n",
    "run_common=True\n",
    "run_train=True #Set this to false if you are training with limited data/on common images only\n",
    "seed=42\n",
    "max_lr=6e-5\n",
    "use_projector=True\n",
    "cache_dir='data/fmri/cache'\n",
    "hidden_dim=4096\n",
    "checkpoint_dir = os.path.abspath(f'train_logs/{checkpoint_name}')\n",
    "#checkpoint_tag = 'mid_100'\n",
    "load_indices=False\n",
    "clip_seq_dim = 257\n",
    "clip_emb_dim = 768\n",
    "clip_size = clip_emb_dim\n",
    "#index_name = 'subj1_indices_30d_105.0w_256.npy'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cbada9f-738a-4dbf-ae88-d7a9e4dcdb19",
   "metadata": {},
   "outputs": [],
   "source": [
    "if subj == 1:\n",
    "    num_voxels = 15724\n",
    "elif subj == 2:\n",
    "    num_voxels = 14278\n",
    "elif subj == 3:\n",
    "    num_voxels = 15226\n",
    "elif subj == 4:\n",
    "    num_voxels = 13153\n",
    "elif subj == 5:\n",
    "    num_voxels = 13039\n",
    "elif subj == 6:\n",
    "    num_voxels = 17907\n",
    "elif subj == 7:\n",
    "    num_voxels = 12682\n",
    "elif subj == 8:\n",
    "    num_voxels = 14386\n",
    "num_voxels_list = [num_voxels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20db04cc-5c98-4b4a-b52e-557b00a20bc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = os.path.abspath(f'train_logs/{model_name}')\n",
    "if not os.path.exists(outdir):\n",
    "    os.makedirs(outdir,exist_ok=True)\n",
    "if use_image_aug:\n",
    "    import kornia\n",
    "    from kornia.augmentation.container import AugmentationSequential\n",
    "    img_augment = AugmentationSequential(\n",
    "        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n",
    "        kornia.augmentation.Resize((224, 224)),\n",
    "        kornia.augmentation.RandomHorizontalFlip(p=0.5),\n",
    "        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n",
    "        kornia.augmentation.RandomGrayscale(p=0.3),\n",
    "        data_keys=[\"input\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8370c2eb-da6d-4077-b09d-df26b8d07289",
   "metadata": {},
   "source": [
    "## Build Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb5d1a90-1c2d-4e47-afdc-3a41e53a10da",
   "metadata": {},
   "source": [
    "### Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56fba626-a158-4714-be06-4675e1299d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Common DL\n",
    "if run_common:\n",
    "    #Load indices here if chosen:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/test/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/test/images.pt'\n",
    "    \n",
    "    common_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "    \n",
    "    common_dl = DataLoader(common_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "    print(len(common_dataset))\n",
    "    voxel,image = next(iter(common_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a3671c1-583f-4fb5-b7ad-8445ec6ca4d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test DL\n",
    "\n",
    "feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/images.pt'\n",
    "\n",
    "test_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "print(len(test_dataset))\n",
    "voxel,image = next(iter(test_dl))\n",
    "print(voxel.shape, image.shape)\n",
    "print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6025a8b-4f8b-4fdc-a74d-d970ba330871",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Train DL\n",
    "if run_train:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/images.pt'\n",
    "    \n",
    "    train_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "    print(len(train_dataset))\n",
    "    voxel,image = next(iter(train_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2233a142-3fd8-48b1-85b6-3a1af8312ad2",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39b951f1-a063-4f4d-8319-280fd7785fbc",
   "metadata": {},
   "source": [
    "### Clipper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa60baee-6e61-400a-b017-fdab38d2a8ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Creating Clipper...')\n",
    "if hidden:\n",
    "    print(\"Using hidden layer CLIP space (Versatile Diffusion)\")\n",
    "    if not norm_embs:\n",
    "        print(\"WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!\")\n",
    "    clip_extractor = Clipper(clip_variant, device=device, hidden_state=True, norm_embs=norm_embs)\n",
    "    out_dim = 257 * clip_size\n",
    "else:\n",
    "    print(\"Using final layer CLIP space (Stable Diffusion Img Variations)\")\n",
    "    if norm_embs:\n",
    "        print(\"WARNING: YOU WANT UN-NORMED EMBEDDINGS FOR IMG VARIATIONS!\")\n",
    "    clip_extractor = Clipper(clip_variant, device=device, hidden_state=False, norm_embs=norm_embs)\n",
    "    out_dim = clip_size\n",
    "print(\"out_dim:\",out_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0c43a1a-e264-4a98-9c6d-d2adbd7ae0c0",
   "metadata": {},
   "source": [
    "### High Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "127e17f6-e90f-47c4-a27c-70f0f25cb530",
   "metadata": {},
   "outputs": [],
   "source": [
    "class fMRIModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(fMRIModule, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "        \n",
    "model = fMRIModule()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69be1193-9028-481d-80f7-7e943dc85e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    nn.LayerNorm(out_features),\n",
    "                    nn.GELU(),\n",
    "                    #nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "\n",
    "model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e41fdb3-6b20-41c8-a44c-e0973f0049ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity test\n",
    "b = torch.randn((2,num_voxels_list[0])).to(device)\n",
    "print(b.shape)\n",
    "print(b.shape, model.ridge(b,0).shape, b[:,0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f87f9ee7-0c32-464c-ae8d-1b2cab6be7a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set in_dim as anything if you dont care about the lin0 module\n",
    "voxel2clip_kwargs = dict(in_dim=hidden_dim,out_dim=clip_emb_dim*clip_seq_dim,\\\n",
    "                         clip_size=clip_emb_dim,use_projector=use_projector, ext_ridge=True)\n",
    "voxel2clip = BrainNetwork(**voxel2clip_kwargs)\n",
    "model.voxel2clip = voxel2clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca774adb-c68a-4f9a-b9b7-0fcbc3672b7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup prior network\n",
    "out_dim = clip_emb_dim\n",
    "depth = 6\n",
    "dim_head = 64\n",
    "heads = clip_emb_dim//64 # heads * dim_head = 12 * 64 = 768\n",
    "out_dim = clip_emb_dim\n",
    "\n",
    "#There is a non hidden version here in ME1 but we never use it\n",
    "guidance_scale = 3.5\n",
    "timesteps = 100\n",
    "prior_network = VersatileDiffusionPriorNetwork(\n",
    "        dim=out_dim,\n",
    "        depth=depth,\n",
    "        dim_head=dim_head,\n",
    "        heads=heads,\n",
    "        causal=False,\n",
    "        num_tokens = clip_seq_dim,\n",
    "        learned_query_mode=\"pos_emb\"\n",
    "    ).to(device)\n",
    "print(\"prior_network loaded\")\n",
    "\n",
    "# custom version that can fix seeds\n",
    "diffusion_prior = BrainDiffusionPrior(\n",
    "    net=prior_network,\n",
    "    image_embed_dim=out_dim,\n",
    "    condition_on_text_encodings=False,\n",
    "    timesteps=timesteps,\n",
    "    cond_drop_prob=0.2,\n",
    "    image_embed_scale=None,\n",
    ").to(device)\n",
    "\n",
    "model.diffusion_prior = diffusion_prior\n",
    "print(\"params of diffusion prior:\")\n",
    "if local_rank==0:\n",
    "    utils.count_params(model.diffusion_prior)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde06f23-4b38-4a69-a8f2-a31f5c6eb1ac",
   "metadata": {},
   "source": [
    "## Training Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6e6820-d121-417c-87a4-a4b786e23e06",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_common and run_train:\n",
    "    num_iterations_per_epoch = len(common_dl)+len(train_dl)\n",
    "elif run_common:\n",
    "    num_iterations_per_epoch = len(common_dl)\n",
    "else:\n",
    "    num_iterations_per_epoch = len(train_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fada699a-cb60-41c9-9a2c-d0b604721ac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations_per_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62e56e58-49da-446e-894f-a4d6497d23a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "opt_grouped_parameters = [\n",
    "    {'params': [p for n, p in model.ridge.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "    {'params': [p for n, p in model.ridge.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "]\n",
    "if v2c:\n",
    "    opt_grouped_parameters.extend([   \n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n",
    "    ])\n",
    "if prior:\n",
    "    opt_grouped_parameters.extend([\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ])\n",
    "\n",
    "print(len(opt_grouped_parameters), lr_scheduler_type)\n",
    "\n",
    "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n",
    "\n",
    "if lr_scheduler_type == 'linear':\n",
    "    lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
    "        optimizer,\n",
    "        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),\n",
    "        last_epoch=-1\n",
    "    )\n",
    "elif lr_scheduler_type == 'cycle':\n",
    "    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))\n",
    "    print(\"total_steps\", total_steps)\n",
    "    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
    "        optimizer, \n",
    "        max_lr=max_lr,\n",
    "        total_steps=total_steps,\n",
    "        final_div_factor=1000,\n",
    "        last_epoch=-1, pct_start=2/num_epochs\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc2b0467-0667-4824-845d-74e98208de8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_ckpt(tag):\n",
    "    ckpt_path = outdir + f'/{tag}.pth'\n",
    "    print(f'saving {ckpt_path}', flush=True)\n",
    "    try:\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),  \n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'lr_scheduler_state_dict': lr_scheduler.state_dict(),  \n",
    "            'train_losses': losses,\n",
    "            'val_losses': val_losses,\n",
    "            'lrs': lrs,\n",
    "            }, ckpt_path)\n",
    "    except Exception as e: \n",
    "        print(f\"Couldn't save due to {e}... moving on to prevent crashing.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdae65f5-551a-43d1-8a75-0b14ad0f2f93",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_ckpt(tag,load_lr=True,load_optimizer=True,load_epoch=False,strict=True,outdir=outdir,multisubj_loading=False): \n",
    "    print(f\"\\n---loading {outdir}/{tag}.pth ckpt---\\n\")\n",
    "    checkpoint = torch.load(f'{outdir}/{tag}.pth', map_location='cpu')\n",
    "    state_dict = checkpoint['model_state_dict']\n",
    "    if multisubj_loading: # remove incompatible ridge layer that will otherwise error\n",
    "        state_dict.pop('ridge.linears.0.weight',None)\n",
    "    model.load_state_dict(state_dict, strict=strict)\n",
    "    if load_epoch:\n",
    "        globals()[\"epoch\"] = checkpoint['epoch']\n",
    "        print(\"Epoch\",epoch)\n",
    "    if load_optimizer:\n",
    "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "    if load_lr:\n",
    "        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n",
    "    del checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15d11c66-18ce-40b8-a62a-5cb9d5f34334",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0\n",
    "losses, val_losses, lrs = [], [], []\n",
    "nce_losses, val_nce_losses = [], []\n",
    "sim_losses, val_sim_losses = [], []\n",
    "best_val_loss = 1e9\n",
    "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n",
    "if hidden:\n",
    "    prior_mult = 30\n",
    "else:\n",
    "    prior_mult = .03\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "725e3124-b3ca-4d31-aa44-67b8488e11e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6429778d-46a3-4dfd-af0c-a7ab7afb0be3",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d60e2e7e-1887-44a3-93d5-b772c6f00b8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(use_image_aug,mixup_pct, ckpt_interval,prior_mult)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de9d0a38-39b1-4fbe-a0b4-2a5fada578b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model: {model_name}\")\n",
    "epoch=0\n",
    "best_epoch=0\n",
    "print(epoch,ckpt_interval,mixup_pct)\n",
    "max_retries = 5  # Define maximum number of retries\n",
    "counter=0\n",
    "def scale_tau(tau_ref: float, B: int, ref_B: int = 16, alpha: float = 0.5):\n",
    "    # alpha=0.5 (sqrt) is a good default; alpha=1.0 is stronger\n",
    "    return tau_ref * (B / ref_B) ** alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74079045-4022-40db-9c43-6bea3ab45eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n",
    "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n",
    "for epoch in progress_bar:\n",
    "    model.train()\n",
    "\n",
    "    sims_base = 0.\n",
    "    val_sims_base = 0.\n",
    "    recon_cossim=0.\n",
    "    test_recon_cossim=0.\n",
    "    fwd_percent_correct = 0.\n",
    "    bwd_percent_correct = 0.\n",
    "    val_fwd_percent_correct = 0.\n",
    "    val_bwd_percent_correct = 0.\n",
    "    test_loss_cossim_total = 0.\n",
    "    loss_cossim_total = 0.\n",
    "    loss_nce_sum = 0.\n",
    "    loss_prior_sum = 0.\n",
    "    val_loss_nce_sum = 0.\n",
    "    val_loss_prior_sum = 0.\n",
    "\n",
    "    common_i=0\n",
    "    train_i=0\n",
    "    if run_common:\n",
    "        for common_i, (voxel, image) in enumerate(common_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                if common_i==3:\n",
    "                    print(\"Inside common dl\")\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                #print(voxel.shape)\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                \n",
    "                clip_target = clip_extractor.embed_image(image).float()\n",
    "        \n",
    "                voxel_ridge = model.ridge(voxel, 0)\n",
    "                #print(voxel_ridge.shape, adapter_gt.shape)\n",
    "                clip_voxels, clip_voxels_proj = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                    \n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "                \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "                \n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "        \n",
    "                # gather batches across multi-gpu if there's multiple\n",
    "                # clip_voxel_gather = accelerator.gather(clip_voxels_norm.view(len(voxel),-1).contiguous())\n",
    "                # clip_target_gather = accelerator.gather(clip_target_norm.view(len(voxel),-1).contiguous())\n",
    "        \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    if run_train:\n",
    "        for train_i, (voxel, image) in enumerate(train_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                print(train_i)\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                repeat_index = train_i % 3\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                voxel = voxel[:,repeat_index].float()\n",
    "                #print(voxel.shape)\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                \n",
    "                clip_target = clip_extractor.embed_image(image).float()\n",
    "                \n",
    "                voxel_ridge = model.ridge(voxel, 0)\n",
    "                clip_voxels, clip_voxels_proj = model.voxel2clip(voxel_ridge)\n",
    "    \n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            #aligned_clip_voxels /= diffusion_prior.module.image_embed_scale if distributed else diffusion_prior.image_embed_scale\n",
    "                            \n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                                \n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "        \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "                \n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "                \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    model.eval()\n",
    "    for val_i, (voxel, image) in enumerate(test_dl):\n",
    "        with torch.no_grad():\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                voxel = voxel.to(device)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "\n",
    "                clip_target = clip_extractor.embed_image(image).float()\n",
    "                \n",
    "                #voxel_ridge = model.ridge(voxel, subj-1)\n",
    "                voxel_ridge = model.ridge(voxel, 0)\n",
    "                \n",
    "                clip_voxels, clip_voxels_proj = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            val_loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            if torch.isnan(val_loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            test_recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "\n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    val_loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=None, betas=None, select=None)\n",
    "                else:\n",
    "                    val_loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = val_loss_nce + (prior_mult * val_loss_prior)\n",
    "                elif v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss = val_loss_nce\n",
    "                elif prior:\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = prior_mult * val_loss_prior\n",
    "                utils.check_loss(val_loss)\n",
    "                \n",
    "                val_losses.append(val_loss.item())\n",
    "\n",
    "                val_sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                \n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "\n",
    "    if local_rank==0:\n",
    "        val_loss = np.mean(val_losses[-(val_i+1):])\n",
    "        if val_loss < best_val_loss:\n",
    "            best_epoch = epoch\n",
    "            best_val_loss = val_loss\n",
    "        if utils.is_interactive():\n",
    "            clear_output(wait=True)\n",
    "            \n",
    "        logs = {\"train/loss\": np.mean(losses[-(train_i+common_i+1):]),\n",
    "            \"val/loss\": np.mean(val_losses[-(val_i+1):]),\n",
    "            \"train/lr\": lrs[-1],\n",
    "            \"train/num_steps\": len(losses),\n",
    "            \"val/num_steps\": len(val_losses),\n",
    "            \"train/cosine_sim_base\": sims_base / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_base\": val_sims_base / (val_i + 1),\n",
    "            \"train/cosine_sim_prior\": recon_cossim / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_prior\": test_recon_cossim / (val_i + 1),\n",
    "            \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + common_i+1),\n",
    "            \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + common_i+1),\n",
    "            \"val/val_fwd_pct_correct\": val_fwd_percent_correct / (val_i + 1),\n",
    "            \"val/val_bwd_pct_correct\": val_bwd_percent_correct / (val_i + 1),\n",
    "            \"train/loss_nce\": loss_nce_sum / (train_i + common_i+1),\n",
    "            \"train/loss_prior\": loss_prior_sum / (train_i + common_i+ 1),\n",
    "            \"val/loss_nce\": val_loss_nce_sum / (val_i + 1),\n",
    "            \"val/loss_prior\": val_loss_prior_sum / (val_i + 1)}\n",
    "        progress_bar.set_postfix(**logs)\n",
    "\n",
    "        # Save model checkpoint and reconstruct\n",
    "        #save_ckpt(f'last')\n",
    "        if epoch % ckpt_interval == 0:\n",
    "            save_ckpt(f'mid_{epoch}')\n",
    "\n",
    "print(\"\\n===Finished!===\\n\")\n",
    "print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f} at epoch: {best_epoch}')\n",
    "if not utils.is_interactive():\n",
    "    sys.exit(0)   "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
