{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f5a116-750f-446a-973e-698a2c4f563a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import gc\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "from einops import rearrange\n",
    "import time\n",
    "import random\n",
    "import string\n",
    "import h5py\n",
    "from tqdm import tqdm\n",
    "import webdataset as wds\n",
    "\n",
    "\n",
    "# tf32 data type is faster than standard float32\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "sys.path.append('generative_models/')\n",
    "import sgm\n",
    "from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder\n",
    "\n",
    "# custom models and functions #\n",
    "import v2_utils as utils\n",
    "\n",
    "from dataloaders import ImageVoxelDataset,ImageVoxelAdapterDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6dcbce-d5ff-4541-898d-dec0640b72f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e6d33c-6305-47d5-8872-14beaf3808ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output # function to clear print outputs in cell\n",
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "062d4039-10f0-44be-81e5-1c2e167dca00",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = torch.float16 # change depending on your mixed_precision\n",
    "local_rank = 0\n",
    "world_size = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdc6b938-2fc9-4fc9-816c-eb2072fc8521",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c576e69e-2205-48f2-8798-89baeb7ec46e",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "666f3b7f-38e4-468b-8c1e-47bc613f9a1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='subj1_l_bclip_basictest_wll_5'\n",
    "data_path='data/new_dl/'\n",
    "checkpoint_name = model_name\n",
    "subj=1\n",
    "batch_size=16\n",
    "hidden=True\n",
    "#resume_from_ckpt=False\n",
    "mixup_pct=0.15\n",
    "norm_embs=True\n",
    "use_image_aug=True\n",
    "blurry_recon=True\n",
    "blur_scale=0.5\n",
    "num_epochs=301\n",
    "prior=True\n",
    "v2c=True\n",
    "lr_scheduler_type='cycle'\n",
    "ckpt_saving=True\n",
    "ckpt_interval=50\n",
    "run_common=True\n",
    "run_train=True #Set this to false if you are training with limited data/on common images only\n",
    "seed=42\n",
    "max_lr=6e-5\n",
    "use_projector=True\n",
    "cache_dir='checkpoints'\n",
    "hidden_dim=4096\n",
    "checkpoint_dir = os.path.abspath(f'train_logs/{checkpoint_name}')\n",
    "checkpoint_tag = 'mid_200'\n",
    "resume_from_ckpt=True\n",
    "# load_indices=True\n",
    "# index_name = 'subj1_indices_30d_105.0w_256.npy'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25f23c5d-dd03-41c8-adfc-640540599fa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if subj == 1:\n",
    "    num_voxels = 15724\n",
    "elif subj == 2:\n",
    "    num_voxels = 14278\n",
    "elif subj == 3:\n",
    "    num_voxels = 15226\n",
    "elif subj == 4:\n",
    "    num_voxels = 13153\n",
    "elif subj == 5:\n",
    "    num_voxels = 13039\n",
    "elif subj == 6:\n",
    "    num_voxels = 17907\n",
    "elif subj == 7:\n",
    "    num_voxels = 12682\n",
    "elif subj == 8:\n",
    "    num_voxels = 14386"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a8b1fc4-9956-41a2-afc9-b15a9e906fa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_voxels_list = [num_voxels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b2bde0-83a5-4729-988e-075bad391b13",
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = os.path.abspath(f'train_logs/{model_name}')\n",
    "if not os.path.exists(outdir):\n",
    "    os.makedirs(outdir,exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69182c6d-cc43-4dca-94e2-e3edd876b7c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_image_aug:\n",
    "    import kornia\n",
    "    from kornia.augmentation.container import AugmentationSequential\n",
    "    img_augment = AugmentationSequential(\n",
    "        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n",
    "        kornia.augmentation.Resize((224, 224)),\n",
    "        kornia.augmentation.RandomHorizontalFlip(p=0.5),\n",
    "        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n",
    "        kornia.augmentation.RandomGrayscale(p=0.3),\n",
    "        data_keys=[\"input\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7e9bc9-69ce-4738-a327-fac58e233b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import kornia\n",
    "from kornia.augmentation.container import AugmentationSequential\n",
    "img_augment2 = AugmentationSequential(\n",
    "    kornia.augmentation.Resize((224, 224)),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60215f05-aaf6-46d8-b3dd-bd4d6f3a6df2",
   "metadata": {},
   "source": [
    "## Build Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "757f38f9-b725-4205-ad90-93da6b147232",
   "metadata": {},
   "source": [
    "### Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d49f5dc-2449-4f50-89ef-f77b0a93b9e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Common DL\n",
    "if run_common:\n",
    "    #Load indices here if chosen:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/test/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/test/images.pt'\n",
    "    \n",
    "    common_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "    \n",
    "    common_dl = DataLoader(common_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "    print(len(common_dataset))\n",
    "    voxel,image = next(iter(common_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc7d8f7e-a504-4bb4-9d3a-4aaff335cb7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test DL\n",
    "\n",
    "feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/test/images.pt'\n",
    "\n",
    "test_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "print(len(test_dataset))\n",
    "voxel,image = next(iter(test_dl))\n",
    "print(voxel.shape, image.shape)\n",
    "print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "725b1a11-aeab-4081-802e-fccdc6930a21",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Train DL\n",
    "if run_train:\n",
    "    feature_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/betas.pt'\n",
    "    image_file = f'data/new_dl/subj{subj:02d}/train/custom_split/train/images.pt'\n",
    "    \n",
    "    train_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "    print(len(train_dataset))\n",
    "    voxel,image = next(iter(train_dl))\n",
    "    print(voxel.shape, image.shape)\n",
    "    print(image.min(),image.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "346141d0-a123-40fa-bfef-bf892bf24de3",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c61bf49-0a33-4cd2-b14f-10167bd6ec34",
   "metadata": {},
   "source": [
    "### Low Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8bff747-fada-4ec8-a24e-05c1b0ccca95",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(blur_scale,blurry_recon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bd3517c-0cc4-4113-a34a-1d3a41518ac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "if blurry_recon:\n",
    "    from diffusers import AutoencoderKL    \n",
    "    autoenc = AutoencoderKL(\n",
    "        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],\n",
    "        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],\n",
    "        block_out_channels=[128, 256, 512, 512],\n",
    "        layers_per_block=2,\n",
    "        sample_size=256,\n",
    "    )\n",
    "    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')\n",
    "    autoenc.load_state_dict(ckpt)\n",
    "    \n",
    "    autoenc.eval()\n",
    "    autoenc.requires_grad_(False)\n",
    "    autoenc.to(device)\n",
    "    utils.count_params(autoenc)\n",
    "    \n",
    "    from convnext import ConvnextXL\n",
    "    cnx = ConvnextXL(f'{cache_dir}/convnext_xlarge_alpha0.75_fullckpt.pth')\n",
    "    cnx.requires_grad_(False)\n",
    "    cnx.eval()\n",
    "    cnx.to(device)\n",
    "    \n",
    "    mean = torch.tensor([0.485, 0.456, 0.406]).to(device).reshape(1,3,1,1)\n",
    "    std = torch.tensor([0.228, 0.224, 0.225]).to(device).reshape(1,3,1,1)\n",
    "    \n",
    "    blur_augs = AugmentationSequential(\n",
    "        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),\n",
    "        kornia.augmentation.RandomGrayscale(p=0.1),\n",
    "        kornia.augmentation.RandomSolarize(p=0.1),\n",
    "        kornia.augmentation.RandomResizedCrop((224,224), scale=(.9,.9), ratio=(1,1), p=0.3),\n",
    "        data_keys=[\"input\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72cfc521-d9fb-4fed-ad4f-10c158f0df96",
   "metadata": {},
   "source": [
    "### Clipper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50d8dc1f-89ed-4e65-aefc-fec500435e6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_img_embedder = FrozenOpenCLIPImageEmbedder(\n",
    "    arch=\"ViT-bigG-14\",\n",
    "    version=\"laion2b_s39b_b160k\",\n",
    "    output_tokens=True,\n",
    "    only_tokens=True,\n",
    ")\n",
    "clip_img_embedder.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6ece342-34ba-47c9-94d2-44ef949bec47",
   "metadata": {},
   "source": [
    "### High Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2098ba43-4ac5-4528-8ecf-64f24749898e",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_seq_dim = 256\n",
    "clip_emb_dim = 1664"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f07d2ae-5694-422d-addd-50e0260e0323",
   "metadata": {},
   "outputs": [],
   "source": [
    "class fMRIModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(fMRIModule, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "        \n",
    "model = fMRIModule()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0054f5c4-6aa6-420a-bdb9-363599491dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    # nn.LayerNorm(out_features),\n",
    "                    # nn.GELU(),\n",
    "                    # nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "        \n",
    "model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b0d8e66-5f0b-4171-9fd5-12e821d6f6d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sanity check\n",
    "b = torch.randn((2,num_voxels_list[0])).to(device)\n",
    "print(b.shape)\n",
    "print(b.shape, model.ridge(b,0).shape, b[:,0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "374c6525-b652-4fed-a709-5eeb3829bc7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from v2_models import BrainNetwork\n",
    "#Set in_dim as anything if you dont care about the lin0 module\n",
    "voxel2clip_kwargs = dict(in_dim=hidden_dim, seq_len=1, h=hidden_dim, out_dim=clip_emb_dim*clip_seq_dim,\\\n",
    "                         clip_size=clip_emb_dim, clip_scale=1, blurry_recon=blurry_recon, n_blocks=4)\n",
    "voxel2clip = BrainNetwork(**voxel2clip_kwargs)\n",
    "model.voxel2clip = voxel2clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d62c2f-c06c-434a-b116-90ccb42d5979",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.count_params(model.voxel2clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3642f470-54f2-4258-950a-de0db2e930dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from v2_models import *\n",
    "\n",
    "out_dim = clip_emb_dim\n",
    "depth = 6\n",
    "dim_head = 52\n",
    "heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim\n",
    "timesteps = 100\n",
    "\n",
    "prior_network = PriorNetwork(\n",
    "        dim=out_dim,\n",
    "        depth=depth,\n",
    "        dim_head=dim_head,\n",
    "        heads=heads,\n",
    "        causal=False,\n",
    "        num_tokens = clip_seq_dim,\n",
    "        learned_query_mode=\"pos_emb\"\n",
    "    )\n",
    "\n",
    "model.diffusion_prior = BrainDiffusionPrior(\n",
    "    net=prior_network,\n",
    "    image_embed_dim=out_dim,\n",
    "    condition_on_text_encodings=False,\n",
    "    timesteps=timesteps,\n",
    "    cond_drop_prob=0.2,\n",
    "    image_embed_scale=None,\n",
    ")\n",
    "\n",
    "utils.count_params(model.diffusion_prior)\n",
    "utils.count_params(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63efa5ce-2ecc-4a60-8929-4d5a0d98d1b5",
   "metadata": {},
   "source": [
    "## Training Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d814d012-232e-4c71-a227-6c07b365ccf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_common and run_train:\n",
    "    num_iterations_per_epoch = len(common_dl)+len(train_dl)\n",
    "elif run_common:\n",
    "    num_iterations_per_epoch = len(common_dl)\n",
    "else:\n",
    "    num_iterations_per_epoch = len(train_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c444770e-c9bf-4cde-9671-db9206af21d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations_per_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "417fa2d8-c864-43a3-96bc-f9daa8ef4dbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "\n",
    "opt_grouped_parameters = [\n",
    "    {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n",
    "]\n",
    "if v2c:\n",
    "    opt_grouped_parameters.extend([   \n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n",
    "    ])\n",
    "if prior:\n",
    "    opt_grouped_parameters.extend([\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
    "        {'params': [p for n, p in model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ])\n",
    "\n",
    "print(len(opt_grouped_parameters), lr_scheduler_type)\n",
    "\n",
    "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n",
    "\n",
    "if lr_scheduler_type == 'linear':\n",
    "    lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
    "        optimizer,\n",
    "        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),\n",
    "        last_epoch=-1\n",
    "    )\n",
    "elif lr_scheduler_type == 'cycle':\n",
    "    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))\n",
    "    print(\"total_steps\", total_steps)\n",
    "    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
    "        optimizer, \n",
    "        max_lr=max_lr,\n",
    "        total_steps=total_steps,\n",
    "        final_div_factor=1000,\n",
    "        last_epoch=-1, pct_start=2/num_epochs\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b38598-a99d-4d5b-a338-1e7f598e97f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_ckpt(tag):\n",
    "    ckpt_path = outdir + f'/{tag}.pth'\n",
    "    print(f'saving {ckpt_path}', flush=True)\n",
    "    try:\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),  \n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'lr_scheduler_state_dict': lr_scheduler.state_dict(),  \n",
    "            'train_losses': losses,\n",
    "            'val_losses': val_losses,\n",
    "            'lrs': lrs,\n",
    "            }, ckpt_path)\n",
    "    except Exception as e: \n",
    "        print(f\"Couldn't save due to {e}... moving on to prevent crashing.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bac5cb8d-1bd5-40be-8f39-a3ce2d1a6278",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_ckpt(tag,load_lr=True,load_optimizer=True,load_epoch=False,strict=True,outdir=outdir,multisubj_loading=False): \n",
    "    print(f\"\\n---loading {outdir}/{tag}.pth ckpt---\\n\")\n",
    "    checkpoint = torch.load(f'{outdir}/{tag}.pth', map_location='cpu')\n",
    "    state_dict = checkpoint['model_state_dict']\n",
    "    if multisubj_loading: # remove incompatible ridge layer that will otherwise error\n",
    "        state_dict.pop('ridge.linears.0.weight',None)\n",
    "    model.load_state_dict(state_dict, strict=strict)\n",
    "    if load_epoch:\n",
    "        globals()[\"epoch\"] = checkpoint['epoch']\n",
    "        print(\"Epoch\",epoch)\n",
    "    if load_optimizer:\n",
    "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "    if load_lr:\n",
    "        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n",
    "    del checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "171cc26d-ccc9-45ac-8839-1d7688121f0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0\n",
    "losses, val_losses, lrs = [], [], []\n",
    "nce_losses, val_nce_losses = [], []\n",
    "sim_losses, val_sim_losses = [], []\n",
    "best_val_loss = 1e9\n",
    "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n",
    "if hidden:\n",
    "    prior_mult = 30\n",
    "else:\n",
    "    prior_mult = .03\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0463692-7d99-4b69-885b-b864739d180f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1326d968-252f-47e4-ac03-2c32027e27d1",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eccba97-3991-4e08-b86c-9eb0ea6f729e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(use_image_aug,mixup_pct, ckpt_interval,prior_mult)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c949f5da-940c-4f92-a04f-dedb6de6a2f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model: {model_name}\")\n",
    "epoch=0\n",
    "best_epoch=0\n",
    "print(epoch,ckpt_interval,mixup_pct)\n",
    "max_retries = 5  # Define maximum number of retries\n",
    "counter=0\n",
    "mse = nn.MSELoss()\n",
    "l1 = nn.L1Loss()\n",
    "def scale_tau(tau_ref: float, B: int, ref_B: int = 16, alpha: float = 0.5):\n",
    "    # alpha=0.5 (sqrt) is a good default; alpha=1.0 is stronger\n",
    "    return tau_ref * (B / ref_B) ** alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d2e1eb-eac9-4e2a-b285-0c50c4041bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if resume_from_ckpt:\n",
    "    print(f\"Loading {checkpoint_dir}/{checkpoint_tag}.pth\")\n",
    "    checkpoint = torch.load(f'{checkpoint_dir}/{checkpoint_tag}.pth', map_location='cpu')\n",
    "    state_dict = checkpoint['model_state_dict']\n",
    "    model.load_state_dict(state_dict, strict=True)\n",
    "    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "    lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])\n",
    "    epoch = checkpoint['epoch']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ed79918-d139-4db6-a280-5430de1ed629",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n",
    "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n",
    "for epoch in progress_bar:\n",
    "    model.train()\n",
    "\n",
    "    sims_base = 0.\n",
    "    val_sims_base = 0.\n",
    "    recon_cossim=0.\n",
    "    test_recon_cossim=0.\n",
    "    fwd_percent_correct = 0.\n",
    "    bwd_percent_correct = 0.\n",
    "    val_fwd_percent_correct = 0.\n",
    "    val_bwd_percent_correct = 0.\n",
    "    test_loss_cossim_total = 0.\n",
    "    loss_cossim_total = 0.\n",
    "    loss_nce_sum = 0.\n",
    "    loss_prior_sum = 0.\n",
    "    val_loss_nce_sum = 0.\n",
    "    val_loss_prior_sum = 0.\n",
    "    loss_blurry_total = 0.\n",
    "    loss_blurry_cont_total = 0.\n",
    "\n",
    "    common_i=0\n",
    "    train_i=0\n",
    "    if run_common:\n",
    "        for common_i, (voxel, image) in enumerate(common_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                if common_i==3:\n",
    "                    print(\"Inside common dl\")\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                #print(voxel.shape)\n",
    "                voxel = voxel.to(device)\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                else:\n",
    "                    #Make sure this is resize only\n",
    "                    image = img_augment2(image)\n",
    "                image = torch.tensor(image, dtype=data_type).to(device)\n",
    "                clip_target = clip_img_embedder(image).to(device)\n",
    "        \n",
    "                voxel_ridge = model.ridge(voxel, 0)\n",
    "                #print(voxel_ridge.shape, adapter_gt.shape)\n",
    "                clip_voxels, clip_voxels_proj, blurry_image_enc_ = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                    \n",
    "                if blurry_recon:     \n",
    "                    image_enc_pred, transformer_feats = blurry_image_enc_\n",
    "    \n",
    "                    image_enc = autoenc.encode(2*image-1).latent_dist.mode() * 0.18215\n",
    "                    image_enc = image_enc.float()\n",
    "                    loss_blurry = l1(image_enc_pred, image_enc)\n",
    "                    loss_blurry_total += loss_blurry.item()\n",
    "    \n",
    "                    if epoch < int(mixup_pct * num_epochs):\n",
    "                        image_enc_shuf = image_enc[perm]\n",
    "                        betas_shape = [-1] + [1]*(len(image_enc.shape)-1)\n",
    "                        image_enc[select] = image_enc[select] * betas[select].reshape(*betas_shape) + \\\n",
    "                            image_enc_shuf[select] * (1 - betas[select]).reshape(*betas_shape)\n",
    "    \n",
    "                    image_norm = (image - mean)/std\n",
    "                    with torch.cuda.amp.autocast(enabled=False):\n",
    "                        image_aug = blur_augs(image.float())\n",
    "                    image_aug = image_aug.clamp_(0, 1)\n",
    "                    image_aug = (image_aug - mean)/std                    \n",
    "                    \n",
    "                    _, cnx_embeds = cnx(image_norm)\n",
    "                    _, cnx_aug_embeds = cnx(image_aug)\n",
    "    \n",
    "                    cont_loss = utils.soft_cont_loss(\n",
    "                        nn.functional.normalize(transformer_feats.reshape(-1, transformer_feats.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_aug_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        temp=0.2)\n",
    "                    loss_blurry_cont_total += cont_loss.item()\n",
    "    \n",
    "                    loss += (loss_blurry + 0.0*cont_loss) * blur_scale #/.18215\n",
    "\n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "\n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "        \n",
    "                # gather batches across multi-gpu if there's multiple\n",
    "                # clip_voxel_gather = accelerator.gather(clip_voxels_norm.view(len(voxel),-1).contiguous())\n",
    "                # clip_target_gather = accelerator.gather(clip_target_norm.view(len(voxel),-1).contiguous())\n",
    "        \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    if run_train:\n",
    "        for train_i, (voxel, image) in enumerate(train_dl):\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "                print(train_i)\n",
    "                optimizer.zero_grad()\n",
    "                loss=0.\n",
    "                voxel = voxel.to(device)\n",
    "                repeat_index = train_i % 3\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                else:\n",
    "                    image = img_augment2(image)\n",
    "                voxel = voxel[:,repeat_index].float()\n",
    "                #print(voxel.shape)\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    voxel, perm, betas, select = utils.mixco(voxel)\n",
    "                image = torch.tensor(image, dtype=data_type).to(device)\n",
    "                clip_target = clip_img_embedder(image).to(device)\n",
    "                \n",
    "                #voxel_ridge = model.ridge(voxel, subj-1)\n",
    "                voxel_ridge = model.ridge(voxel, 0)\n",
    "                clip_voxels, clip_voxels_proj, blurry_image_enc_ = model.voxel2clip(voxel_ridge)\n",
    "    \n",
    "                #clip_voxels, clip_voxels_proj = diffusion_prior.module.voxel2clip(voxel) if distributed else diffusion_prior.voxel2clip(voxel)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel_ridge),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            #aligned_clip_voxels /= diffusion_prior.module.image_embed_scale if distributed else diffusion_prior.image_embed_scale\n",
    "                            \n",
    "                            # If loss_prior is NaN, clean up before retry\n",
    "                            if torch.isnan(loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if optimizer:\n",
    "                                    optimizer.zero_grad()\n",
    "                                # Clear the memory cache if possible\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "                    \n",
    "                if blurry_recon:     \n",
    "                    image_enc_pred, transformer_feats = blurry_image_enc_\n",
    "    \n",
    "                    image_enc = autoenc.encode(2*image-1).latent_dist.mode() * 0.18215\n",
    "                    image_enc = image_enc.float()\n",
    "                    loss_blurry = l1(image_enc_pred, image_enc)\n",
    "                    loss_blurry_total += loss_blurry.item()\n",
    "    \n",
    "                    if epoch < int(mixup_pct * num_epochs):\n",
    "                        image_enc_shuf = image_enc[perm]\n",
    "                        betas_shape = [-1] + [1]*(len(image_enc.shape)-1)\n",
    "                        image_enc[select] = image_enc[select] * betas[select].reshape(*betas_shape) + \\\n",
    "                            image_enc_shuf[select] * (1 - betas[select]).reshape(*betas_shape)\n",
    "    \n",
    "                    image_norm = (image - mean)/std\n",
    "                    with torch.cuda.amp.autocast(enabled=False):\n",
    "                        image_aug = blur_augs(image.float())\n",
    "                    image_aug = image_aug.clamp_(0, 1)\n",
    "                    image_aug = (image_aug - mean)/std\n",
    "                    \n",
    "                    _, cnx_embeds = cnx(image_norm)\n",
    "                    _, cnx_aug_embeds = cnx(image_aug)\n",
    "    \n",
    "                    cont_loss = utils.soft_cont_loss(\n",
    "                        nn.functional.normalize(transformer_feats.reshape(-1, transformer_feats.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        nn.functional.normalize(cnx_aug_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),\n",
    "                        temp=0.2)\n",
    "                    loss_blurry_cont_total += cont_loss.item()\n",
    "    \n",
    "                    loss += (loss_blurry + 0.0*cont_loss) * blur_scale #/.18215\n",
    "    \n",
    "                \n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "        \n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=perm, betas=betas, select=select)\n",
    "                else:\n",
    "                    B = clip_voxels_norm.size(0)\n",
    "                    tau_ref = float(soft_loss_temps[epoch -int(mixup_pct*num_epochs)].item())\n",
    "                    epoch_temp = scale_tau(tau_ref, B, ref_B=16, alpha=0.5)\n",
    "                    loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = loss_nce + (prior_mult * loss_prior)\n",
    "                elif v2c:\n",
    "                    loss_nce_sum += loss_nce.item()\n",
    "                    loss = loss_nce\n",
    "                elif prior:\n",
    "                    loss_prior_sum += loss_prior.item()\n",
    "                    loss = prior_mult * loss_prior\n",
    "                utils.check_loss(loss)\n",
    "                \n",
    "                #accelerator.backward(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "        \n",
    "                losses.append(loss.item())\n",
    "                lrs.append(optimizer.param_groups[0]['lr'])\n",
    "                \n",
    "                sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                # forward and backward top 1 accuracy\n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "        \n",
    "                if lr_scheduler_type is not None:\n",
    "                    lr_scheduler.step()\n",
    "    model.eval()\n",
    "    for val_i, (voxel, image) in enumerate(test_dl):\n",
    "        with torch.no_grad():\n",
    "            with torch.cuda.amp.autocast(dtype=data_type):\n",
    "\n",
    "                voxel = torch.mean(voxel,axis=1).float()\n",
    "                voxel = voxel.to(device)\n",
    "                if use_image_aug:\n",
    "                    image = img_augment(image)\n",
    "                else:\n",
    "                    image = img_augment2(image)\n",
    "\n",
    "                image = torch.tensor(image, dtype=data_type).to(device)\n",
    "                clip_target = clip_img_embedder(image).to(device)\n",
    "            \n",
    "                #voxel_ridge = model.ridge(voxel, subj-1)\n",
    "                voxel_ridge = model.ridge(voxel, 0)\n",
    "                \n",
    "                clip_voxels, clip_voxels_proj, blurry_image_enc_ = model.voxel2clip(voxel_ridge)\n",
    "                if hidden:\n",
    "                    clip_voxels = clip_voxels.view(len(voxel),-1,clip_emb_dim)\n",
    "                \n",
    "                if prior:\n",
    "                    for attempt in range(max_retries):\n",
    "                        try:\n",
    "                            # Forward pass\n",
    "                            val_loss_prior, aligned_clip_voxels = model.diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)\n",
    "                            if torch.isnan(val_loss_prior).any():\n",
    "                                counter+=1\n",
    "                                raise ValueError(\"Encountered NaN in loss_prior\")\n",
    "                            test_recon_cossim += nn.functional.cosine_similarity(aligned_clip_voxels, clip_target).mean().item()\n",
    "                            # If loss_prior is not NaN, break the loop and continue\n",
    "                            break\n",
    "                        except ValueError as e:\n",
    "                            if attempt < max_retries - 1:\n",
    "                                print(f\"{e} on attempt {attempt+1}, clearing gradients and retrying...\")\n",
    "                                # Clear gradients if they have been calculated\n",
    "                                if torch.cuda.is_available():\n",
    "                                    torch.cuda.empty_cache()\n",
    "                                # Explicitly collect garbage to remove intermediate tensors\n",
    "                                gc.collect()\n",
    "                            else:\n",
    "                                raise RuntimeError(f\"loss_prior remains NaN after {max_retries} attempts.\") from e\n",
    "                else:\n",
    "                    aligned_clip_voxels = clip_voxels\n",
    "\n",
    "                clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1)\n",
    "                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
    "\n",
    "                if epoch < int(mixup_pct * num_epochs):\n",
    "                    val_loss_nce = utils.mixco_nce(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=.006,\n",
    "                        perm=None, betas=None, select=None)\n",
    "                else:\n",
    "                    val_loss_nce = utils.soft_clip_loss(\n",
    "                        clip_voxels_norm,\n",
    "                        clip_target_norm,\n",
    "                        temp=epoch_temp)\n",
    "                    \n",
    "                if prior and v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = val_loss_nce + (prior_mult * val_loss_prior)\n",
    "                elif v2c:\n",
    "                    val_loss_nce_sum += val_loss_nce.item()\n",
    "                    val_loss = val_loss_nce\n",
    "                elif prior:\n",
    "                    val_loss_prior_sum += val_loss_prior.item()\n",
    "                    val_loss = prior_mult * val_loss_prior\n",
    "                utils.check_loss(val_loss)\n",
    "                \n",
    "                val_losses.append(val_loss.item())\n",
    "\n",
    "                val_sims_base += nn.functional.cosine_similarity(clip_target_norm,clip_voxels_norm).mean().item()\n",
    "                \n",
    "                labels = torch.arange(len(clip_target_norm)).to(device)\n",
    "                val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm,clip_target_norm), labels, k=1)\n",
    "                val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
    "\n",
    "    if local_rank==0:\n",
    "        val_loss = np.mean(val_losses[-(val_i+1):])\n",
    "        if val_loss < best_val_loss:\n",
    "            best_epoch = epoch\n",
    "            best_val_loss = val_loss\n",
    "        if utils.is_interactive():\n",
    "            clear_output(wait=True)\n",
    "            \n",
    "        logs = {\"train/loss\": np.mean(losses[-(train_i+common_i+1):]),\n",
    "            \"val/loss\": np.mean(val_losses[-(val_i+1):]),\n",
    "            \"train/lr\": lrs[-1],\n",
    "            \"train/num_steps\": len(losses),\n",
    "            \"val/num_steps\": len(val_losses),\n",
    "            \"train/cosine_sim_base\": sims_base / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_base\": val_sims_base / (val_i + 1),\n",
    "            \"train/cosine_sim_prior\": recon_cossim / (train_i + common_i+1),\n",
    "            \"val/cosine_sim_prior\": test_recon_cossim / (val_i + 1),\n",
    "            \"train/loss_blurry_total\": loss_blurry_total / (train_i + common_i + 1),\n",
    "            \"train/loss_blurry_cont_total\": loss_blurry_cont_total / (train_i + common_i + 1),\n",
    "            \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + common_i+1),\n",
    "            \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + common_i+1),\n",
    "            \"val/val_fwd_pct_correct\": val_fwd_percent_correct / (val_i + 1),\n",
    "            \"val/val_bwd_pct_correct\": val_bwd_percent_correct / (val_i + 1),\n",
    "            \"train/loss_nce\": loss_nce_sum / (train_i + common_i+1),\n",
    "            \"train/loss_prior\": loss_prior_sum / (train_i + common_i+ 1),\n",
    "            \"val/loss_nce\": val_loss_nce_sum / (val_i + 1),\n",
    "            \"val/loss_prior\": val_loss_prior_sum / (val_i + 1)}\n",
    "        progress_bar.set_postfix(**logs)\n",
    "\n",
    "        # Save model checkpoint and reconstruct\n",
    "        #save_ckpt(f'last')\n",
    "        if epoch % ckpt_interval == 0:\n",
    "            save_ckpt(f'mid_{epoch}')\n",
    "\n",
    "print(\"\\n===Finished!===\\n\")\n",
    "print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f} at epoch: {best_epoch}')\n",
    "if not utils.is_interactive():\n",
    "    sys.exit(0)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd878354-1a6e-4129-8e49-203c16b98ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls $outdir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d1b750-4efc-4a9e-bee3-8f850b1a9330",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_ckpt('mid_300')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbdb4874-364b-44c8-8662-d77ead80f1e6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
