{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"make variations of input image\"\"\"\n",
    "\n",
    "import argparse, os, sys, glob\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "import PIL\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from PIL import Image\n",
    "from tqdm import tqdm, trange\n",
    "from itertools import islice\n",
    "from einops import rearrange, repeat\n",
    "from torchvision.utils import make_grid\n",
    "from torch import autocast\n",
    "from contextlib import nullcontext\n",
    "import time\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "sys.path.append(os.path.dirname(sys.path[0]))\n",
    "from ldm.util import instantiate_from_config\n",
    "from ldm.models.diffusion.ddim import DDIMSampler\n",
    "from ldm.models.diffusion.plms import PLMSSampler\n",
    "\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "def chunk(it, size):\n",
    "    it = iter(it)\n",
    "    return iter(lambda: tuple(islice(it, size)), ())\n",
    "\n",
    "\n",
    "def load_model_from_config(config, ckpt, verbose=False):\n",
    "    print(f\"Loading model from {ckpt}\")\n",
    "    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
    "    if \"global_step\" in pl_sd:\n",
    "        print(f\"Global Step: {pl_sd['global_step']}\")\n",
    "    sd = pl_sd[\"state_dict\"]\n",
    "    model = instantiate_from_config(config.model)\n",
    "    m, u = model.load_state_dict(sd, strict=False)\n",
    "    if len(m) > 0 and verbose:\n",
    "        print(\"missing keys:\")\n",
    "        print(m)\n",
    "    if len(u) > 0 and verbose:\n",
    "        print(\"unexpected keys:\")\n",
    "        print(u)\n",
    "\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def load_img(path):\n",
    "    image = Image.open(path).convert(\"RGB\")\n",
    "    w, h = image.size\n",
    "    print(f\"loaded input image of size ({w}, {h}) from {path}\")\n",
    "    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32\n",
    "    image = image.resize((512, 512), resample=PIL.Image.LANCZOS)\n",
    "    image = np.array(image).astype(np.float32) / 255.0\n",
    "    image = image[None].transpose(0, 3, 1, 2)\n",
    "    image = torch.from_numpy(image)\n",
    "    return 2.*image - 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config=\"configs/stable-diffusion/v1-inference.yaml\"\n",
    "ckpt=\"./models/sd/sd-v1-4.ckpt\"\n",
    "config = OmegaConf.load(f\"{config}\")\n",
    "model = load_model_from_config(config, f\"{ckpt}\")\n",
    "sampler = DDIMSampler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(S,\n",
    "            batch_size,\n",
    "            shape,\n",
    "            conditioning=None,\n",
    "            callback=None,\n",
    "            normals_sequence=None,\n",
    "            img_callback=None,\n",
    "            quantize_x0=False,\n",
    "            eta=0.,\n",
    "            mask=None,\n",
    "            x0=None,\n",
    "            temperature=1.,\n",
    "            noise_dropout=0.,\n",
    "            score_corrector=None,\n",
    "            corrector_kwargs=None,\n",
    "            verbose=True,\n",
    "            x_T=None,\n",
    "            log_every_t=100,\n",
    "            unconditional_guidance_scale=1.,\n",
    "            unconditional_conditioning=None,\n",
    "            # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n",
    "            **kwargs\n",
    "            ):\n",
    "    if conditioning is not None:\n",
    "        if isinstance(conditioning, dict):\n",
    "            cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
    "            if cbs != batch_size:\n",
    "                print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
    "        else:\n",
    "            if conditioning[0].shape[0] != batch_size:\n",
    "                print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
    "    # sampling\n",
    "    C, H, W = shape\n",
    "    size = (batch_size, C, H, W)\n",
    "    print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
    "\n",
    "    samples, intermediates = ddim_sampling(conditioning, size,\n",
    "                                                callback=callback,\n",
    "                                                img_callback=img_callback,\n",
    "                                                quantize_denoised=quantize_x0,\n",
    "                                                mask=mask, x0=x0,\n",
    "                                                ddim_use_original_steps=False,\n",
    "                                                noise_dropout=noise_dropout,\n",
    "                                                temperature=temperature,\n",
    "                                                score_corrector=score_corrector,\n",
    "                                                corrector_kwargs=corrector_kwargs,\n",
    "                                                x_T=x_T,\n",
    "                                                log_every_t=log_every_t,\n",
    "                                                unconditional_guidance_scale=unconditional_guidance_scale,\n",
    "                                                unconditional_conditioning=unconditional_conditioning,\n",
    "                                                )\n",
    "    return samples, intermediates\n",
    "\n",
    "@torch.no_grad()\n",
    "def ddim_sampling(cond, shape,\n",
    "                    x_T=None, ddim_use_original_steps=False,\n",
    "                    callback=None, timesteps=None, quantize_denoised=False,\n",
    "                    mask=None, x0=None, img_callback=None, log_every_t=100,\n",
    "                    temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n",
    "                    unconditional_guidance_scale=1., unconditional_conditioning=None,):\n",
    "    device = sampler.model.betas.device\n",
    "    b = shape[0]\n",
    "    if x_T is None:\n",
    "        img = torch.randn(shape, device=device)\n",
    "    else:\n",
    "        img = x_T\n",
    "\n",
    "    if timesteps is None:\n",
    "        timesteps = sampler.ddpm_num_timesteps if ddim_use_original_steps else sampler.ddim_timesteps\n",
    "    elif timesteps is not None and not ddim_use_original_steps:\n",
    "        subset_end = int(min(timesteps / sampler.ddim_timesteps.shape[0], 1) * sampler.ddim_timesteps.shape[0]) - 1\n",
    "        timesteps = sampler.ddim_timesteps[:subset_end]\n",
    "\n",
    "    intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
    "    time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
    "    total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
    "    print(f\"Running DDIM Sampling with {total_steps} timesteps\")\n",
    "\n",
    "    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)\n",
    "    for i, step in enumerate(iterator):\n",
    "        index = total_steps - i - 1\n",
    "        ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
    "        norm_t = int(ts*10/1000)\n",
    "        \n",
    "        cond2 = cond[norm_t]\n",
    "        #norm_t = 0\n",
    "        if mask is not None:\n",
    "            assert x0 is not None\n",
    "            img_orig = sampler.model.q_sample(x0, ts)  # TODO: deterministic forward pass?\n",
    "            img = img_orig * mask + (1. - mask) * img\n",
    "\n",
    "        outs = sampler.p_sample_ddim(img,cond2, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
    "                                    quantize_denoised=quantize_denoised, temperature=temperature,\n",
    "                                    noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
    "                                    corrector_kwargs=corrector_kwargs,\n",
    "                                    unconditional_guidance_scale=unconditional_guidance_scale,\n",
    "                                    unconditional_conditioning=cond2)\n",
    "        img, pred_x0 = outs\n",
    "        if callback: callback(i)\n",
    "        if img_callback: img_callback(pred_x0, i)\n",
    "\n",
    "        if index % log_every_t == 0 or index == total_steps - 1:\n",
    "            intermediates['x_inter'].append(img)\n",
    "            intermediates['pred_x0'].append(pred_x0)\n",
    "\n",
    "    return img, intermediates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(prompt = '', Zt_dir = '', embedding_dir='',ddim_steps = 50,strength = 0.5, model = None, seed=42):\n",
    "    ddim_eta=0.0\n",
    "    n_iter=1\n",
    "    C=4\n",
    "    f=8\n",
    "    n_samples=1\n",
    "    n_rows=0\n",
    "    scale=10.0\n",
    "    \n",
    "    precision=\"autocast\"\n",
    "    outdir=\"./out\"\n",
    "    seed_everything(seed)\n",
    "\n",
    "\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "    outpath = outdir\n",
    "\n",
    "    batch_size = n_samples\n",
    "    n_rows = n_rows if n_rows > 0 else batch_size\n",
    "    data = [batch_size * [prompt]]\n",
    "\n",
    "    sample_path = os.path.join(outpath, \"samples\")\n",
    "    os.makedirs(sample_path, exist_ok=True)\n",
    "    base_count = len(os.listdir(sample_path))\n",
    "    grid_count = len(os.listdir(outpath)) + 10\n",
    "\n",
    "    x_inversion = torch.load(Zt_dir,map_location=device)\n",
    "    c = torch.load(embedding_dir,map_location=device)\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)\n",
    "\n",
    "    assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'\n",
    "    t_enc = int(strength * ddim_steps)\n",
    "    print(f\"target t_enc is {t_enc} steps\")\n",
    "\n",
    "    precision_scope = autocast if precision == \"autocast\" else nullcontext\n",
    "    with torch.no_grad():\n",
    "        with precision_scope(\"cuda\"):\n",
    "            with model.ema_scope():\n",
    "                tic = time.time()\n",
    "                all_samples = list()\n",
    "                for n in trange(n_iter, desc=\"Sampling\"):\n",
    "                    for prompts in tqdm(data, desc=\"data\"):\n",
    "                        \n",
    "                        t_enc = ddim_steps\n",
    "                        #x_rand = torch.randn_like(x_inversion)\n",
    "                        samples, intermediates = sample(ddim_steps,1,(4,512,512),c,verbose=False, eta=0.,x_T = x_inversion,\n",
    "                unconditional_guidance_scale=scale,\n",
    "                unconditional_conditioning=c[0],)\n",
    "\n",
    "                        x_samples = model.decode_first_stage(samples)\n",
    "\n",
    "                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
    "\n",
    "                        for x_sample in x_samples:\n",
    "                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
    "                            base_count += 1\n",
    "                        all_samples.append(x_samples)\n",
    "\n",
    "                # additionally, save as grid\n",
    "                grid = torch.stack(all_samples, 0)\n",
    "                grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
    "                grid = make_grid(grid, nrow=n_rows)\n",
    "\n",
    "                # to image\n",
    "                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
    "                output = Image.fromarray(grid.astype(np.uint8))\n",
    "                output.save(os.path.join(outpath, 'recover'+f'-{grid_count:04}.png'))\n",
    "                # Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))\n",
    "                grid_count += 1\n",
    "\n",
    "                toc = time.time()\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main(prompt = '*', \\\n",
    "     Zt_dir = './out/Alejandro_Toledo_0037_zt.pt', \\\n",
    "     embedding_dir = './out/Alejandro_Toledo_0037_embedding.pt', \\\n",
    "     ddim_steps = 100, \\\n",
    "     strength = 0.99, \\\n",
    "     seed=50, \\\n",
    "     model = model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
