{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9348c9df-ccc8-49e7-8bc1-83e72fe76628",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from tqdm import tqdm\n",
    "from datetime import datetime\n",
    "import webdataset as wds\n",
    "import PIL\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from dataloaders import ImageVoxelDataset,ImageVoxelAdapterDataset\n",
    "\n",
    "import argparse\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "local_rank = 0\n",
    "print(\"device:\",device)\n",
    "\n",
    "import utils\n",
    "from models import Clipper, OpenClipper, BrainNetwork, BrainDiffusionPrior, VersatileDiffusionPriorNetwork\n",
    "\n",
    "if utils.is_interactive():\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2\n",
    "\n",
    "seed=42\n",
    "utils.seed_everything(seed=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a29ab7e-be72-4947-8100-2c50f6a72ec1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='subj1_nl_sclip_basictest2_finetune_s5_aamax_100'\n",
    "save_name = 'subj1_nl_sclip_basictest2_finetune_s5_aamax_100_hl'\n",
    "data_path='data/new_dl/'\n",
    "subj=5\n",
    "ref_sub=1\n",
    "checkpoint_tag = 'mid_200'\n",
    "batch_size=1\n",
    "hidden=True\n",
    "norm_embs=True\n",
    "prior=True\n",
    "v2c=True\n",
    "seed=42\n",
    "use_projector=True\n",
    "hidden_dim=4096\n",
    "vd_cache_dir = 'data/fmri/cache'\n",
    "recons_per_sample=8\n",
    "img2img_strength=1\n",
    "split_id=1\n",
    "img_variations = False\n",
    "clip_seq_dim = 257\n",
    "clip_emb_dim = 768\n",
    "hidden_dim = 4096\n",
    "out_dim = clip_seq_dim * clip_emb_dim\n",
    "outdir = f'train_logs/{model_name}'\n",
    "retrieve = False\n",
    "plotting = False\n",
    "saving = True\n",
    "verbose = False\n",
    "imsize = 512\n",
    "og=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f45c44d1-b5f5-41c7-a7b7-27a62fb24918",
   "metadata": {},
   "outputs": [],
   "source": [
    "if split_id==1:\n",
    "    split_val=\"\"\n",
    "else:\n",
    "    split_val=split_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76451af6-eb18-462d-b92f-e62fd604d553",
   "metadata": {},
   "outputs": [],
   "source": [
    "if subj == 1:\n",
    "    num_voxels = 15724\n",
    "elif subj == 2:\n",
    "    num_voxels = 14278\n",
    "elif subj == 3:\n",
    "    num_voxels = 15226\n",
    "elif subj == 4:\n",
    "    num_voxels = 13153\n",
    "elif subj == 5:\n",
    "    num_voxels = 13039\n",
    "elif subj == 6:\n",
    "    num_voxels = 17907\n",
    "elif subj == 7:\n",
    "    num_voxels = 12682\n",
    "elif subj == 8:\n",
    "    num_voxels = 14386\n",
    "print(\"subj\",subj,\"num_voxels\",num_voxels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0063117-c2a1-43f3-9fa6-ea2ee14a588d",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_voxels_list = [15724, num_voxels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9acb28e2-152d-42cd-b208-09a8b0b16774",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test DL\n",
    "batch_size=1\n",
    "feature_file = f'data/new_dl/subj0{subj}/train/custom_split{split_val}/test/betas.pt'\n",
    "image_file = f'data/new_dl/subj0{subj}/train/custom_split{split_val}/test/images.pt'\n",
    "#Uncomment these two lines if you want to run on things data instead\n",
    "# feature_file = f'data/clean_things_data/dl/subj0{subj}/test/betas.pt'\n",
    "# image_file = f'data/clean_things_data/dl/subj0{subj}/test/images.pt'\n",
    "test_dataset = ImageVoxelDataset(feature_file, image_file)\n",
    "val_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "num_val=len(test_dataset)\n",
    "voxel,image = next(iter(val_dl))\n",
    "print(voxel.shape, image.shape)\n",
    "print(image.max(), image.min())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f8bd9fd-6fa3-4cac-8500-6d26362147d1",
   "metadata": {},
   "source": [
    "## Load VD Pipe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b866017-cab0-48d7-ac20-56c14b4668a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Creating versatile diffusion reconstruction pipeline...')\n",
    "from diffusers import VersatileDiffusionDualGuidedPipeline, UniPCMultistepScheduler\n",
    "from diffusers.models import DualTransformer2DModel\n",
    "try:\n",
    "    vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(\"shi-labs/versatile-diffusion\",vd_cache_dir).to(device).to(torch.float16)\n",
    "except:\n",
    "    print(\"Downloading Versatile Diffusion to\", vd_cache_dir)\n",
    "    vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(\n",
    "            \"shi-labs/versatile-diffusion\",\n",
    "            cache_dir = vd_cache_dir).to(device).to(torch.float16)\n",
    "vd_pipe.image_unet.eval()\n",
    "vd_pipe.vae.eval()\n",
    "vd_pipe.image_unet.requires_grad_(False)\n",
    "vd_pipe.vae.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16adc4bf-de1c-4d7f-9f91-b89ea1c989cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler_dir = 'data/fmri/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c14244-1717-4083-ab78-42679f899064",
   "metadata": {},
   "outputs": [],
   "source": [
    "#vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained(vd_cache_dir, subfolder=\"scheduler\")\n",
    "vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained(scheduler_dir, subfolder=\"scheduler\")\n",
    "num_inference_steps = 20\n",
    "\n",
    "# Set weighting of Dual-Guidance \n",
    "text_image_ratio = .0 # .5 means equally weight text and image, 0 means use only image\n",
    "for name, module in vd_pipe.image_unet.named_modules():\n",
    "    if isinstance(module, DualTransformer2DModel):\n",
    "        module.mix_ratio = text_image_ratio\n",
    "        for i, type in enumerate((\"text\", \"image\")):\n",
    "            if type == \"text\":\n",
    "                module.condition_lengths[i] = 77\n",
    "                module.transformer_index_for_condition[i] = 1  # use the second (text) transformer\n",
    "            else:\n",
    "                module.condition_lengths[i] = 257\n",
    "                module.transformer_index_for_condition[i] = 0  # use the first (image) transformer\n",
    "\n",
    "unet = vd_pipe.image_unet\n",
    "vae = vd_pipe.vae\n",
    "noise_scheduler = vd_pipe.scheduler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94c6ba35-68fd-4ccd-9b28-a2cad89523d3",
   "metadata": {},
   "source": [
    "## Build Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36d37dad-5ba3-4aa7-af34-ea74d60fc15a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class fMRIModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(fMRIModule, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "        \n",
    "model = fMRIModule()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9b2f43c-d2b0-4c99-9aed-2e1a069d8444",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RidgeRegression(torch.nn.Module):\n",
    "    # make sure to add weight_decay when initializing optimizer to enable regularization\n",
    "    def __init__(self, input_sizes, out_features):\n",
    "        super(RidgeRegression, self).__init__()\n",
    "        self.out_features = out_features\n",
    "        self.linears = torch.nn.ModuleList([\n",
    "                nn.Sequential(\n",
    "                    torch.nn.Linear(input_size, out_features),\n",
    "                    nn.LayerNorm(out_features),\n",
    "                    nn.GELU(),\n",
    "                    #nn.Dropout(0.5)\n",
    "                ) for input_size in input_sizes\n",
    "            ])\n",
    "    def forward(self, x, subj_idx):\n",
    "        out = self.linears[subj_idx](x).unsqueeze(1)\n",
    "        return out\n",
    "        \n",
    "model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a203e5f-b3d8-4615-82d3-3f53beb7fe95",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set in_dim as anything if you dont care about the lin0 module\n",
    "voxel2clip_kwargs = dict(in_dim=hidden_dim,out_dim=clip_emb_dim*clip_seq_dim,\\\n",
    "                         clip_size=clip_emb_dim, n_blocks=4, ext_ridge=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "272a1910-95e5-4071-b8b6-af2236ed2b5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "voxel2clip = BrainNetwork(**voxel2clip_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb88c86-cc2a-449b-860e-a9d404e6815e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.voxel2clip = voxel2clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a7bcc8d-395f-4998-bcd2-1e077c461244",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"params of voxel2clip:\")\n",
    "if local_rank==0:\n",
    "    utils.count_params(voxel2clip)\n",
    "    \n",
    "# setup prior network\n",
    "out_dim = clip_emb_dim\n",
    "depth = 6\n",
    "dim_head = 64\n",
    "heads = clip_emb_dim//64 # heads * dim_head = 12 * 64 = 768\n",
    "\n",
    "guidance_scale = 3.5\n",
    "timesteps = 100\n",
    "prior_network = VersatileDiffusionPriorNetwork(\n",
    "        dim=out_dim,\n",
    "        depth=depth,\n",
    "        dim_head=dim_head,\n",
    "        heads=heads,\n",
    "        causal=False,\n",
    "        num_tokens = 257,\n",
    "        learned_query_mode=\"pos_emb\"\n",
    "    ).to(device)\n",
    "print(\"prior_network loaded\")\n",
    "\n",
    "# custom version that can fix seeds\n",
    "diffusion_prior = BrainDiffusionPrior(\n",
    "    net=prior_network,\n",
    "    image_embed_dim=out_dim,\n",
    "    condition_on_text_encodings=False,\n",
    "    timesteps=timesteps,\n",
    "    cond_drop_prob=0.2,\n",
    "    image_embed_scale=None,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf99d79-93fa-4fca-acd0-3f2383e2d2f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.diffusion_prior = diffusion_prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87c2f2cb-5d20-4acf-bbf7-b5ee637a8f61",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.count_params(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acba4521-694e-4748-8e8b-9a496376e4b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = os.path.join(outdir, f'mid_200.pth')\n",
    "\n",
    "print(\"ckpt_path\",ckpt_path)\n",
    "checkpoint = torch.load(ckpt_path, map_location=device)\n",
    "state_dict = checkpoint['model_state_dict']\n",
    "model.load_state_dict(state_dict,strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f976dd3f-8142-4cb3-b29d-e45aeb684eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval().to(device)\n",
    "models = [model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e60b6a6-9677-4b41-a18d-547e93d236d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_extractor = Clipper(\"ViT-L/14\", hidden_state=True, norm_embs=True, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4bc39f7-7d85-4cc6-b042-deeb3f7773d3",
   "metadata": {},
   "source": [
    "## Recon one-at-a-time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60ea17c9-efcc-4e98-9666-241fd95ac213",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))\n",
    "\n",
    "if img_variations:\n",
    "    guidance_scale = 7.5\n",
    "else:\n",
    "    guidance_scale = 3.5\n",
    "    \n",
    "ind_include = np.arange(num_val)\n",
    "all_brain_recons = None\n",
    "    \n",
    "only_lowlevel = False\n",
    "if img2img_strength == 1:\n",
    "    img2img = False\n",
    "elif img2img_strength == 0:\n",
    "    img2img = True\n",
    "    only_lowlevel = True\n",
    "else:\n",
    "    img2img = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09fa5cd4-1ee2-4239-9f3c-5042e6e13def",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load low level images if interested but they are not used in our results\n",
    "# from torchvision import transforms\n",
    "#low_level_recons = torch.load('mapped_ll_bc_images_512.pt')\n",
    "# imsize = 512\n",
    "# low_level_recons = transforms.Resize((imsize,imsize))(low_level_recons)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "785c113e-b6a2-4f9d-9362-bc6c9ad60b1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for val_i, (voxel, img) in enumerate(tqdm(val_dl,total=len(ind_include))):\n",
    "    if val_i<np.min(ind_include):\n",
    "        continue\n",
    "    voxel = torch.mean(voxel,axis=1).to(device)\n",
    "    with torch.no_grad():\n",
    "        if img2img:\n",
    "            if og:\n",
    "                ae_preds = voxel2sd(voxel.float())\n",
    "                blurry_recons = vd_pipe.vae.decode(ae_preds.to(device).half()/0.18215).sample / 2 + 0.5\n",
    "    \n",
    "                if val_i==0:\n",
    "                    plt.imshow(utils.torch_to_Image(blurry_recons))\n",
    "                    plt.show()\n",
    "            else:\n",
    "                blurry_recons = low_level_recons[val_i].unsqueeze(0).to(device)\n",
    "                if val_i==0:\n",
    "                    plt.imshow(utils.torch_to_Image(blurry_recons))\n",
    "                    plt.show()\n",
    "        else:\n",
    "            #print(\"No Low Level\")\n",
    "            blurry_recons = None\n",
    "\n",
    "        if only_lowlevel:\n",
    "            brain_recons = blurry_recons\n",
    "        else:\n",
    "            grid, brain_recons, laion_best_picks, recon_img = utils.reconstruction2(\n",
    "                img, voxel,\n",
    "                clip_extractor, unet, vae, noise_scheduler,\n",
    "                voxel2clip_cls = None, \n",
    "                models = models,\n",
    "                text_token = None,\n",
    "                img_lowlevel = blurry_recons,\n",
    "                num_inference_steps = num_inference_steps,\n",
    "                n_samples_save = batch_size,\n",
    "                recons_per_sample = recons_per_sample,\n",
    "                guidance_scale = guidance_scale,\n",
    "                img2img_strength = img2img_strength, # 0=fully rely on img_lowlevel, 1=not doing img2img\n",
    "                timesteps_prior = 100,\n",
    "                seed = seed,\n",
    "                retrieve = retrieve,\n",
    "                plotting = plotting,\n",
    "                img_variations = img_variations,\n",
    "                verbose = verbose,\n",
    "                subj=2\n",
    "            )\n",
    "\n",
    "            if plotting:\n",
    "                plt.show()\n",
    "                # grid.savefig(f'evals/{model_name}_{val_i}.png')\n",
    "\n",
    "            brain_recons = brain_recons[:,laion_best_picks.astype(np.int8)]\n",
    "\n",
    "        if all_brain_recons is None:\n",
    "            all_brain_recons = brain_recons\n",
    "            all_images = img\n",
    "        else:\n",
    "            all_brain_recons = torch.vstack((all_brain_recons,brain_recons))\n",
    "            all_images = torch.vstack((all_images,img))\n",
    "\n",
    "    if val_i>=np.max(ind_include):\n",
    "        break\n",
    "\n",
    "all_brain_recons = all_brain_recons.view(-1,3,imsize,imsize)\n",
    "print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))\n",
    "\n",
    "if saving:\n",
    "    torch.save(all_images,f'recons/all_images_subj{subj}_nsd_split{split_id}.pt')\n",
    "    torch.save(all_brain_recons,f'recons/{save_name}_recons_img2img{img2img_strength}_{recons_per_sample}samples.pt')\n",
    "print(f'recon_path: recons/{save_name}_recons_img2img{img2img_strength}_{recons_per_sample}samples')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "980cb0fc-513c-4fd4-90d8-241118b3a186",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
