{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import importlib\n",
    "import json\n",
    "import math\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "from types import MethodType\n",
    "from datasets.datasets import get_dataloader\n",
    "from models.heatmaps import gen_heatmaps\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import h5py\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_para = {'log': 'log/generator_pro2_discriminator_pro2_k10_tau0.01_bedroom_128',\n",
    "            'data_root': '../data/bedroom',\n",
    "           'checkpoint': 'epoch_0.model'}\n",
    "model_names = [int(model_name[6:].split('.')[0]) for model_name in os.listdir(os.path.join(new_para['log'], 'checkpoints')) if model_name.endswith('.model')]\n",
    "model_names.sort()\n",
    "checkpoint = model_names[-1]\n",
    "new_para['checkpoint'] = 'epoch_' + str(model_names[-1]) + '.model'\n",
    "print(new_para['checkpoint'])\n",
    "\n",
    "with open(os.path.join(new_para['log'], 'parameters.json'), 'rt') as f:\n",
    "    para = json.load(f)\n",
    "    para.update(new_para)\n",
    "    \n",
    "device = 'cuda:0'\n",
    "device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = [100, 100, 100, 100]\n",
    "\n",
    "stage_list = np.concatenate([np.ones(n_epochs[i]) * i for i in range(len(n_epochs))]).astype(np.int)\n",
    "\n",
    "offset = 50\n",
    "\n",
    "alpha_list = []\n",
    "for i in range(len(n_epochs)):\n",
    "    alpha_list.append(np.linspace(0, 1, offset))\n",
    "    alpha_list.append(np.linspace(1, 1, n_epochs[i]-offset))\n",
    "alpha_list = np.concatenate(alpha_list)\n",
    "\n",
    "if checkpoint < len(stage_list):\n",
    "    stage = stage_list[checkpoint]\n",
    "    alpha = alpha_list[checkpoint]\n",
    "else:\n",
    "    stage = stage_list[-1]\n",
    "    alpha = 1\n",
    "\n",
    "batch_sizes = [256, 128, 64, 32]\n",
    "image_sizes = [64, 128, 256, 512]\n",
    "image_size = image_sizes[stage]\n",
    "image_size_multiplier = image_size / 2\n",
    "\n",
    "import matplotlib as mpl\n",
    "mpl.rc(\"savefig\", dpi=image_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = importlib.import_module('models.' + para['model'])\n",
    "generator = model.Generator({'z_dim': para['z_dim'], 'n_keypoints': para['n_keypoints'],\n",
    "                             'n_embedding': para['n_embedding'], 'tau': para['tau']}).to(device)\n",
    "                             \n",
    "model = torch.load(os.path.join(para['log'], 'checkpoints', para['checkpoint']),\n",
    "                               map_location=lambda storage, location: storage)\n",
    "generator.load_state_dict(model['generator'])\n",
    "generator.to(device)\n",
    "generator.eval()\n",
    "\n",
    "test_batch = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Editing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the keypoints locations\n",
    "test_batch['input_noise0'] = torch.randn(1, *(generator.noise_shapes[0])).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the background\n",
    "test_batch['input_noise1'] = torch.randn(1, *(generator.noise_shapes[1])).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the keypoints details\n",
    "test_batch['input_noise2'] = torch.randn(1, *(generator.noise_shapes[2])).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    img = generator(test_batch, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "    keypoints = generator.gen_keypoints(test_batch).detach().cpu().numpy().squeeze() * image_size_multiplier + image_size_multiplier\n",
    "\n",
    "fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "axes=fig.subplots(1, 3)\n",
    "\n",
    "img = generator(test_batch, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "keypoints = (generator.gen_keypoints(test_batch).detach().cpu().numpy().squeeze()) * image_size_multiplier + image_size_multiplier\n",
    "axes[0].imshow(img)\n",
    "axes[1].imshow(img)\n",
    "axes[1].scatter(keypoints[:, 1], keypoints[:, 0], c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "axes[2].imshow(img)\n",
    "axes[2].scatter(keypoints[:, 1], keypoints[:, 0], c=list(range(para['n_keypoints'])), s=500, marker='+', linewidth=7.5)\n",
    "\n",
    "n = list(range(para['n_keypoints']))\n",
    "for i, txt in enumerate(n):\n",
    "    axes[1].annotate(txt, (keypoints[:, 1][i], keypoints[:, 0][i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Keypoint Embedding Interpolation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolating background feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_bg_interpolation(bg2, num=6):\n",
    "    lam = torch.linspace(0, 1, num)\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(3, num)\n",
    "    attr1 = generator.gen_atrributes(test_batch)\n",
    "    attr2 = generator.gen_atrributes({'input_noise0': test_batch['input_noise0'], 'input_noise1': bg2, 'input_noise2': test_batch['input_noise2']})\n",
    "    attr = attr1.copy()\n",
    "    original_img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "    for i in range(6):\n",
    "        attr['bg_emb'] = (1-lam[i]) * attr1['bg_emb'] + lam[i] * attr2['bg_emb']\n",
    "        img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "        axes[0,i].set_title('lambda: {0:.2f}'.format(lam[i]))\n",
    "\n",
    "        axes[1,i].imshow(img)\n",
    "        axes[1,i].scatter(keypoints[:, 1], keypoints[:, 0], c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "        \n",
    "        diff_img = original_img-img\n",
    "        axes[2,i].imshow(np.abs(diff_img))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg2 = torch.randn(1, *(generator.noise_shapes[1])).to(device)\n",
    "# bg2 = test_batchs[26]['input_noise1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "view_bg_interpolation(bg2, num=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolating all keypoint features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_keypoints_interpolation(kp2, num=6):\n",
    "    lam = torch.linspace(0, 1, num)\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, num)\n",
    "    attr1 = generator.gen_atrributes(test_batch)\n",
    "    attr2 = generator.gen_atrributes({'input_noise0': test_batch['input_noise0'], 'input_noise1': test_batch['input_noise1'], 'input_noise2': kp2})\n",
    "    attr = attr1.copy()\n",
    "\n",
    "    for i in range(num):\n",
    "        attr['kp_emb'] = (1-lam[i]) * attr1['kp_emb'] + lam[i] * attr2['kp_emb']\n",
    "        img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "        axes[0,i].set_title('lambda: {0:.2f}'.format(lam[i]))\n",
    "\n",
    "        axes[1,i].imshow(img)\n",
    "        axes[1,i].scatter(keypoints[:, 1], keypoints[:, 0], c=list(range(para['n_keypoints'])), s=60, marker='+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_kp_batch = {'input_noise{}'.format(noise_i): torch.randn(1, *noise_shape).to(device) for noise_i, noise_shape in enumerate(generator.noise_shapes)}\n",
    "# test_kp_batch = test_batchs[18]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize=(2,1), dpi=image_size*2, facecolor='w', edgecolor='k')\n",
    "axes=fig.subplots(1, 2)\n",
    "\n",
    "img2 = generator(test_kp_batch, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "keypoints3 = generator.gen_keypoints(test_kp_batch).detach().cpu().numpy().squeeze() * image_size_multiplier + image_size_multiplier\n",
    "axes[0].imshow(img2)\n",
    "axes[0].axis('off')\n",
    "axes[0].scatter(keypoints3[:, 1], keypoints3[:, 0], c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "axes[1].imshow(img2)\n",
    "axes[1].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "view_keypoints_interpolation(test_kp_batch['input_noise2'], num=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolating one keypoint feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_interpolation_one_keypoint(kp2, keypoint_idx=[0], num=6):\n",
    "    lam = torch.linspace(0.0, 1.0, num)\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(3, num)\n",
    "    attr1 = generator.gen_atrributes(test_batch)\n",
    "    attr2 = generator.gen_atrributes({'input_noise0': test_batch['input_noise0'], 'input_noise1': test_batch['input_noise1'], 'input_noise2': kp2})\n",
    "    attr = attr1.copy()\n",
    "    original_img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "    for i in range(num):\n",
    "        for kp_idx in keypoint_idx:\n",
    "            attr['kp_emb'][0, kp_idx] = (1-lam[i]) * attr1['kp_emb'][0, kp_idx] + lam[i] * attr2['kp_emb'][0, kp_idx]\n",
    "        img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "        axes[0,i].set_title('lambda: {0:.2f}'.format(lam[i]))\n",
    "\n",
    "        axes[1,i].imshow(img)        \n",
    "        for idx in keypoint_idx:\n",
    "            axes[1,i].scatter(keypoints[idx, 1], keypoints[idx, 0], c=list(range(para['n_keypoints']))[idx], s=240, marker='+')\n",
    "        \n",
    "        diff_img = original_img-img\n",
    "        axes[2,i].imshow(np.abs(diff_img))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# view_interpolation_one_keypoint(test_kp_batch['input_noise2'], keypoint_idx=[18, 8, 10, 11, 14, 26, 4, 2, 13, 15, 20, 29, 21, 22, 5], num=6)\n",
    "view_interpolation_one_keypoint(test_kp_batch['input_noise2'], keypoint_idx=[5,9], num=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolation on keypoint locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_location_interpolation(pos_noise2, num=6):\n",
    "    lam = torch.linspace(0, 1, num)\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, num)\n",
    "    attr1 = generator.gen_atrributes(test_batch)\n",
    "    attr2 = generator.gen_atrributes({'input_noise0': pos_noise2, 'input_noise1': test_batch['input_noise1'], 'input_noise2': test_batch['input_noise2']})\n",
    "    attr = attr1.copy()\n",
    "\n",
    "    for i in range(num):\n",
    "        attr['keypoints'] = (1-lam[i]) * attr1['keypoints'] + lam[i] * attr2['keypoints']\n",
    "        img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "        axes[0,i].set_title('lambda: {0:.2f}'.format(lam[i]))\n",
    "\n",
    "        axes[1,i].imshow(img)\n",
    "        axes[1,i].scatter(attr['keypoints'][0, :, 1].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, :, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints'])), s=60, marker='+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pos_batch = {'input_noise{}'.format(noise_i): torch.randn(1, *noise_shape).to(device) for noise_i, noise_shape in enumerate(generator.noise_shapes)}\n",
    "# test_pos_batch = test_batchs[4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize=(4,2), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "axes=fig.subplots(1, 2)\n",
    "\n",
    "img2 = generator(test_pos_batch, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "keypoints2 = generator.gen_keypoints(test_pos_batch).detach().cpu().numpy().squeeze() * image_size_multiplier + image_size_multiplier\n",
    "axes[0].imshow(img2)\n",
    "axes[0].axis('off')\n",
    "axes[0].scatter(keypoints2[:, 1], keypoints2[:, 0], c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "axes[1].imshow(img2)\n",
    "axes[1].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "view_location_interpolation(test_pos_batch['input_noise0'], num=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpolation individual keypoint locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_location_interpolation_individual(pos_noise2, kp_indices=[0, 1], num=6):\n",
    "    lam = torch.linspace(0, 1, num)\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, num)\n",
    "    attr1 = generator.gen_atrributes(test_batch)\n",
    "    attr2 = generator.gen_atrributes({'input_noise0': pos_noise2, 'input_noise1': test_batch['input_noise1'], 'input_noise2': test_batch['input_noise2']})\n",
    "    attr = attr1.copy()\n",
    "\n",
    "    for i in range(num):\n",
    "        for kp_idx in kp_indices:\n",
    "            attr['keypoints'][:, kp_idx, :] = (1-lam[i]) * attr1['keypoints'][:, kp_idx, :] + lam[i] * attr2['keypoints'][:, kp_idx, :]\n",
    "        img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "        axes[0,i].set_title('lambda: {0:.2f}'.format(lam[i]))\n",
    "\n",
    "        axes[1,i].imshow(img)\n",
    "        axes[1,i].scatter(attr['keypoints'][0, :, 1].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, :, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints'])), s=60, marker='+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "view_location_interpolation_individual(test_pos_batch['input_noise0'], kp_indices=[8], num=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interpolation individual custom keypoint locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_custom_location_interpolation_individual(movement, num=6):\n",
    "    lam = torch.linspace(0, 1, num)\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, num)\n",
    "    attr1 = generator.gen_atrributes(test_batch)\n",
    "    attr2 = attr1.copy()\n",
    "    attr2['keypoints'] = attr2['keypoints'] + movement\n",
    "    attr = attr1.copy()\n",
    "\n",
    "    for i in range(num):\n",
    "        attr['keypoints'] = (1-lam[i]) * attr1['keypoints'] + lam[i] * attr2['keypoints']\n",
    "        img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "        axes[0,i].set_title('lambda: {0:.2f}'.format(lam[i]))\n",
    "\n",
    "        axes[1,i].imshow(img)\n",
    "        axes[1,i].scatter(attr['keypoints'][0, 6, 1].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, 6, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints']))[6], s=240, marker='+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "movement = torch.zeros(1, para['n_keypoints'], 2).to(device)\n",
    "movement[:, 6, 0] += 0.5\n",
    "# movement[:, 6, 1] = 0.2\n",
    "view_custom_location_interpolation_individual(movement, num=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Removing Parts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_removing_parts(kp_indices):\n",
    "    num = len(kp_indices) + 2\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size*10, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, num)\n",
    "    attr = generator.gen_atrributes(test_batch)\n",
    "\n",
    "    for i in range(num-1):\n",
    "        sub_kp_indices = kp_indices[:i]\n",
    "        img = generator.use_atrributes_removing_parts(attr, sub_kp_indices, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "        axes[0,i].imshow(img)\n",
    "\n",
    "        axes[1,i].imshow(img)\n",
    "        axes[1,i].scatter(attr['keypoints'][0, 2, 1].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, 2, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints']))[2], s=240, marker='+')\n",
    "    sub_kp_indices = list(range(para['n_keypoints']))\n",
    "    img = generator.use_atrributes_removing_parts(attr, sub_kp_indices, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "    axes[0,num-1].imshow(img)\n",
    "    axes[1,num-1].imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# view_removing_parts(kp_indices=[0, 2, 3, 6])  # for face\n",
    "view_removing_parts(kp_indices=[9,5])  # for bedroom"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding Parts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_adding_parts(kp_indices, kp_n_pos_deviation):\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, 2)\n",
    "    attr = generator.gen_atrributes(test_batch)\n",
    "    \n",
    "    img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "    axes[0,0].imshow(img)\n",
    "    axes[1,0].imshow(img)\n",
    "\n",
    "    kp_n_pos = []\n",
    "    for kp_idx in kp_indices:\n",
    "        kp_n_pos.append(attr['keypoints'][:, kp_idx:kp_idx+1, :] + kp_n_pos_deviation)\n",
    "    \n",
    "    img = generator.use_atrributes_multi_parts(attr, attr, kp_indices, kp_n_pos, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "\n",
    "    axes[0,1].imshow(img)\n",
    "\n",
    "    axes[1,1].imshow(img)\n",
    "    \n",
    "    for i in range(len(kp_n_pos)):\n",
    "        axes[1,1].scatter(kp_n_pos[i][0, :, 1].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          kp_n_pos[i][0, :, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(kp_n_pos[i].shape[1])), s=240, marker='+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kp_n_pos = torch.tensor([0., 0.4]).reshape(1, 1, 2).to(device)\n",
    "# view_adding_parts(kp_indices=[8, 6], kp_n_pos_deviation=kp_n_pos)\n",
    "view_adding_parts(kp_indices=[7], kp_n_pos_deviation=kp_n_pos)\n",
    "# view_adding_parts(kp_indices=[5, 9], kp_n_pos_deviation=kp_n_pos)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding another objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_adding_object(dist):\n",
    "    fig=plt.figure(figsize=(20, 5), dpi=image_size*2, facecolor='w', edgecolor='k')\n",
    "    axes=fig.subplots(2, 2)\n",
    "    attr = generator.gen_atrributes(test_batch)\n",
    "    \n",
    "    img = generator.use_atrributes(attr, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "    axes[0,0].imshow(img)\n",
    "    axes[1,0].imshow(img)\n",
    "    axes[1,0].scatter(attr['keypoints'][0, :, 1].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, :, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "    \n",
    "    img = generator.use_atrributes_two_object(attr, dist, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "    axes[0,1].imshow(img)\n",
    "    axes[1,1].imshow(img)\n",
    "    \n",
    "    axes[1,1].scatter((attr['keypoints'][0, :, 1]-dist).cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, :, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "    \n",
    "    axes[1,1].scatter((attr['keypoints'][0, :, 1]+dist).cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          attr['keypoints'][0, :, 0].cpu()*image_size_multiplier + image_size_multiplier, \n",
    "                          c=list(range(para['n_keypoints'])), s=60, marker='+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# view_adding_object(dist=0.75)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Keypoint Embedding Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.manifold import TSNE\n",
    "tsne_X = []\n",
    "for _ in range(10):\n",
    "    kp_emb = generator.gen_keypoints_feat(torch.randn(32, *(generator.noise_shapes[2])).to(device)).detach().cpu().numpy()\n",
    "    bg_emb = generator.gen_background_embedding(torch.randn(32, *(generator.noise_shapes[1])).to(device)).unsqueeze(1).detach().cpu().numpy()\n",
    "    tsne_X.append(np.concatenate([kp_emb, bg_emb], axis=1))\n",
    "tsne_X = np.concatenate(tsne_X).reshape(-1, generator.n_embedding)\n",
    "tsne_X_embedded = TSNE(n_components=2).fit_transform(tsne_X)\n",
    "plt.scatter(tsne_X_embedded[:,0], tsne_X_embedded[:,1], c=list(range(generator.n_keypoints+1))*320)\n",
    "\n",
    "# from sklearn.manifold import TSNE\n",
    "# tsne_X = []\n",
    "# for _ in range(10):\n",
    "#     kp_emb = generator.gen_keypoints_feat(torch.randn(32, *(generator.noise_shapes[2])).to(device)).detach().cpu().numpy()\n",
    "#     tsne_X.append(kp_emb)\n",
    "# tsne_X = np.concatenate(tsne_X).reshape(-1, generator.n_embedding)\n",
    "# tsne_X_embedded = TSNE(n_components=2).fit_transform(tsne_X)\n",
    "# plt.scatter(tsne_X_embedded[:,0], tsne_X_embedded[:,1], c=list(range(generator.n_keypoints))*320)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.decomposition import PCA\n",
    "# pca_X = []\n",
    "# for _ in range(10):\n",
    "#     kp_emb = generator.gen_keypoints_feat(torch.randn(32, *(generator.noise_shapes[2])).to(device)).detach().cpu().numpy()\n",
    "#     bg_emb = generator.gen_background_embedding(torch.randn(32, *(generator.noise_shapes[1])).to(device)).unsqueeze(1).detach().cpu().numpy()\n",
    "#     pca_X.append(np.concatenate([kp_emb, bg_emb], axis=1))\n",
    "# pca_X = np.concatenate(pca_X).reshape(-1, generator.n_embedding)\n",
    "# pca_X_embedded = PCA(n_components=2).fit_transform(pca_X)\n",
    "# plt.scatter(pca_X_embedded[:,0], pca_X_embedded[:,1], c=list(range(generator.n_keypoints+1))*320)\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "pca_X = []\n",
    "for _ in range(10):\n",
    "    pca_X.append(generator.gen_keypoints_feat(torch.randn(32, *(generator.noise_shapes[2])).to(device)).detach().cpu().numpy())\n",
    "pca_X = np.concatenate(pca_X).reshape(-1, generator.n_embedding)\n",
    "pca_X_embedded = PCA(n_components=2).fit_transform(pca_X)\n",
    "plt.scatter(pca_X_embedded[:,0], pca_X_embedded[:,1], c=list(range(generator.n_keypoints))*320)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sample Lots of Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batchs = []\n",
    "fig=plt.figure(figsize=(6, 6), dpi=256, facecolor='w', edgecolor='k')\n",
    "axes=fig.subplots(6, 6)\n",
    "\n",
    "index = 0\n",
    "for i in range(6):\n",
    "    for j in range(6):\n",
    "        with torch.no_grad():\n",
    "            test_batchs.append({'input_noise{}'.format(noise_i): torch.randn(1, *noise_shape).to(device) for noise_i, noise_shape in enumerate(generator.noise_shapes)})\n",
    "            img = generator(test_batchs[-1], stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "            img = np.clip(img, 0, 1)\n",
    "            axes[i, j].imshow(img)\n",
    "            axes[i, j].set_title(str(index))\n",
    "            axes[i, j].axis('off')\n",
    "            index += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch = test_batchs[33]\n",
    "with torch.no_grad():\n",
    "    img = generator(test_batch, stage=stage, alpha=alpha)['img'].detach().cpu().numpy().transpose((0, 2, 3, 1)).squeeze() * 0.5 + 0.5\n",
    "    keypoints = generator.gen_keypoints(test_batch).detach().cpu().numpy().squeeze() * image_size_multiplier + image_size_multiplier\n",
    "\n",
    "fig=plt.figure(figsize=(20, 5), dpi=image_size, facecolor='w', edgecolor='k')\n",
    "axes=fig.subplots(1, 3)\n",
    "\n",
    "img = generator(test_batch, stage=stage, alpha=alpha)['img'].detach().cpu().permute(0, 2, 3, 1).squeeze() * 0.5 + 0.5\n",
    "img = torch.clamp(img, 0, 1)\n",
    "keypoints = (generator.gen_keypoints(test_batch).detach().cpu().numpy().squeeze()) * image_size_multiplier + image_size_multiplier\n",
    "axes[0].imshow(img)\n",
    "axes[1].imshow(img)\n",
    "axes[1].scatter(keypoints[:, 1], keypoints[:, 0], c=list(range(para['n_keypoints'])), s=60, marker='+')\n",
    "# axes[2].imshow(gen_heatmaps(generator.gen_keypoints(test_batch), heatmap_size=128).max(dim=1)[0].squeeze().detach().cpu().numpy())\n",
    "axes[2].imshow(img)\n",
    "axes[2].scatter(keypoints[:, 1], keypoints[:, 0], c=list(range(para['n_keypoints'])), s=500, marker='+', linewidth=7.5)\n",
    "\n",
    "n = list(range(para['n_keypoints']))\n",
    "for i, txt in enumerate(n):\n",
    "    axes[1].annotate(txt, (keypoints[:, 1][i], keypoints[:, 0][i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
