{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7abb4368",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"make variations of input image\"\"\"\n",
    "\n",
    "import argparse, os, sys, glob\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "import PIL\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from PIL import Image\n",
    "from tqdm import tqdm, trange\n",
    "from itertools import islice\n",
    "from einops import rearrange, repeat\n",
    "from torchvision.utils import make_grid\n",
    "from torch import autocast\n",
    "from contextlib import nullcontext\n",
    "import time\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "sys.path.append(os.path.dirname(sys.path[0]))\n",
    "from ldm.util import instantiate_from_config\n",
    "from ldm.models.diffusion.ddim import DDIMSampler\n",
    "from ldm.models.diffusion.plms import PLMSSampler\n",
    "\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "from models.encoders.model_irse import Backbone\n",
    "from generate_mask import evaluate_mask\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "def chunk(it, size):\n",
    "    it = iter(it)\n",
    "    return iter(lambda: tuple(islice(it, size)), ())\n",
    "\n",
    "\n",
    "def load_model_from_config(config, ckpt, verbose=False):\n",
    "    print(f\"Loading model from {ckpt}\")\n",
    "    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
    "    if \"global_step\" in pl_sd:\n",
    "        print(f\"Global Step: {pl_sd['global_step']}\")\n",
    "    sd = pl_sd[\"state_dict\"]\n",
    "    model = instantiate_from_config(config.model)\n",
    "    m, u = model.load_state_dict(sd, strict=False)\n",
    "    if len(m) > 0 and verbose:\n",
    "        print(\"missing keys:\")\n",
    "        print(m)\n",
    "    if len(u) > 0 and verbose:\n",
    "        print(\"unexpected keys:\")\n",
    "        print(u)\n",
    "\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def load_img(path):\n",
    "    image = Image.open(path).convert(\"RGB\")\n",
    "    w, h = image.size\n",
    "    print(f\"loaded input image of size ({w}, {h}) from {path}\")\n",
    "    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32\n",
    "    image = image.resize((512, 512), resample=PIL.Image.LANCZOS)\n",
    "    image = np.array(image).astype(np.float32) / 255.0\n",
    "    image = image[None].transpose(0, 3, 1, 2)\n",
    "    image = torch.from_numpy(image)\n",
    "    return 2.*image - 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5498b96",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(device)\n",
    "print(torch.cuda.device_count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b0abda9",
   "metadata": {},
   "outputs": [],
   "source": [
    "config=\"configs/stable-diffusion/v1-inference.yaml\"\n",
    "ckpt=\"./models/sd/sd-v1-4.ckpt\"\n",
    "config = OmegaConf.load(f\"{config}\")\n",
    "model = load_model_from_config(config, f\"{ckpt}\")\n",
    "sampler = DDIMSampler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b345ba02",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_reverse(S,\n",
    "            batch_size,\n",
    "            shape,\n",
    "            conditioning=None,\n",
    "            callback=None,\n",
    "            normals_sequence=None,\n",
    "            img_callback=None,\n",
    "            quantize_x0=False,\n",
    "            eta=0.,\n",
    "            mask=None,\n",
    "            x0=None,\n",
    "            temperature=1.,\n",
    "            noise_dropout=0.,\n",
    "            score_corrector=None,\n",
    "            corrector_kwargs=None,\n",
    "            verbose=True,\n",
    "            x_T=None,\n",
    "            log_every_t=100,\n",
    "            unconditional_guidance_scale=1.,\n",
    "            unconditional_conditioning=None,\n",
    "            # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n",
    "            **kwargs\n",
    "            ):\n",
    "    if conditioning is not None:\n",
    "        if isinstance(conditioning, dict):\n",
    "            cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
    "            if cbs != batch_size:\n",
    "                print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
    "        else:\n",
    "            if conditioning[0].shape[0] != batch_size:\n",
    "                print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
    "    # sampling\n",
    "    C, H, W = shape\n",
    "    size = (batch_size, C, H, W)\n",
    "    print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
    "\n",
    "    samples, intermediates = ddim_reverse_sampling(conditioning, size,\n",
    "                                                callback=callback,\n",
    "                                                img_callback=img_callback,\n",
    "                                                quantize_denoised=quantize_x0,\n",
    "                                                mask=mask, x0=x0,\n",
    "                                                ddim_use_original_steps=False,\n",
    "                                                noise_dropout=noise_dropout,\n",
    "                                                temperature=temperature,\n",
    "                                                score_corrector=score_corrector,\n",
    "                                                corrector_kwargs=corrector_kwargs,\n",
    "                                                x_T=x_T,\n",
    "                                                log_every_t=log_every_t,\n",
    "                                                unconditional_guidance_scale=unconditional_guidance_scale,\n",
    "                                                unconditional_conditioning=unconditional_conditioning,\n",
    "                                                )\n",
    "    return samples, intermediates\n",
    "\n",
    "@torch.no_grad()\n",
    "def ddim_reverse_sampling(cond, shape,\n",
    "                    x_T=None, ddim_use_original_steps=False,\n",
    "                    callback=None, timesteps=None, quantize_denoised=False,\n",
    "                    mask=None, x0=None, img_callback=None, log_every_t=100,\n",
    "                    temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n",
    "                    unconditional_guidance_scale=1., unconditional_conditioning=None,):\n",
    "    device = sampler.model.betas.device\n",
    "    b = shape[0]\n",
    "    if x_T is None:\n",
    "        img = torch.randn(shape, device=device)\n",
    "    else:\n",
    "        img = x_T\n",
    "\n",
    "    if timesteps is None:\n",
    "        timesteps = sampler.ddpm_num_timesteps if ddim_use_original_steps else sampler.ddim_timesteps\n",
    "    elif timesteps is not None and not ddim_use_original_steps:\n",
    "        subset_end = int(min(timesteps / sampler.ddim_timesteps.shape[0], 1) * sampler.ddim_timesteps.shape[0]) - 1\n",
    "        timesteps = sampler.ddim_timesteps[:subset_end]\n",
    "\n",
    "    intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
    "    time_range = (range(0,timesteps)) if ddim_use_original_steps else (timesteps)\n",
    "    total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
    "    print(f\"Running DDIM reverse Sampling with {total_steps} timesteps\")\n",
    "\n",
    "    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)\n",
    "    #input_noise = torch.randn(img.shape, device=device)\n",
    "    for i, step in enumerate(iterator):\n",
    "        index = i\n",
    "        ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
    "        norm_t = int(ts*10/1000)\n",
    "        if mask is not None:\n",
    "            assert x0 is not None\n",
    "            img_orig = sampler.model.q_sample(x0, ts)  # TODO: deterministic forward pass?\n",
    "            img = img_orig * mask + (1. - mask) * img\n",
    "        #seed_everything(42)\n",
    "        outs = sampler.p_sample_ddim_reverse(img, cond[norm_t], ts, index=index, use_original_steps=ddim_use_original_steps,\n",
    "                                    quantize_denoised=quantize_denoised, temperature=temperature,\n",
    "                                    noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
    "                                    corrector_kwargs=corrector_kwargs,\n",
    "                                    unconditional_guidance_scale=1.0,\n",
    "                                    unconditional_conditioning=cond[norm_t])\n",
    "        img, pred_x0 = outs\n",
    "        if callback: callback(i)\n",
    "        if img_callback: img_callback(pred_x0, i)\n",
    "\n",
    "        if index % log_every_t == 0 or index == total_steps - 1:\n",
    "            intermediates['x_inter'].append(img)\n",
    "            intermediates['pred_x0'].append(pred_x0)\n",
    "\n",
    "    return img, intermediates\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe8b3992",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode( x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,\n",
    "            use_original_steps=False, input_noise = None,initial_img = None,resnet=None,guidance=None,loss_guidance_scale=0,\n",
    "            mask=0,initial_latent=None):\n",
    "\n",
    "    timesteps = np.arange(1000) if use_original_steps else sampler.ddim_timesteps\n",
    "    timesteps = timesteps[:t_start]\n",
    "\n",
    "    time_range = np.flip(timesteps)\n",
    "    total_steps = timesteps.shape[0]\n",
    "    print(f\"Running DDIM Sampling with {total_steps} timesteps\")\n",
    "\n",
    "    iterator = tqdm(time_range, desc='Decoding image', total=total_steps)\n",
    "    x_dec = x_latent\n",
    "    for i, step in enumerate(iterator):\n",
    "        index = total_steps - i - 1\n",
    "        ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)\n",
    "        norm_t = int(ts[0]*10/1000)\n",
    "        if int(ts[0])>400:\n",
    "            cond2 = cond[norm_t]\n",
    "        else:\n",
    "            cond2 =  unconditional_conditioning\n",
    "        \n",
    "        print(norm_t)\n",
    "        x_dec, _ = sampler.p_sample_ddim(x_dec, cond2, ts, index=index, use_original_steps=use_original_steps,\n",
    "                                        unconditional_guidance_scale=1,\n",
    "                                        unconditional_conditioning=cond2,\n",
    "                                        input_noise = input_noise,\n",
    "                                        initial_img = initial_img,\n",
    "                                        resnet = resnet,\n",
    "                                        guidance=guidance,\n",
    "                                        loss_guidance_scale=loss_guidance_scale,\n",
    "                                        mask=mask,initial_latent=initial_latent)\n",
    "    return x_dec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ff8a8f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(prompt = '', content_dir = '',ddim_steps = 50,strength = 0.5, model = None, resnet = None,seed=42):\n",
    "    ddim_eta=0.0\n",
    "    n_iter=1\n",
    "    C=4\n",
    "    f=8\n",
    "    n_samples=1\n",
    "    n_rows=0\n",
    "    scale=10.0\n",
    "    \n",
    "    precision=\"autocast\"\n",
    "    outdir=\"./out\"\n",
    "    seed_everything(seed)\n",
    "\n",
    "\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "    outpath = outdir\n",
    "\n",
    "    batch_size = n_samples\n",
    "    n_rows = n_rows if n_rows > 0 else batch_size\n",
    "    data = [batch_size * [prompt]]\n",
    "\n",
    "\n",
    "    sample_path = os.path.join(outpath, \"samples\")\n",
    "    os.makedirs(sample_path, exist_ok=True)\n",
    "    base_count = len(os.listdir(sample_path))\n",
    "    grid_count = len(os.listdir(outpath)) + 10\n",
    "\n",
    "    content_name =  content_dir.split('/')[-1].split('.')[0]\n",
    "    content_image = load_img(content_dir).to(device)\n",
    "    content_image = repeat(content_image, '1 ... -> b ...', b=batch_size)\n",
    "    content_latent = model.get_first_stage_encoding(model.encode_first_stage(content_image))  # move to latent space\n",
    "    \n",
    "    mask = evaluate_mask(dspth=content_dir,device = device).unsqueeze(0).to(device)\n",
    "\n",
    "    init_latent = content_latent\n",
    "\n",
    "    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)\n",
    "\n",
    "    assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'\n",
    "    t_enc = int(strength * ddim_steps)\n",
    "    print(f\"target t_enc is {t_enc} steps\")\n",
    "\n",
    "    precision_scope = autocast if precision == \"autocast\" else nullcontext\n",
    "    with torch.no_grad():\n",
    "        with precision_scope(\"cuda\"):\n",
    "            with model.ema_scope():\n",
    "                tic = time.time()\n",
    "                all_samples = list()\n",
    "                for n in trange(n_iter, desc=\"Sampling\"):\n",
    "                    for prompts in tqdm(data, desc=\"data\"):\n",
    "                        uc = None\n",
    "                        if scale != 1.0:\n",
    "                            uc=[]\n",
    "                            for i in range(10):\n",
    "                                uc.append(model.get_learned_conditioning(batch_size * [\"\"], content_image,i))\n",
    "                        if isinstance(prompts, tuple):\n",
    "                            prompts = list(prompts)\n",
    "                        c= [] \n",
    "                        for i in range(10):\n",
    "                            c.append( model.get_learned_conditioning(prompts, content_image,i))\n",
    "\n",
    "                        t_enc = int(strength * ddim_steps)\n",
    "                        x_inversion,_ = sample_reverse(ddim_steps,1,(4,512,512),c,verbose=False, eta=0.,x_T = init_latent,\n",
    "                unconditional_guidance_scale=scale,\n",
    "                unconditional_conditioning=uc[0],)\n",
    "                        torch.save(x_inversion,os.path.join(outpath, content_name+'_zt.pt'))\n",
    "                        torch.save(c,os.path.join(outpath, content_name+'_embedding.pt'))\n",
    "                        del x_inversion\n",
    "                        z_enc=[]\n",
    "                        z_enc1 = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
    "                        z_enc2 = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
    "                        z_enc3 = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
    "                        z_enc4 = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
    "                        z_enc.extend([z_enc1,z_enc2,z_enc3,z_enc4])\n",
    "                        z_enc = torch.cat((z_enc[:]),dim=0)\n",
    "\n",
    "                        torch.cuda.empty_cache()\n",
    "                        samples = decode(z_enc, c, t_enc, \n",
    "                                                unconditional_guidance_scale=scale,\n",
    "                                                    unconditional_conditioning=uc[0],initial_img=content_image,resnet=resnet,\n",
    "                                                    guidance = True,loss_guidance_scale=1,mask=mask,initial_latent=init_latent)\n",
    "                        print(z_enc.shape, uc[0].shape, t_enc)\n",
    "                        # txt2img\n",
    "\n",
    "                        x_samples = model.decode_first_stage(samples)\n",
    "\n",
    "                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
    "\n",
    "                        for x_sample in x_samples:\n",
    "                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
    "                            base_count += 1\n",
    "                            output = Image.fromarray(x_sample.astype(np.uint8))\n",
    "                            output.save(os.path.join(outpath, content_name+f'-{grid_count:04}.png'))\n",
    "                            grid_count += 1\n",
    "\n",
    "                toc = time.time()\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08dd2dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.cpu()\n",
    "model.embedding_manager.load('./logs/39.jpg2023-06-28T16-53-37_celeba_39/checkpoints/embeddings.pt')\n",
    "facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')\n",
    "facenet.load_state_dict(torch.load(\".cache/torch/hub/checkpoints/model_ir_se50.pth\"))\n",
    "facenet.eval()\n",
    "facenet = facenet.to(device)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca318f6c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\n",
    "main(prompt = '*', \\\n",
    "     content_dir = './dataset/39.jpg', \\\n",
    "     ddim_steps = 50, \\\n",
    "     strength = 0.6, \\\n",
    "     seed=125, \\\n",
    "     model = model,\\\n",
    "     resnet=facenet,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb7fbf81",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa770b54",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "4bfdbc5ecf268fe8cbe1003c5e2c130e872c62898120a6963bc33993ee6594f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
