{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"make variations of input image\"\"\"\n",
    "\n",
    "import argparse, os, sys, glob\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "import PIL\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from PIL import Image\n",
    "from tqdm import tqdm, trange\n",
    "from itertools import islice\n",
    "from einops import rearrange, repeat\n",
    "from torchvision.utils import make_grid\n",
    "from torch import autocast\n",
    "from contextlib import nullcontext\n",
    "import time\n",
    "from pytorch_lightning import seed_everything\n",
    "import torch.nn.functional as F\n",
    "sys.path.append(os.path.dirname(sys.path[0]))\n",
    "from ldm.util import instantiate_from_config\n",
    "from ldm.models.diffusion.ddim import DDIMSampler\n",
    "from ldm.models.diffusion.plms import PLMSSampler\n",
    "from models.encoders.model_irse import Backbone\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "from torchvision.transforms import Compose, ToTensor, Normalize\n",
    "from facenet_pytorch import InceptionResnetV1,MTCNN\n",
    "transfroms = Compose(\n",
    "        [ToTensor(), Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "def chunk(it, size):\n",
    "    it = iter(it)\n",
    "    return iter(lambda: tuple(islice(it, size)), ())\n",
    "\n",
    "\n",
    "def load_model_from_config(config, ckpt, verbose=False):\n",
    "    print(f\"Loading model from {ckpt}\")\n",
    "    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
    "    if \"global_step\" in pl_sd:\n",
    "        print(f\"Global Step: {pl_sd['global_step']}\")\n",
    "    sd = pl_sd[\"state_dict\"]\n",
    "    model = instantiate_from_config(config.model)\n",
    "    m, u = model.load_state_dict(sd, strict=False)\n",
    "    if len(m) > 0 and verbose:\n",
    "        print(\"missing keys:\")\n",
    "        print(m)\n",
    "    if len(u) > 0 and verbose:\n",
    "        print(\"unexpected keys:\")\n",
    "        print(u)\n",
    "\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def load_img(path):\n",
    "    image = Image.open(path).convert(\"RGB\")\n",
    "    w, h = image.size\n",
    "    print(f\"loaded input image of size ({w}, {h}) from {path}\")\n",
    "    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32\n",
    "    image = image.resize((512, 512), resample=PIL.Image.LANCZOS)\n",
    "    image = np.array(image).astype(np.float32) / 255.0\n",
    "    image = image[None].transpose(0, 3, 1, 2)\n",
    "    image = torch.from_numpy(image)\n",
    "    return 2.*image - 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config=\"configs/stable-diffusion/v1-inference.yaml\"\n",
    "ckpt=\"./models/sd/sd-v1-4.ckpt\"\n",
    "config = OmegaConf.load(f\"{config}\")\n",
    "model = load_model_from_config(config, f\"{ckpt}\")\n",
    "sampler = DDIMSampler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_reverse(S,\n",
    "            batch_size,\n",
    "            shape,\n",
    "            conditioning=None,\n",
    "            callback=None,\n",
    "            normals_sequence=None,\n",
    "            img_callback=None,\n",
    "            quantize_x0=False,\n",
    "            eta=0.,\n",
    "            mask=None,\n",
    "            x0=None,\n",
    "            temperature=1.,\n",
    "            noise_dropout=0.,\n",
    "            score_corrector=None,\n",
    "            corrector_kwargs=None,\n",
    "            verbose=True,\n",
    "            x_T=None,\n",
    "            log_every_t=100,\n",
    "            unconditional_guidance_scale=1.,\n",
    "            unconditional_conditioning=None,\n",
    "            # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n",
    "            **kwargs\n",
    "            ):\n",
    "    if conditioning is not None:\n",
    "        if isinstance(conditioning, dict):\n",
    "            cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
    "            if cbs != batch_size:\n",
    "                print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
    "        else:\n",
    "            if conditioning[0].shape[0] != batch_size:\n",
    "                print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
    "    # sampling\n",
    "    C, H, W = shape\n",
    "    size = (batch_size, C, H, W)\n",
    "    print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
    "\n",
    "    samples, intermediates = ddim_reverse_sampling(conditioning, size,\n",
    "                                                callback=callback,\n",
    "                                                img_callback=img_callback,\n",
    "                                                quantize_denoised=quantize_x0,\n",
    "                                                mask=mask, x0=x0,\n",
    "                                                ddim_use_original_steps=False,\n",
    "                                                noise_dropout=noise_dropout,\n",
    "                                                temperature=temperature,\n",
    "                                                score_corrector=score_corrector,\n",
    "                                                corrector_kwargs=corrector_kwargs,\n",
    "                                                x_T=x_T,\n",
    "                                                log_every_t=log_every_t,\n",
    "                                                unconditional_guidance_scale=unconditional_guidance_scale,\n",
    "                                                unconditional_conditioning=unconditional_conditioning,\n",
    "                                                )\n",
    "    return samples, intermediates\n",
    "\n",
    "@torch.no_grad()\n",
    "def ddim_reverse_sampling(cond, shape,\n",
    "                    x_T=None, ddim_use_original_steps=False,\n",
    "                    callback=None, timesteps=None, quantize_denoised=False,\n",
    "                    mask=None, x0=None, img_callback=None, log_every_t=100,\n",
    "                    temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n",
    "                    unconditional_guidance_scale=1., unconditional_conditioning=None,):\n",
    "    device = sampler.model.betas.device\n",
    "    b = shape[0]\n",
    "    if x_T is None:\n",
    "        img = torch.randn(shape, device=device)\n",
    "    else:\n",
    "        img = x_T\n",
    "\n",
    "    if timesteps is None:\n",
    "        timesteps = sampler.ddpm_num_timesteps if ddim_use_original_steps else sampler.ddim_timesteps\n",
    "    elif timesteps is not None and not ddim_use_original_steps:\n",
    "        subset_end = int(min(timesteps / sampler.ddim_timesteps.shape[0], 1) * sampler.ddim_timesteps.shape[0]) - 1\n",
    "        timesteps = sampler.ddim_timesteps[:subset_end]\n",
    "\n",
    "    intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
    "    time_range = (range(0,timesteps)) if ddim_use_original_steps else (timesteps)\n",
    "    total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
    "    print(f\"Running DDIM reverse Sampling with {total_steps} timesteps\")\n",
    "\n",
    "    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)\n",
    "    #input_noise = torch.randn(img.shape, device=device)\n",
    "    for i, step in enumerate(iterator):\n",
    "        index = i\n",
    "        ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
    "        norm_t = int(ts*10/1000)\n",
    "        if mask is not None:\n",
    "            assert x0 is not None\n",
    "            img_orig = sampler.model.q_sample(x0, ts)  # TODO: deterministic forward pass?\n",
    "            img = img_orig * mask + (1. - mask) * img\n",
    "        #seed_everything(42)\n",
    "        outs = sampler.p_sample_ddim_reverse(img, cond[norm_t], ts, index=index, use_original_steps=ddim_use_original_steps,\n",
    "                                    quantize_denoised=quantize_denoised, temperature=temperature,\n",
    "                                    noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
    "                                    corrector_kwargs=corrector_kwargs,\n",
    "                                    unconditional_guidance_scale=1.0,\n",
    "                                    unconditional_conditioning=cond[norm_t])\n",
    "        img, pred_x0 = outs\n",
    "        if callback: callback(i)\n",
    "        if img_callback: img_callback(pred_x0, i)\n",
    "\n",
    "        if index % log_every_t == 0 or index == total_steps - 1:\n",
    "            intermediates['x_inter'].append(img)\n",
    "            intermediates['pred_x0'].append(pred_x0)\n",
    "\n",
    "    return img, intermediates\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode( x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,\n",
    "            use_original_steps=False, input_noise = None,initial_img = None,resnet=None,guidance=None,loss_guidance_scale=0):\n",
    "\n",
    "    timesteps = np.arange(1000) if use_original_steps else sampler.ddim_timesteps\n",
    "    timesteps = timesteps[:t_start]\n",
    "\n",
    "    time_range = np.flip(timesteps)\n",
    "    total_steps = timesteps.shape[0]\n",
    "    print(f\"Running DDIM Sampling with {total_steps} timesteps\")\n",
    "\n",
    "    iterator = tqdm(time_range, desc='Decoding image', total=total_steps)\n",
    "    x_dec = x_latent\n",
    "    for i, step in enumerate(iterator):\n",
    "        index = total_steps - i - 1\n",
    "        ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)\n",
    "        norm_t = int(ts*10/1000)\n",
    "        #norm_t = int(ts/87)\n",
    "        if int(ts)>600:\n",
    "            cond2 = cond[norm_t]\n",
    "            loss_guidance_scale = 1.5\n",
    "        else:\n",
    "            cond2 = unconditional_conditioning\n",
    "            loss_guidance_scale = 2.5\n",
    "        \n",
    "        print(norm_t)\n",
    "        x_dec, _ = sampler.p_sample_ddim2(x_dec, cond2, ts, index=index, use_original_steps=use_original_steps,\n",
    "                                        unconditional_guidance_scale=unconditional_guidance_scale,\n",
    "                                        unconditional_conditioning=cond2,\n",
    "                                        input_noise = input_noise,\n",
    "                                        initial_img = initial_img,\n",
    "                                        resnet = resnet,\n",
    "                                        guidance=guidance,\n",
    "                                        loss_guidance_scale=loss_guidance_scale)\n",
    "    return x_dec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(prompt = '', content_dir = '',ddim_steps = 50,strength = 0.5, model = None, resnet = None,seed=42):\n",
    "    ddim_eta=0.0\n",
    "    n_iter=1\n",
    "    C=4\n",
    "    f=8\n",
    "    n_samples=1\n",
    "    n_rows=0\n",
    "    scale=10.0\n",
    "    \n",
    "    precision=\"autocast\"\n",
    "    outdir=\"./out\"\n",
    "    seed_everything(seed)\n",
    "\n",
    "    mtcnn = MTCNN(image_size=160)\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "    outpath = outdir\n",
    "\n",
    "    batch_size = n_samples\n",
    "    n_rows = n_rows if n_rows > 0 else batch_size\n",
    "    data = [batch_size * [prompt]]\n",
    "\n",
    "\n",
    "    sample_path = os.path.join(outpath, \"samples\")\n",
    "    os.makedirs(sample_path, exist_ok=True)\n",
    "    base_count = len(os.listdir(sample_path))\n",
    "    grid_count = len(os.listdir(outpath)) + 10\n",
    "\n",
    "    content_name =  content_dir.split('/')[-1].split('.')[0]\n",
    "    content_image = load_img(content_dir).to(device)\n",
    "    content_image = repeat(content_image, '1 ... -> b ...', b=batch_size)\n",
    "    image = Image.open(content_dir)\n",
    "    align_img = mtcnn(image)\n",
    "    initial_image = ((align_img)).unsqueeze(0).to(device)\n",
    "    #initial_image = mtcnn(initial_image).unsqueeze(0).to(device)\n",
    "    content_latent = model.get_first_stage_encoding(model.encode_first_stage(content_image))  # move to latent space\n",
    "\n",
    "    init_latent = content_latent\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)\n",
    "\n",
    "    assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'\n",
    "    t_enc = int(strength * ddim_steps)\n",
    "    print(f\"target t_enc is {t_enc} steps\")\n",
    "\n",
    "    precision_scope = autocast if precision == \"autocast\" else nullcontext\n",
    "    with torch.no_grad():\n",
    "        with precision_scope(\"cuda\"):\n",
    "            with model.ema_scope():\n",
    "                tic = time.time()\n",
    "                all_samples = list()\n",
    "                for n in trange(n_iter, desc=\"Sampling\"):\n",
    "                    for prompts in tqdm(data, desc=\"data\"):\n",
    "                        uc = None\n",
    "                        if scale != 1.0:\n",
    "                            uc=[]\n",
    "                            for i in range(10):\n",
    "                                uc.append(model.get_learned_conditioning(batch_size * [\"\"], content_image,i))\n",
    "                        if isinstance(prompts, tuple):\n",
    "                            prompts = list(prompts)\n",
    "                        c= [] \n",
    "                        for i in range(10):\n",
    "                            c.append( model.get_learned_conditioning(prompts, content_image,i))\n",
    "                        seed_everything(seed)\n",
    "                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
    "                        #z_enc = torch.randn_like(init_latent)\n",
    "                        t_enc = int(strength * ddim_steps)\n",
    "\n",
    "                        x_inversion,_ = sample_reverse(ddim_steps,1,(4,512,512),c,verbose=False, eta=0.,x_T = init_latent,\n",
    "                unconditional_guidance_scale=scale,\n",
    "                unconditional_conditioning=uc[0],)\n",
    "                        torch.save(x_inversion,os.path.join(outpath, content_name+'_zt.pt'))\n",
    "                        torch.save(c,os.path.join(outpath, content_name+'_embedding.pt'))\n",
    "                        del x_inversion\n",
    "\n",
    "                        samples = decode(z_enc, c, t_enc, \n",
    "                                                unconditional_guidance_scale=scale,\n",
    "                                                    unconditional_conditioning=uc[0],initial_img=initial_image,resnet=resnet,\n",
    "                                                    guidance = True,loss_guidance_scale=1)\n",
    "                        print(z_enc.shape, uc[0].shape, t_enc)\n",
    "\n",
    "                        x_samples = model.decode_first_stage(samples)\n",
    "\n",
    "                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
    "\n",
    "                        for x_sample in x_samples:\n",
    "                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
    "                            base_count += 1\n",
    "                        all_samples.append(x_samples)\n",
    "\n",
    "                # additionally, save as grid\n",
    "                grid = torch.stack(all_samples, 0)\n",
    "                grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
    "                grid = make_grid(grid, nrow=n_rows)\n",
    "\n",
    "                # to image\n",
    "                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
    "                output = Image.fromarray(grid.astype(np.uint8))\n",
    "                output.save(os.path.join(outpath, content_name+f'-{grid_count:04}.png'))\n",
    "                # Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))\n",
    "                grid_count += 1\n",
    "\n",
    "                toc = time.time()\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.cpu()\n",
    "model.embedding_manager.load('./logs/Alejandro_Toledo2023-07-23T15-38-22_Alejandro_Toledo/checkpoints/embeddings.pt')\n",
    "facenet = InceptionResnetV1(pretrained='vggface2').eval()\n",
    "#resnet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')\n",
    "#resnet.load_state_dict(torch.load(\"/home/gpu/.cache/torch/hub/checkpoints/model_ir_se50.pth\"))\n",
    "facenet = facenet.to(device).eval()\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content_root = './dataset/Alejandro_Toledo_0037.jpg'\n",
    "main(prompt = '*', \\\n",
    "content_dir = os.path.join(content_root), \\\n",
    "ddim_steps = 100, \\\n",
    "strength = 0.80, \\\n",
    "seed=100162, \\\n",
    "model = model,\\\n",
    "resnet=facenet )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
