{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download checkpoints and build models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import torch, torchvision\n",
    "import random\n",
    "import numpy as np\n",
    "import PIL.Image as PImage, PIL.ImageDraw as PImageDraw\n",
    "setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed\n",
    "setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed\n",
    "from models import VQVAE, build_vae_var\n",
    "from attacks import apply_attack\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "import torch\n",
    "gpu_id = 0\n",
    "torch.cuda.set_device(gpu_id)\n",
    "print(torch.cuda.current_device())\n",
    "device = torch.device(f\"cuda:{gpu_id}\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "data_dir_ae = \"[DATA_SAVE_DIR]/var/finetune\"\n",
    "\n",
    "data_dir_ar = \"[VAR_MODEL_PATH]\"\n",
    "\n",
    "MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====\n",
    "assert MODEL_DEPTH in {16, 20, 24, 30}\n",
    "\n",
    "\n",
    "# download checkpoint\n",
    "hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'\n",
    "# vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'\n",
    "vae_ckpt, var_ckpt = 'vqvae_finetuned_fmap_lpips0.0_mse_img0.0_mse_feat1.0_steps20000_encoder_lr5e-05_bs16_with_dataset.pth', f'var_d{MODEL_DEPTH}.pth'\n",
    "# vae_ckpt, var_ckpt = 'vqvae_finetuned_fmap_lpips0.0_mse_img0.0_mse_feat1.0_keep_real10.0_steps20000_encoder_lr5e-05_bs8_with_dataset.pth', f'var_d{MODEL_DEPTH}.pth'\n",
    "\n",
    "vae_ckpt_path, var_ckpt_path = osp.join(data_dir_ae, vae_ckpt), osp.join(data_dir_ar, var_ckpt)\n",
    "if not osp.exists(vae_ckpt_path): \n",
    "    os.system(f'wget -P {data_dir_ae} {hf_home}/{vae_ckpt}')\n",
    "if not osp.exists(var_ckpt_path): \n",
    "    os.system(f'wget -P {data_dir_ar} {hf_home}/{var_ckpt}')\n",
    "\n",
    "# build vae, var\n",
    "patch_nums = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]\n",
    "# device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "if 'vae' not in globals() or 'var' not in globals():\n",
    "    vae, var = build_vae_var(\n",
    "        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters\n",
    "        device=device, patch_nums=patch_nums,\n",
    "        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,\n",
    "    )\n",
    "\n",
    "# load checkpoints\n",
    "vae.load_state_dict(torch.load(vae_ckpt_path, map_location='cpu'), strict=True)\n",
    "var.load_state_dict(torch.load(var_ckpt_path, map_location='cpu'), strict=True)\n",
    "vae.eval(), var.eval()\n",
    "for p in vae.parameters(): p.requires_grad_(False)\n",
    "for p in var.parameters(): p.requires_grad_(False)\n",
    "print(f'prepare finished.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vae.quantize.using_znorm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# configs for attack\n",
    "n_img = 1000\n",
    "batch_size = 64\n",
    "optim_iters = 1\n",
    "attack_type = \"none\"\n",
    "range_map = {\n",
    "    \"none\": [0.0],  # no attack\n",
    "    \"noise\": [0.1], #[0.0, 0.05, 0.1, 0.15, 0.2],\n",
    "    \"gauss\": [7], #[1, 3, 5, 7, 9, 11, 13, 15, 17, 19],\n",
    "    \"crop\": [0.5], #[1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],\n",
    "    \"jpeg\": [50], #[100, 90, 80, 70, 60, 50, 40, 30, 20, 10],\n",
    "    \"rotate\": [(-5,5)],\n",
    "    \"CtrlRegen\": [0.1], #[0, 0.1, 0.2, 0.3, 0.4, 0.5]\n",
    "}\n",
    "\n",
    "args_map = {\n",
    "    \"none\" : \"none\",\n",
    "    \"noise\" : \"variance\",\n",
    "    \"gauss\" : \"kernel_size\",\n",
    "    \"crop\" : \"crop_ratio\",\n",
    "    \"jpeg\" : \"final_quality\",\n",
    "    \"rotate\" : \"degrees\",\n",
    "    \"CtrlRegen\" : \"ctrl_regen_steps\"\n",
    "}\n",
    "\n",
    "import argparse\n",
    "parser = argparse.ArgumentParser()\n",
    "args, unknown = parser.parse_known_args()\n",
    "args.num_samples = n_img\n",
    "args.variance = 0.1  # default noise variance\n",
    "attack_strength = range_map[attack_type][0]  # default attack strength\n",
    "args.__dict__[args_map[attack_type]] = attack_strength"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample tokens and images with classifier-free guidance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "regenerate = False\n",
    "if regenerate:\n",
    "    import random\n",
    "    from tqdm import tqdm\n",
    "    # set args\n",
    "    seed = 42 #@param {type:\"number\"}\n",
    "    torch.manual_seed(seed)\n",
    "    num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
    "    cfg = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
    "    # class_labels = (720, 983, 756, 226, 121, 338, 725, 560) #(980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:\"raw\"}\n",
    "    more_smooth = False # True for more smooth output\n",
    "    save = False\n",
    "\n",
    "    # seed\n",
    "    torch.manual_seed(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    # run faster\n",
    "    tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = bool(tf32)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n",
    "    torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n",
    "\n",
    "    # sample\n",
    "    # split into batches\n",
    "    for i_batch in tqdm(range(0, n_img, batch_size)):\n",
    "        class_labels = tuple([random.randint(0,1000) for _ in range(batch_size)])\n",
    "        B = len(class_labels)\n",
    "        label_B: torch.LongTensor = torch.tensor(class_labels, device=device)\n",
    "        with torch.no_grad():\n",
    "            # with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster\n",
    "            with torch.autocast('cuda', enabled=True, cache_enabled=True):    # using bfloat16 can be faster\n",
    "                # recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
    "                recon_B3HW_batch, original_idxBl_batch, token_map = var.autoregressive_infer_cfg_with_token_map(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
    "        # print(recon_B3HW_batch.shape)\n",
    "        # concatenate batch results\n",
    "        if i_batch == 0:\n",
    "            recon_B3HW = recon_B3HW_batch\n",
    "        else:\n",
    "            recon_B3HW = torch.cat((recon_B3HW, recon_B3HW_batch), dim=0)\n",
    "        if i_batch == 0:\n",
    "            original_idxBl = original_idxBl_batch\n",
    "        else:\n",
    "            for i_scale in range(len(original_idxBl_batch)):\n",
    "                original_idxBl[i_scale] = torch.cat([original_idxBl[i_scale],original_idxBl_batch[i_scale]], dim=0)\n",
    "        print(len(original_idxBl), original_idxBl[0].shape)\n",
    "        print(original_idxBl)\n",
    "        chw = torchvision.utils.make_grid(recon_B3HW_batch, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "    recon_B3HW_save = recon_B3HW.clone()\n",
    "    print(f'Generated {len(recon_B3HW_save)} images with classifier-free guidance.')\n",
    "    # save all the generated images to png format in /generated\n",
    "    output_dir = '[DATA_SAVE_DIR]/var/var_generated/'\n",
    "    if not osp.exists(output_dir):\n",
    "        os.makedirs(output_dir)\n",
    "    if save:\n",
    "        for i in range(len(recon_B3HW_save)):\n",
    "            img = recon_B3HW_save[i].permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)\n",
    "            img = PImage.fromarray(img)\n",
    "            img.save(osp.join(output_dir, f'{i:03d}.png'))\n",
    "        print(f'Saved to {output_dir}.')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculating different losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compare how many tokens match for the two token maps on each scale\n",
    "import torch.nn.functional as F\n",
    "def compare_token_maps(map1, map2, return_per_scale=False):\n",
    "    total_equal = []\n",
    "    total_elements = []\n",
    "    overlapping_ratio = []\n",
    "    \n",
    "    for tensor_list1, tensor_list2 in zip(map1, map2):\n",
    "        total_equal_scale = []\n",
    "        total_elements_scale = []\n",
    "        if tensor_list1.shape != tensor_list2.shape:\n",
    "            raise ValueError(f\"Tensors have different shapes\")\n",
    "        \n",
    "        # Create a boolean mask of equal elements\n",
    "        equal_mask = tensor_list1.cpu() == tensor_list2.cpu() # Bl (binary)\n",
    "\n",
    "        # Count equal elements\n",
    "        equal_count = equal_mask.sum(1) # B\n",
    "        element_count = torch.Tensor([tensor_list1[i].numel() for i in range(len(tensor_list1))]) # B\n",
    "\n",
    "        total_equal_scale.append(equal_count)\n",
    "        total_elements_scale.append(element_count)\n",
    "        overlapping_ratio.append(equal_count / element_count)\n",
    "\n",
    "    if return_per_scale:\n",
    "        return total_equal, total_elements, overlapping_ratio\n",
    "    else:\n",
    "        return [equal.sum() for equal in total_equal], [element.sum() for element in total_elements], [ratio.mean() for ratio in overlapping_ratio]\n",
    "\n",
    "def compare_embeddings(embed_scale1, embed_scale2): #List[BhwC]\n",
    "    error_scales = []\n",
    "    for embed1, embed2 in zip(embed_scale1, embed_scale2):\n",
    "        error = F.mse_loss(embed1, embed2, reduction='none')\n",
    "        error = error.mean(dim=(1,2,3)).cpu()\n",
    "        error_scales.append(error)\n",
    "    return error_scales"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Define the losses without the optimized token search\n",
    "'''\n",
    "# the losses with the quantization\n",
    "def calculate_loss_batch(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):\n",
    "\n",
    "    # first reconstruction\n",
    "    recon_img, recon_idxBl, f, fhat, embeddings = vae.img_to_reconstructed_img_with_intermediates(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)\n",
    "    # calculate overlapping ratio only when the original tokens are provided\n",
    "    if original_idxBl is not None:\n",
    "        total_equal, total_elements, overlapping_ratio = compare_token_maps(original_idxBl, recon_idxBl, return_per_scale=True)\n",
    "        print([ratio.mean() for ratio in overlapping_ratio])\n",
    "    else:\n",
    "        overlapping_ratio = []\n",
    "\n",
    "    recon_img_show = recon_img.clone().add_(1).mul_(0.5)\n",
    "\n",
    "    feature_map_mse = F.mse_loss(fhat, f, reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    recon_img_mse = F.mse_loss(recon_img, original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    embed_mse_current = compare_embeddings(embeddings[\"current_resolution\"][\"original\"], embeddings[\"current_resolution\"][\"quantized\"])\n",
    "    embed_mse_full = compare_embeddings(embeddings[\"full_resolution\"][\"original\"], embeddings[\"full_resolution\"][\"quantized\"])\n",
    "\n",
    "    # second reconstruction\n",
    "    recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_with_intermediates(recon_img, last_one=True)\n",
    "    recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)\n",
    "\n",
    "    feature_map_mse_2nd = F.mse_loss(fhat_2nd, f_2nd, reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    recon_img_mse_2nd = F.mse_loss(recon_img_2nd, recon_img.clone(), reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    embed_mse_current_2nd = compare_embeddings(embeddings_2nd[\"current_resolution\"][\"original\"], embeddings_2nd[\"current_resolution\"][\"quantized\"])\n",
    "    embed_mse_full_2nd = compare_embeddings(embeddings_2nd[\"full_resolution\"][\"original\"], embeddings_2nd[\"full_resolution\"][\"quantized\"])\n",
    "\n",
    "    feature_map_mse_ratio = feature_map_mse / feature_map_mse_2nd\n",
    "    recon_img_mse_ratio = recon_img_mse / recon_img_mse_2nd\n",
    "    embed_mse_current_ratio = [embed_mse_current[i] / embed_mse_current_2nd[i] for i in range(len(embed_mse_current))]\n",
    "    embed_mse_full_ratio = [embed_mse_full[i] / embed_mse_full_2nd[i] for i in range(len(embed_mse_full))]\n",
    "\n",
    "    # update the results\n",
    "    if len(all_dataset_results[\"overlapping\"][\"all\"][dataset_name]) == 0:\n",
    "        all_dataset_results[\"overlapping\"][\"all\"][dataset_name] = overlapping_ratio\n",
    "    else:\n",
    "        for scale in range(len(overlapping_ratio)):\n",
    "            all_dataset_results[\"overlapping\"][\"all\"][dataset_name][scale] = torch.cat((all_dataset_results[\"overlapping\"][\"all\"][dataset_name][scale], overlapping_ratio[scale]))\n",
    "\n",
    "    all_dataset_results[\"feature_map\"][\"1st\"][dataset_name].extend(feature_map_mse)\n",
    "    all_dataset_results[\"rec\"][\"1st\"][dataset_name].extend(recon_img_mse)\n",
    "    for i in range(len(patch_nums)):\n",
    "        all_dataset_results[f\"embedding_current_{i}\"][\"1st\"][dataset_name].extend(embed_mse_current[i])\n",
    "        all_dataset_results[f\"embedding_full_{i}\"][\"1st\"][dataset_name].extend(embed_mse_full[i])\n",
    "\n",
    "    all_dataset_results[\"feature_map\"][\"2nd\"][dataset_name].extend(feature_map_mse_2nd)\n",
    "    all_dataset_results[\"rec\"][\"2nd\"][dataset_name].extend(recon_img_mse_2nd)\n",
    "    for i in range(len(patch_nums)):\n",
    "        all_dataset_results[f\"embedding_current_{i}\"][\"2nd\"][dataset_name].extend(embed_mse_current_2nd[i])\n",
    "        all_dataset_results[f\"embedding_full_{i}\"][\"2nd\"][dataset_name].extend(embed_mse_full_2nd[i])\n",
    "\n",
    "    all_dataset_results[\"feature_map\"][\"ratio\"][dataset_name].extend(feature_map_mse_ratio)\n",
    "    all_dataset_results[\"rec\"][\"ratio\"][dataset_name].extend(recon_img_mse_ratio)\n",
    "    for i in range(len(patch_nums)):\n",
    "        all_dataset_results[f\"embedding_current_{i}\"][\"ratio\"][dataset_name].extend(embed_mse_current_ratio[i])\n",
    "        all_dataset_results[f\"embedding_full_{i}\"][\"ratio\"][dataset_name].extend(embed_mse_full_ratio[i])\n",
    "\n",
    "    # display the images\n",
    "    if display_img:\n",
    "        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "    return all_dataset_results\n",
    "\n",
    "# the losses without the quantization\n",
    "def calculate_loss_batch_no_quant(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):\n",
    "\n",
    "    # first reconstruction\n",
    "    recon_img, recon_idxBl, f, fhat, embeddings = vae.img_to_reconstructed_img_without_quant(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)\n",
    "    # recon_img_, recon_idxBl_, f_, fhat_, embeddings_ = vae.img_to_reconstructed_img_with_intermediates(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)\n",
    "    # calculate overlapping ratio only when the original tokens are provided\n",
    "    if original_idxBl is not None:\n",
    "        total_equal, total_elements, overlapping_ratio = compare_token_maps(original_idxBl, recon_idxBl, return_per_scale=True)\n",
    "        print(overlapping_ratio)\n",
    "    else:\n",
    "        overlapping_ratio = []\n",
    "    \n",
    "\n",
    "    recon_img_show = recon_img.clone().add_(1).mul_(0.5)\n",
    "\n",
    "\n",
    "    recon_img_mse = F.mse_loss(recon_img.clone(), original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "\n",
    "    # second reconstruction\n",
    "    # recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_without_quant(recon_img.clone(), last_one=True)\n",
    "    recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_without_quant(recon_img.clone(), last_one=True)\n",
    "    # recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_with_intermediates(recon_img.clone(), last_one=True)\n",
    "    recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)\n",
    "\n",
    "\n",
    "    recon_img_mse_2nd = F.mse_loss(recon_img_2nd.clone(), recon_img.clone(), reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "\n",
    "    recon_img_mse_ratio = recon_img_mse / recon_img_mse_2nd\n",
    "    # update the results\n",
    "    if len(all_dataset_results[\"overlapping_no_quant\"][\"all\"][dataset_name]) == 0:\n",
    "        all_dataset_results[\"overlapping_no_quant\"][\"all\"][dataset_name] = overlapping_ratio\n",
    "    else:\n",
    "        for scale in range(len(overlapping_ratio)):\n",
    "            all_dataset_results[\"overlapping_no_quant\"][\"all\"][dataset_name][scale] = torch.cat((all_dataset_results[\"overlapping_no_quant\"][\"all\"][dataset_name][scale], overlapping_ratio[scale]))\n",
    "\n",
    "    all_dataset_results[\"rec_no_quant\"][\"1st\"][dataset_name].extend(recon_img_mse)\n",
    "\n",
    "    all_dataset_results[\"rec_no_quant\"][\"2nd\"][dataset_name].extend(recon_img_mse_2nd)\n",
    "\n",
    "    all_dataset_results[\"rec_no_quant\"][\"ratio\"][dataset_name].extend(recon_img_mse_ratio)\n",
    "\n",
    "    # display the images\n",
    "    if display_img:\n",
    "        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "    return all_dataset_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimized token search\n",
    "def f_to_idxBl_and_fhat_optimized(f_hat_target, iters=optim_iters):\n",
    "    idxBl_rec_first, f_rest, _, embedhat_scales = vae.quantize.f_to_idxBl_or_fhat_with_f_rest(f_hat_target.clone(), to_fhat=False)\n",
    "    idxBl_rec, f_refine,  mse_refine = vae.quantize.refine_soft_assign(f_hat_target, init_idx_Bl=idxBl_rec_first, iters=iters, lr=0.1, entropy_weight=0.0)\n",
    "    idxBl_rec_topk_re = idxBl_rec\n",
    "    # freeze the tokens layer by layer\n",
    "    for i in range(9, 4, -1):\n",
    "        idxBl_rec_topk_re, f_refine_re, _ = vae.quantize.refine_soft_assign(f_hat_target, init_idx_Bl=idxBl_rec_topk_re, iters=iters, lr=0.1, entropy_weight=0.0, v_patch_nums=patch_nums, fix_scale=i)\n",
    "    \n",
    "    return idxBl_rec_topk_re, f_refine_re\n",
    "\n",
    "# embed_BChw = vae.idxBl_to_embedhat(original_idxBl) # original_idxBl is the original token\n",
    "# f_hat_target = vae.quantize.embedhat_to_fhat(\n",
    "#     embed_BChw, all_to_max_scale=True, last_one=True\n",
    "# )\n",
    "\n",
    "def img_to_reconstructed_img_with_optim(original_B3HW):\n",
    "    f_gen = vae.quant_conv(vae.encoder(original_B3HW.clone()))\n",
    "    idxBl_rec_topk_re, f_refine_re = f_to_idxBl_and_fhat_optimized(f_gen)\n",
    "    # feature_map_mse = F.mse_loss(f_gen, f_refine_re, reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    # print(f'Feature map MSE: {feature_map_mse}')\n",
    "    # rec_gen_img = vae.fhat_to_img(f_gen)\n",
    "    rec_gen_img = vae.decoder(vae.post_quant_conv(f_refine_re)).clamp_(-1, 1)\n",
    "    return rec_gen_img, idxBl_rec_topk_re, f_gen, f_refine_re\n",
    "\n",
    "\n",
    "# chw = torchvision.utils.make_grid(rec_gen_img.clone().detach().add_(1).mul_(0.5), nrow=8, padding=0, pad_value=1.0)\n",
    "# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "# chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "# chw.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Define the losses with the optimized token search\n",
    "'''\n",
    "def calculate_loss_batch_with_optim(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):\n",
    "\n",
    "    # first reconstruction\n",
    "    recon_img, recon_idxBl, f, fhat = img_to_reconstructed_img_with_optim(original_B3HW.clone().mul_(2).add_(-1).float())\n",
    "    # calculate overlapping ratio only when the original tokens are provided\n",
    "    if original_idxBl is not None:\n",
    "        total_equal, total_elements, overlapping_ratio = compare_token_maps(original_idxBl, recon_idxBl, return_per_scale=True)\n",
    "        print([ratio.mean() for ratio in overlapping_ratio])\n",
    "    else:\n",
    "        overlapping_ratio = []\n",
    "\n",
    "    recon_img_show = recon_img.clone().add_(1).mul_(0.5)\n",
    "\n",
    "    feature_map_mse = F.mse_loss(fhat, f, reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    recon_img_mse = F.mse_loss(recon_img, original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "\n",
    "    # second reconstruction\n",
    "    recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd = img_to_reconstructed_img_with_optim(recon_img)\n",
    "    recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)\n",
    "\n",
    "    feature_map_mse_2nd = F.mse_loss(fhat_2nd, f_2nd, reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "    recon_img_mse_2nd = F.mse_loss(recon_img_2nd, recon_img.clone(), reduction='none').mean(dim=(1,2,3)).cpu()\n",
    "\n",
    "    feature_map_mse_ratio = feature_map_mse / feature_map_mse_2nd\n",
    "    recon_img_mse_ratio = recon_img_mse / recon_img_mse_2nd\n",
    "\n",
    "    # update the results\n",
    "    if len(all_dataset_results[\"overlapping_optim\"][\"all\"][dataset_name]) == 0:\n",
    "        all_dataset_results[\"overlapping_optim\"][\"all\"][dataset_name] = overlapping_ratio\n",
    "    else:\n",
    "        for scale in range(len(overlapping_ratio)):\n",
    "            all_dataset_results[\"overlapping_optim\"][\"all\"][dataset_name][scale] = torch.cat((all_dataset_results[\"overlapping_optim\"][\"all\"][dataset_name][scale], overlapping_ratio[scale]))\n",
    "\n",
    "    all_dataset_results[\"feature_map_optim\"][\"1st\"][dataset_name].extend(feature_map_mse)\n",
    "    all_dataset_results[\"rec_optim\"][\"1st\"][dataset_name].extend(recon_img_mse)\n",
    "\n",
    "    all_dataset_results[\"feature_map_optim\"][\"2nd\"][dataset_name].extend(feature_map_mse_2nd)\n",
    "    all_dataset_results[\"rec_optim\"][\"2nd\"][dataset_name].extend(recon_img_mse_2nd)\n",
    "\n",
    "    all_dataset_results[\"feature_map_optim\"][\"ratio\"][dataset_name].extend(feature_map_mse_ratio)\n",
    "    all_dataset_results[\"rec_optim\"][\"ratio\"][dataset_name].extend(recon_img_mse_ratio)\n",
    "\n",
    "    # display the images\n",
    "    if display_img:\n",
    "        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)\n",
    "        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
    "        chw = PImage.fromarray(chw.astype(np.uint8))\n",
    "        chw.show()\n",
    "\n",
    "    return all_dataset_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize the loss results\n",
    "all_dataset_results = {\n",
    "    # for losses with the optimized token search\n",
    "    \"feature_map\": {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "\n",
    "    },\n",
    "    \"rec\": {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "    },\n",
    "    \"rec_no_quant\": {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "    },\n",
    "    \"overlapping\": {\n",
    "        \"all\": {}\n",
    "    },\n",
    "    \"overlapping_no_quant\": {\n",
    "        \"all\": {}\n",
    "    },\n",
    "    # for losses with the optimized token search\n",
    "    \"feature_map_optim\": {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "\n",
    "    },\n",
    "    \"rec_optim\": {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "    },\n",
    "    \"overlapping_optim\": {\n",
    "        \"all\": {}\n",
    "    }\n",
    "}\n",
    "for i in range(len(patch_nums)):\n",
    "    all_dataset_results[f\"embedding_current_{i}\"] = {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "    }\n",
    "    all_dataset_results[f\"embedding_full_{i}\"] = {\n",
    "        \"1st\": {},\n",
    "        \"2nd\": {},\n",
    "        \"ratio\": {},\n",
    "        # \"ratio_calibrated\": {},\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate all the losses for all given datasets\n",
    "dataset_name_image_path = {\n",
    "    \"LAION\": \"PATH_TO_LAION_SUBSET\",   # TODO: =====> please specify the path to your LAION subset <=====\n",
    "    \"MS-COCO\": \"PATH_TO_MSCOCO_SUBSET\", # TODO: =====> please specify the path to your MS-COCO subset <=====\n",
    "    \"ImageNet (val)\": \"PATH_TO_IMAGENET_VAL_SUBSET\", # TODO: =====> please specify the path to your ImageNet validation subset <=====\n",
    "    \"ImageNet (train)\": \"PATH_TO_IMAGENET_TRAIN_SUBSET\", # TODO: =====> please specify the path to your ImageNet training subset <=====\n",
    "    \"RAR Generated\": \"PATH_TO_RAR_GENERATED_SUBSET\", # TODO: =====> please specify the path to your RAR generated images subset <=====\n",
    "    \"VAR Generated\": \"PATH_TO_VAR_GENERATED_SUBSET\", # TODO: =====> please specify the path to your VAR generated images subset <=====\n",
    "    \"LlamaGen Generated\": \"PATH_TO_LLAMAGEN_GENERATED_SUBSET\", # TODO: =====> please specify the path to your LlamaGen generated images subset <=====\n",
    "    \"Taming Generated\": \"PATH_TO_TAMING_GENERATED_SUBSET\", # TODO: =====> please specify the path to your Taming generated images subset <=====\n",
    "    \"Infinity Generated\": \"PATH_TO_INFINITY_GENERATED_SUBSET\", # TODO: =====> please specify the path to your Infinity generated images subset <=====\n",
    "}\n",
    "\n",
    "var_dataset_name = \"VAR Generated\"\n",
    "\n",
    "for dataset_name in dataset_name_image_path.keys():\n",
    "    # dataset_names.append(dataset_name)\n",
    "    for loss_type in all_dataset_results.keys():\n",
    "        for loss_round in all_dataset_results[loss_type].keys():\n",
    "            all_dataset_results[loss_type][loss_round][dataset_name] = []\n",
    "\n",
    "    image_path = dataset_name_image_path[dataset_name]\n",
    "    print(f'reading from {image_path}')\n",
    "    load_token_map = (dataset_name == var_dataset_name)\n",
    "    dataset = apply_attack(img_path=image_path, attack=attack_type, load_token_map=load_token_map, args=args)\n",
    "    print(f'dataset length: {len(dataset)}')\n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "    for i, batch in enumerate(dataloader):\n",
    "        image_batch, token_batch = batch\n",
    "        original_B3HW, original_idxBl = image_batch, token_batch\n",
    "        original_B3HW = original_B3HW.cuda()\n",
    "        # the main losses\n",
    "        display_img = False\n",
    "        if dataset_name == var_dataset_name:\n",
    "            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]\n",
    "            all_dataset_results = calculate_loss_batch(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)\n",
    "        else:\n",
    "            all_dataset_results = calculate_loss_batch(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)\n",
    "        # the rec loss without quant\n",
    "        if dataset_name == var_dataset_name:\n",
    "            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]\n",
    "            all_dataset_results = calculate_loss_batch_no_quant(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)\n",
    "        else:\n",
    "            all_dataset_results = calculate_loss_batch_no_quant(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)\n",
    "        # the losses with the optim\n",
    "        # print(f'with optim search')\n",
    "        if dataset_name == var_dataset_name:\n",
    "            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]\n",
    "            all_dataset_results = calculate_loss_batch_with_optim(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)\n",
    "        else:\n",
    "            all_dataset_results = calculate_loss_batch_with_optim(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)\n",
    "        print(i)\n",
    "        if i < 2:\n",
    "            print(all_dataset_results[\"overlapping\"][\"all\"][dataset_name])\n",
    "            print(len(all_dataset_results[\"rec\"][\"1st\"][dataset_name]))\n",
    "            print(len(all_dataset_results[\"rec_no_quant\"][\"1st\"][dataset_name]))\n",
    "            print(111)\n",
    "\n",
    "        # ms_h_BChw = vae.idxBl_to_embedhat(original_idxBl.clone())\n",
    "        # f_hat_target = vae.quantize.embedhat_to_fhat(\n",
    "        #     ms_h_BChw, all_to_max_scale=True, last_one=True\n",
    "        # ).detach()\n",
    "        # # Also decode the target f_hat to get the target image reconstruction\n",
    "        # image_reconstructed_target = vae.decoder(vae.post_quant_conv(f_hat_target)).clamp(-1, 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize the probability distribution function of codebook losses for both real and generated images with seaborn\n",
    "from utils.plot import plot_multi_pdf\n",
    "from sklearn import metrics\n",
    "import pandas as pd\n",
    "\n",
    "def evaluate(rar_scores, other_scores):\n",
    "    all_labels = np.concatenate([np.zeros(len(other_scores)), np.ones(len(rar_scores))])\n",
    "    all_scores = np.concatenate([other_scores, rar_scores])\n",
    "    all_scores_inverted = -all_scores\n",
    "    # all_labels = np.concatenate([np.zeros(len(rar_scores)), np.ones(len(other_scores))])\n",
    "    # all_scores = np.concatenate([rar_scores, other_scores])\n",
    "\n",
    "    # fpr, tpr, threshold = metrics.roc_curve(all_labels, all_scores)\n",
    "    fpr, tpr, threshold_inverted = metrics.roc_curve(all_labels, all_scores_inverted)\n",
    "    auc = metrics.auc(fpr, tpr)\n",
    "    acc = np.max(1 - (fpr + (1 - tpr))/2)\n",
    "\n",
    "    idx = np.where(fpr < 0.01)[0][-1]\n",
    "    threshold_at_1fpr = -threshold_inverted[idx]\n",
    "    tpr_at_1fpr = tpr[idx]\n",
    "\n",
    "    return threshold_at_1fpr, auc, acc, tpr_at_1fpr\n",
    "\n",
    "\n",
    "print(f'Results for setting: VAR-d[{MODEL_DEPTH}], {attack_type}({args_map[attack_type]}={args.__dict__[args_map[attack_type]]}), {n_img}imgs, {optim_iters}iters')\n",
    "result_dir = f'results/VAR-d{MODEL_DEPTH}/{attack_type}({args_map[attack_type]}={args.__dict__[args_map[attack_type]]})_{n_img}imgs_{optim_iters}iters'\n",
    "os.makedirs(result_dir, exist_ok=True)\n",
    "\n",
    "results_table = []\n",
    "\n",
    "for loss_type, loss_type_results in all_dataset_results.items(): # codebook, rec\n",
    "    if \"overlapping\" in loss_type or \"full\" in loss_type or \"current\" in loss_type:\n",
    "        continue\n",
    "    for loss_round, loss_round_results in loss_type_results.items(): # 1st, 2nd, ratio, ratio_calibrated\n",
    "        print(f\"[{loss_type}][{loss_round}]\")\n",
    "        data_list, label_list = [], [f\"{dataset_name}\" for dataset_name in dataset_name_image_path.keys()]\n",
    "        for dataset_name, dataset_results in loss_round_results.items(): # Real, Real (train), VAR Generated, RAR Generated\n",
    "            print(f\"{dataset_name} dataset size {len(dataset_results)}\")\n",
    "            data_list.append(np.array(dataset_results))\n",
    "        # plot\n",
    "        title = f'{loss_type}_{loss_round}_mse'\n",
    "        xlabel = 'Loss'\n",
    "        ylabel = 'PDF'\n",
    "        plot_multi_pdf(data_list, label_list, title, xlabel, ylabel, save_dir=result_dir)\n",
    "        \n",
    "        # quantitative\n",
    "        results_table_single_losses = []\n",
    "        for dataset_name, dataset_results in loss_round_results.items():  # Real, VAR Generated, etc.\n",
    "            if dataset_name != var_dataset_name:\n",
    "                threshold, auc, acc, tpr1 = evaluate(\n",
    "                    np.array(loss_round_results[var_dataset_name]),\n",
    "                    np.array(dataset_results)\n",
    "                )\n",
    "                results_table.append({\n",
    "                    \"Loss Type\": loss_type,\n",
    "                    \"Loss Round\": loss_round,\n",
    "                    \"Comparison\": f\"VAR Generated vs {dataset_name}\",\n",
    "                    \"Threshold\": round(threshold, 4),\n",
    "                    \"AUC\": round(auc, 4),\n",
    "                    \"ACC\": round(acc, 4),\n",
    "                    \"TPR@1%FPR\": round(tpr1, 4)\n",
    "                })\n",
    "                results_table_single_losses.append({\n",
    "                    \"Comparison\": f\"VAR Generated vs {dataset_name}\",\n",
    "                    \"Threshold\": round(threshold, 4),\n",
    "                    \"AUC\": round(auc, 4),\n",
    "                    \"ACC\": round(acc, 4),\n",
    "                    \"TPR@1%FPR\": round(tpr1, 4)\n",
    "                })\n",
    "        df_single = pd.DataFrame(results_table_single_losses)\n",
    "        print(df_single.to_string(index=False), flush=True)\n",
    "        print('\\n')\n",
    "import pandas as pd\n",
    "df = pd.DataFrame(results_table)\n",
    "print(df.to_string(index=False))\n",
    "# save the results\n",
    "df.to_csv(os.path.join(result_dir, 'results.csv'), index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df.to_string(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "\n",
    "# --- your input DataFrame ---\n",
    "# df = ...  # must contain columns: [\"Loss Type\",\"Loss Round\",\"Comparison\",\"Threshold\",\"AUC\",\"ACC\",\"TPR@1%FPR\"]\n",
    "\n",
    "# Choose which metric to display in the table\n",
    "metric = \"TPR@1%FPR\"  # or \"ACC\", \"TPR@1%FPR\", \"Threshold\"\n",
    "\n",
    "# Desired column order\n",
    "col_order = [\n",
    "    \"ImageNet (train)\",\n",
    "    \"ImageNet (val)\",\n",
    "    \"LAION\",\n",
    "    \"MS-COCO\",\n",
    "    \"LlamaGen\",\n",
    "    \"RAR\",\n",
    "    \"Taming\",\n",
    "    \"VAR\",\n",
    "    \"Infinity\",\n",
    "]\n",
    "\n",
    "aliases = {\n",
    "    \"imagenet (train)\": \"ImageNet (train)\",\n",
    "    \"imagenet (val)\": \"ImageNet (val)\",\n",
    "    \"imagenet\": \"ImageNet\",  # fallback, usually won't be used\n",
    "    \"laion\": \"LAION\",\n",
    "    \"ms-coco\": \"MS-COCO\",\n",
    "    \"llamagen\": \"LlamaGen\",\n",
    "    \"rar\": \"RAR\",\n",
    "    \"taming\": \"Taming\",\n",
    "    \"var\": \"VAR\",\n",
    "    \"infinity\": \"Infinity\",\n",
    "}\n",
    "\n",
    "def extract_dataset(x: str) -> str:\n",
    "    \"\"\"Extract the dataset name from the Comparison string.\"\"\"\n",
    "    m = re.search(r\"\\bvs\\b\\s+(.*)$\", str(x), flags=re.IGNORECASE)\n",
    "    if not m:\n",
    "        return \"\"\n",
    "    right = m.group(1)\n",
    "    right = re.sub(r\"\\bGenerated\\b\", \"\", right, flags=re.IGNORECASE).strip()\n",
    "    # keep words like \"ImageNet (val)\" intact\n",
    "    key = right.lower().strip()\n",
    "    return aliases.get(key, right)\n",
    "\n",
    "tmp = df.copy()\n",
    "tmp[\"Dataset\"] = tmp[\"Comparison\"].apply(extract_dataset)\n",
    "\n",
    "# keep only rows for the datasets we care about\n",
    "tmp = tmp[tmp[\"Dataset\"].isin(col_order)]\n",
    "\n",
    "# remove duplicates if any\n",
    "tmp = tmp.drop_duplicates(subset=[\"Loss Type\",\"Loss Round\",\"Dataset\"], keep=\"last\")\n",
    "\n",
    "# Pivot to wide format\n",
    "wide = tmp.pivot(index=[\"Loss Type\",\"Loss Round\"], columns=\"Dataset\", values=metric)\n",
    "\n",
    "# Reorder columns\n",
    "wide = wide.reindex(columns=col_order)\n",
    "\n",
    "# Round numeric values\n",
    "wide = wide.round(3)\n",
    "\n",
    "# Reset index so Loss Type and Loss Round become columns again\n",
    "wide = wide.reset_index()\n",
    "\n",
    "wide[col_order] = wide[col_order].applymap(lambda x: round(x*100, 1) if pd.notnull(x) else x)\n",
    "\n",
    "# Convert to LaTeX\n",
    "latex = wide.to_latex(\n",
    "    index=False,\n",
    "    escape=True,\n",
    "    na_rep=\"-\",\n",
    "    float_format=lambda x: f\"{x:.1f}\",  # ensure one decimal\n",
    "    column_format=\"ll\" + \"c\"*len(col_order),\n",
    "    bold_rows=False,\n",
    "    longtable=False,\n",
    "    multicolumn=False,\n",
    "    multicolumn_format=\"c\",\n",
    ")\n",
    "print(latex)\n",
    "# save the latex\n",
    "with open(os.path.join(result_dir, 'results.tex'), 'w') as f:\n",
    "    f.write(latex)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
