{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6eb7d71-9eda-4370-ac5f-195114c18522",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "import math\n",
    "from einops import rearrange\n",
    "import time\n",
    "import random\n",
    "import string\n",
    "import h5py\n",
    "from tqdm import tqdm\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from diffusers.models.autoencoders.vae import Decoder\n",
    "\n",
    "# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main\n",
    "sys.path.append('generative_models/')\n",
    "import sgm\n",
    "from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder2\n",
    "from generative_models.sgm.models.diffusion import DiffusionEngine\n",
    "from generative_models.sgm.util import append_dims\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "\n",
    "import v2_utils as utils\n",
    "#from models import *\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"device:\",device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf364c1-c508-41b8-b018-8f9e5d9dbcab",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path='data/new_dl'\n",
    "cache_dir='checkpoints'\n",
    "model_name='subj1_l_bclip_basictest_wll_5_finetune_s7_250_aamax'\n",
    "checkpoint_name = model_name\n",
    "checkpoint_dir = os.path.abspath(f'train_logs/{checkpoint_name}')\n",
    "checkpoint_tag = 'mid_200'\n",
    "subj=7\n",
    "blurry_recon=True\n",
    "seed=42\n",
    "batch_size=1\n",
    "plotting=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a16b1eef-6817-42e2-b981-991c434ab7a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make output directory\n",
    "os.makedirs(\"evals\",exist_ok=True)\n",
    "os.makedirs(f\"evals/{model_name}\",exist_ok=True)\n",
    "\n",
    "all_images = torch.load(f\"evals/all_images_subj{subj:02d}.pt\")\n",
    "all_recons = torch.load(f\"recons/{model_name}_all_recons.pt\")\n",
    "if blurry_recon:\n",
    "    all_blurryrecons = torch.load(f\"recons/{model_name}_all_blurryrecons.pt\")\n",
    "    all_blurryrecons = transforms.Resize((768,768))(all_blurryrecons).float()\n",
    "all_predcaptions = torch.load(f\"recons/{model_name}_all_predcaptions.pt\")\n",
    "\n",
    "all_recons = transforms.Resize((768,768))(all_recons).float()\n",
    "\n",
    "\n",
    "print(model_name)\n",
    "if blurry_recon:\n",
    "    print(all_images.shape,all_recons.shape, all_predcaptions.shape, all_blurryrecons.shape)\n",
    "else:\n",
    "    print(all_images.shape,all_recons.shape, all_predcaptions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "725ff9d7-ae58-466a-b7cb-f7e6e48f0163",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = OmegaConf.load(\"generative_models/configs/unclip6.yaml\")\n",
    "config = OmegaConf.to_container(config, resolve=True)\n",
    "unclip_params = config[\"model\"][\"params\"]\n",
    "sampler_config = unclip_params[\"sampler_config\"]\n",
    "sampler_config['params']['num_steps'] = 38\n",
    "config = OmegaConf.load(\"generative_models/configs/inference/sd_xl_base.yaml\")\n",
    "config = OmegaConf.to_container(config, resolve=True)\n",
    "refiner_params = config[\"model\"][\"params\"]\n",
    "\n",
    "network_config = refiner_params[\"network_config\"]\n",
    "denoiser_config = refiner_params[\"denoiser_config\"]\n",
    "first_stage_config = refiner_params[\"first_stage_config\"]\n",
    "conditioner_config = refiner_params[\"conditioner_config\"]\n",
    "scale_factor = refiner_params[\"scale_factor\"]\n",
    "disable_first_stage_autocast = refiner_params[\"disable_first_stage_autocast\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66412b47-7db2-4a28-91eb-85811bffff37",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_ckpt_path = 'checkpoints/zavychromaxl_v30.safetensors'\n",
    "base_engine = DiffusionEngine(network_config=network_config,\n",
    "                       denoiser_config=denoiser_config,\n",
    "                       first_stage_config=first_stage_config,\n",
    "                       conditioner_config=conditioner_config,\n",
    "                       sampler_config=sampler_config, # using the one defined by the unclip\n",
    "                       scale_factor=scale_factor,\n",
    "                       disable_first_stage_autocast=disable_first_stage_autocast,\n",
    "                       ckpt_path=base_ckpt_path)\n",
    "base_engine.eval().requires_grad_(False)\n",
    "base_engine.to(device)\n",
    "\n",
    "base_engine.conditioner.embedders[1].model.attn_mask = None\n",
    "\n",
    "base_text_embedder1 = FrozenCLIPEmbedder(\n",
    "    layer=conditioner_config['params']['emb_models'][0]['params']['layer'],\n",
    "    layer_idx=conditioner_config['params']['emb_models'][0]['params']['layer_idx'],\n",
    ")\n",
    "base_text_embedder1.to(device)\n",
    "\n",
    "base_text_embedder2 = FrozenOpenCLIPEmbedder2(\n",
    "    arch=conditioner_config['params']['emb_models'][1]['params']['arch'],\n",
    "    version=conditioner_config['params']['emb_models'][1]['params']['version'],\n",
    "    freeze=conditioner_config['params']['emb_models'][1]['params']['freeze'],\n",
    "    layer=conditioner_config['params']['emb_models'][1]['params']['layer'],\n",
    "    always_return_pooled=conditioner_config['params']['emb_models'][1]['params']['always_return_pooled'],\n",
    "    #always_return_pooled = False,\n",
    "    legacy=conditioner_config['params']['emb_models'][1]['params']['legacy'],\n",
    ")\n",
    "base_text_embedder2.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c9d4564-d515-41c4-9ab9-03c23b785d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch={\"txt\":\"\",\n",
    "      \"original_size_as_tuple\": torch.ones(1, 2).to(device) * 768,\n",
    "      \"crop_coords_top_left\": torch.zeros(1, 2).to(device),\n",
    "      \"target_size_as_tuple\": torch.ones(1, 2).to(device) * 1024}\n",
    "out = base_engine.conditioner(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "679f07f0-2f7e-44db-8abe-e5d40d78fae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "crossattn = out[\"crossattn\"].to(device)\n",
    "vector_suffix = out[\"vector\"][:,-1536:].to(device)\n",
    "print(\"crossattn\", crossattn.shape)\n",
    "print(\"vector_suffix\", vector_suffix.shape)\n",
    "print(\"---\")\n",
    "\n",
    "batch_uc={\"txt\": \"painting, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs, anime\",\n",
    "      \"original_size_as_tuple\": torch.ones(1, 2).to(device) * 768,\n",
    "      \"crop_coords_top_left\": torch.zeros(1, 2).to(device),\n",
    "      \"target_size_as_tuple\": torch.ones(1, 2).to(device) * 1024}\n",
    "out = base_engine.conditioner(batch_uc)\n",
    "crossattn_uc = out[\"crossattn\"].to(device)\n",
    "vector_uc = out[\"vector\"].to(device)\n",
    "print(\"crossattn_uc\", crossattn_uc.shape)\n",
    "print(\"vector_uc\", vector_uc.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae07cf22-a134-46df-8a90-138b90782e96",
   "metadata": {},
   "outputs": [],
   "source": [
    "if utils.is_interactive(): plotting=False\n",
    "\n",
    "num_samples = 1 \n",
    "img2img_timepoint = 9 # 9 # higher number means more reliance on prompt, less reliance on matching the conditioning image\n",
    "base_engine.sampler.guider.scale = 5\n",
    "def denoiser(x, sigma, c): return base_engine.denoiser(base_engine.model, x, sigma, c)\n",
    "\n",
    "if plotting or num_samples>1:\n",
    "    clip_img_embedder = FrozenOpenCLIPImageEmbedder(\n",
    "        arch=\"ViT-bigG-14\",\n",
    "        version=\"laion2b_s39b_b160k\",\n",
    "        output_tokens=True,\n",
    "        only_tokens=True,\n",
    "    )\n",
    "    clip_img_embedder.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b767530-7190-4ae8-bf8b-dab3d789191d",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_enhancedrecons = None\n",
    "plotting=False\n",
    "for img_idx in tqdm(range(len(all_recons))):\n",
    "    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16), base_engine.ema_scope():\n",
    "        base_engine.sampler.num_steps = 25\n",
    "        \n",
    "        image = all_recons[[img_idx]]\n",
    "        \n",
    "        if plotting:\n",
    "            print(\"blur pixcorr:\",utils.pixcorr(all_blurryrecons[[img_idx]].float(), all_images[[img_idx]].float()))\n",
    "            print(\"blur cossim:\",nn.functional.cosine_similarity(clip_img_embedder(utils.resize(all_blurryrecons[[img_idx]].float(),256).to(device)).flatten(1), \n",
    "                                                         clip_img_embedder(utils.resize(all_images[[img_idx]].float(),224).to(device)).flatten(1)))\n",
    "\n",
    "            print(\"recon pixcorr:\",utils.pixcorr(image,all_images[[img_idx]].float()))\n",
    "            print(\"recon cossim:\",nn.functional.cosine_similarity(clip_img_embedder(utils.resize(image,224).to(device)).flatten(1), \n",
    "                                                         clip_img_embedder(utils.resize(all_images[[img_idx]].float(),224).to(device)).flatten(1)))\n",
    "        \n",
    "        image = image.to(device)\n",
    "        prompt = all_predcaptions[[img_idx]][0]\n",
    "        # prompt = \"\"\n",
    "        if plotting: \n",
    "            print(\"prompt:\",prompt)\n",
    "            plt.imshow(transforms.ToPILImage()(all_blurryrecons[img_idx].float()))\n",
    "            plt.show()\n",
    "            plt.imshow(transforms.ToPILImage()(all_recons[img_idx].float()))\n",
    "            plt.show()\n",
    "            plt.imshow(transforms.ToPILImage()(image[0]))\n",
    "            plt.show()\n",
    "\n",
    "        # z = torch.randn(num_samples,4,96,96).to(device)\n",
    "        assert image.shape[-1]==768\n",
    "        z = base_engine.encode_first_stage(image*2-1).repeat(num_samples,1,1,1)\n",
    "\n",
    "        openai_clip_text = base_text_embedder1(prompt)\n",
    "        clip_text_tokenized, clip_text_emb  = base_text_embedder2(prompt)\n",
    "        clip_text_emb = torch.hstack((clip_text_emb, vector_suffix))\n",
    "        clip_text_tokenized = torch.cat((openai_clip_text, clip_text_tokenized),dim=-1)\n",
    "        c = {\"crossattn\": clip_text_tokenized.repeat(num_samples,1,1), \"vector\": clip_text_emb.repeat(num_samples,1)}\n",
    "        uc = {\"crossattn\": crossattn_uc.repeat(num_samples,1,1), \"vector\": vector_uc.repeat(num_samples,1)}\n",
    "\n",
    "        noise = torch.randn_like(z)\n",
    "        sigmas = base_engine.sampler.discretization(base_engine.sampler.num_steps).to(device)\n",
    "        init_z = (z + noise * append_dims(sigmas[-img2img_timepoint], z.ndim)) / torch.sqrt(1.0 + sigmas[0] ** 2.0)\n",
    "        sigmas = sigmas[-img2img_timepoint:].repeat(num_samples,1)\n",
    "\n",
    "        base_engine.sampler.num_steps = sigmas.shape[-1] - 1\n",
    "        noised_z, _, _, _, c, uc = base_engine.sampler.prepare_sampling_loop(init_z, cond=c, uc=uc, \n",
    "                                                            num_steps=base_engine.sampler.num_steps)\n",
    "        for timestep in range(base_engine.sampler.num_steps):\n",
    "            noised_z = base_engine.sampler.sampler_step(sigmas[:,timestep],\n",
    "                                                        sigmas[:,timestep+1],\n",
    "                                                        denoiser, noised_z, cond=c, uc=uc, gamma=0)\n",
    "        samples_z_base = noised_z\n",
    "        samples_x = base_engine.decode_first_stage(samples_z_base)\n",
    "        samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n",
    "\n",
    "        # find best sample\n",
    "        if plotting==False and num_samples==1:\n",
    "            samples = samples[0]\n",
    "        else:\n",
    "            sample_cossim = nn.functional.cosine_similarity(clip_img_embedder(utils.resize(samples,224).to(device)).flatten(1), \n",
    "                                clip_img_embedder(utils.resize(all_images[[img_idx]].float(),224).to(device)).flatten(1))\n",
    "            which_sample = torch.argmax(sample_cossim)\n",
    "            best_cossim = torch.max(sample_cossim)\n",
    "\n",
    "            if plotting:\n",
    "                print(\"samples\", samples.shape)\n",
    "                for n in range(num_samples):\n",
    "                    recon = transforms.ToPILImage()(samples[n])\n",
    "                    plt.imshow(recon)\n",
    "                    plt.show()\n",
    "                    if (n==which_sample).item(): print(\"CHOSEN ABOVE\")\n",
    "                    print(\"upsampled pixcorr:\",utils.pixcorr(samples[[n]].cpu(),all_images[[img_idx]].float()))\n",
    "                    print(\"upsampled cossim:\",nn.functional.cosine_similarity(clip_img_embedder(utils.resize(samples[[n]],224).to(device)).flatten(1), \n",
    "                                                         clip_img_embedder(utils.resize(all_images[[img_idx]].float(),224).to(device)).flatten(1)))\n",
    "                err # dont want to do entire for loop with plotting=True\n",
    "\n",
    "            samples = samples[which_sample]\n",
    "\n",
    "        samples = samples.cpu()[None]\n",
    "        if all_enhancedrecons is None:\n",
    "            all_enhancedrecons = samples\n",
    "        else:\n",
    "            all_enhancedrecons = torch.vstack((all_enhancedrecons, samples))\n",
    "            \n",
    "all_enhancedrecons = transforms.Resize((256,256))(all_enhancedrecons).float()\n",
    "print(\"all_enhancedrecons\", all_enhancedrecons.shape)\n",
    "torch.save(all_enhancedrecons,f\"evals/{model_name}/{model_name}_all_enhancedrecons.pt\")\n",
    "print(f\"saved evals/{model_name}/{model_name}_all_enhancedrecons.pt\")\n",
    "\n",
    "if not utils.is_interactive():\n",
    "    sys.exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09e81467-420e-45b3-b323-da60970ecd72",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c60d28-3f03-4354-839d-c888d76d33b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f47ef1f-e71a-43a8-ae9c-b4f84f08a206",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da41aa25-834b-4c01-b9f4-52dd2d869629",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "207630c3-fa2a-4e65-9963-f0719f094031",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed07160f-2f8b-446f-9065-f61037aa7be7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6db496e-a9a5-41d9-802e-ed6595974e61",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984fdb7f-e668-4d5d-b8cd-31c74f441c49",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
