{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualisation from latent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys  \n",
    "sys.path.insert(0, '/Users/andrzej/Personal/Projects/disentanglement-multi-task')\n",
    "from models.ae import AEModel\n",
    "from common.data_loader import get_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "okragOli = torch.from_numpy(np.load('../inverse_circle_umap_n_neighbors_50_min_dist_0.1.npy'))\n",
    "means = okragOli.mean(dim=1, keepdim=True)\n",
    "stds = okragOli.std(dim=1, keepdim=True)\n",
    "normalized_data = (okragOli - means) / stds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from architectures import encoders, decoders\n",
    "\n",
    "encoder_name = \"SimpleConv64\"\n",
    "decoder_name = \"SimpleConv64\"\n",
    "\n",
    "encoder = getattr(encoders, encoder_name)\n",
    "decoder = getattr(decoders, decoder_name)\n",
    "\n",
    "# model and optimizer\n",
    "model = AEModel(encoder(8, 3, 64), decoder(8, 3, 64)).to(torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(checkpoint['model_states']['G'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = model.decoder(normalized_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.utils\n",
    "def visualize_recon(recon_image):\n",
    "        recon_image = torchvision.utils.make_grid(recon_image)\n",
    "\n",
    "        samples = recon_image\n",
    "\n",
    "        torchvision.utils.save_image(samples, \"test.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_recon(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualisation traverse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))\n",
    "model.load_state_dict(checkpoint['model_states']['G'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = get_dataloader('shapes3d_multitask', '/Users/andrzej/Personal/Projects/data/test_dsets', 1,\n",
    "                              123, 64, split=\"train\", test_prec=0.125, val_prec=0.125,\n",
    "                              num_workers=1, pin_memory=True,n_task_headers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from common.utils import grid2gif, get_data_for_visualization, prepare_data_for_visualization\n",
    "import os\n",
    "\n",
    "z_dim = 8\n",
    "l_dim = 0\n",
    "traverse_z = True\n",
    "traverse_c = False\n",
    "num_labels = 0\n",
    "image_size = 64\n",
    "num_channels = train_loader.dataset.num_channels()\n",
    "\n",
    "def set_z(z, latent_id, val):\n",
    "    z[:, latent_id] += val\n",
    "\n",
    "def encode_deterministic(**kwargs):\n",
    "    images = kwargs['images']\n",
    "    if len(images.size()) == 3:\n",
    "        images = images.unsqueeze(0)\n",
    "    z = model.encode(images)\n",
    "    means = z.mean(dim=1, keepdim=True)\n",
    "    stds = z.std(dim=1, keepdim=True)\n",
    "    normalized_data = (z - means) / stds\n",
    "    return normalized_data\n",
    "\n",
    "def decode_deterministic(**kwargs):\n",
    "    latent = kwargs['latent']\n",
    "    if len(latent.size()) == 1:\n",
    "        latent = latent.unsqueeze(0)\n",
    "    return model.decode(latent)\n",
    "\n",
    "def visualize_traverse(limit: tuple, spacing, data=None, test=False, data_to_visualisation=None):\n",
    "    interp_values = torch.arange(limit[0], limit[1]+spacing, spacing)\n",
    "    num_cols = interp_values.size(0)\n",
    "\n",
    "    sample_images_dict, sample_labels_dict = prepare_data_for_visualization(data_to_visualisation)\n",
    "    encodings = dict()\n",
    "        \n",
    "    for key in sample_images_dict.keys():\n",
    "        encodings[key] = encode_deterministic(images=sample_images_dict[key], labels=sample_labels_dict[key])\n",
    "\n",
    "    gifs = []\n",
    "    for key in encodings:\n",
    "        latent_orig = encodings[key]\n",
    "        label_orig = sample_labels_dict[key]\n",
    "        orig_image = sample_images_dict[key]\n",
    "        print('latent_orig: {}, label_orig: {}'.format(latent_orig, label_orig))\n",
    "        samples = []\n",
    "\n",
    "        # encode original on the first row\n",
    "        sample = decode_deterministic(latent=latent_orig.detach(), labels=label_orig)\n",
    "        \n",
    "        for _ in interp_values:\n",
    "            samples.append(orig_image.unsqueeze(0))\n",
    "            \n",
    "        for _ in interp_values:\n",
    "            samples.append(sample)\n",
    "            \n",
    "        for zid in range(z_dim):\n",
    "            for val in interp_values:\n",
    "                latent = latent_orig.clone()\n",
    "                latent[:, zid] += val\n",
    "                set_z(latent, zid, val)\n",
    "                sample = decode_deterministic(latent=latent, labels=label_orig)\n",
    "\n",
    "                samples.append(sample)\n",
    "                gifs.append(sample)\n",
    "                    \n",
    "        samples = torch.cat(samples, dim=0).cpu()\n",
    "        samples = torchvision.utils.make_grid(samples, nrow=num_cols)\n",
    "        \n",
    "        file_name = os.path.join(\".\", '{}_{}.{}'.format(\"traverse\", key, \"png\"))\n",
    "        torchvision.utils.save_image(samples, file_name)\n",
    "        \n",
    "    total_rows = num_labels * l_dim + \\\n",
    "                 z_dim * int(traverse_z) + \\\n",
    "                 num_labels * int(traverse_c)\n",
    "    gifs = torch.cat(gifs)\n",
    "    gifs = gifs.view(len(encodings), total_rows, num_cols,\n",
    "                     num_channels, image_size, image_size).transpose(1, 2)\n",
    "    for i, key in enumerate(encodings.keys()):\n",
    "        for j, val in enumerate(interp_values):\n",
    "            file_name = \\\n",
    "                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png'))\n",
    "            torchvision.utils.save_image(tensor=gifs[i][j].cpu(),\n",
    "                                         fp=file_name,\n",
    "                                         nrow=total_rows, pad_value=1)\n",
    "            \n",
    "        file_name = os.path.join('.', '{}_{}.{}'.format('traverse', key, 'gif'))\n",
    "\n",
    "        grid2gif(str(os.path.join('.', '{}_{}*.{}').format('tmp', key, 'png')),\n",
    "                 file_name, delay=10)\n",
    "\n",
    "        # Delete temp image files\n",
    "        for j, val in enumerate(interp_values):\n",
    "            os.remove(\n",
    "                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png')))\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))\n",
    "model.load_state_dict(checkpoint['model_states']['G'])\n",
    "model.eval()\n",
    "\n",
    "min_ = -1\n",
    "max_ = 1\n",
    "spacing_ = 0.1\n",
    "samples = visualize_traverse(limit=(min_,max_), spacing=spacing_, data_to_visualisation=data)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "npimg = samples.detach().numpy()\n",
    "print(npimg.shape)\n",
    "\n",
    "plt.figure(figsize=(50,40))\n",
    "plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)\n",
    "plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)\n",
    "\n",
    "plt.savefig(\"travers-multi-10.png\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-single-5/last', map_location=torch.device('cpu'))\n",
    "model.load_state_dict(checkpoint['model_states']['G'])\n",
    "model.eval()\n",
    "\n",
    "min_ = -1\n",
    "max_ = 1\n",
    "spacing_ = 0.1\n",
    "samples = visualize_traverse(limit=(min_,max_), spacing=spacing_, data_to_visualisation=data)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(50,40))\n",
    "\n",
    "npimg = samples.detach().numpy()\n",
    "print(npimg.shape)\n",
    "plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)\n",
    "plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)\n",
    "\n",
    "plt.savefig(\"travers-single-5.png\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-random/last', map_location=torch.device('cpu'))\n",
    "model.load_state_dict(checkpoint['model_states']['G'])\n",
    "model.eval()\n",
    "\n",
    "min_ = -1\n",
    "max_ = 1\n",
    "spacing_ = 0.1\n",
    "samples = visualize_traverse(limit=(min_,max_), spacing=spacing_, data_to_visualisation=data)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(50,40))\n",
    "\n",
    "npimg = samples.detach().numpy()\n",
    "print(npimg.shape)\n",
    "plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)\n",
    "plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)\n",
    "\n",
    "plt.savefig(\"travers-random.png\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Try to plot it with labels..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "npimg = samples.detach().numpy()\n",
    "print(npimg.shape)\n",
    "plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "plt.xticks(np.arange(0, 1389, step=694), ['-1', '0', '1'],fontsize=20)\n",
    "plt.yticks(np.arange(0, 662, step=70), ['input', 'recon', '1', '2', '3','4','5','6','7','8'],fontsize=20)\n",
    "\n",
    "plt.savefig(\"test.png\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.utils\n",
    "def visualize_recon(input_image, recon_image, test=False):\n",
    "    input_image = torchvision.utils.make_grid(input_image)\n",
    "    recon_image = torchvision.utils.make_grid(recon_image)\n",
    "\n",
    "    white_line = torch.ones((3, input_image.size(1), 10)).to('cpu')\n",
    "\n",
    "    samples = torch.cat([input_image, white_line, recon_image], dim=2)\n",
    "\n",
    "    torchvision.utils.save_image(samples, 'reconstruction_random.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.utils\n",
    "def visualize_recon(recon_image):\n",
    "        recon_image = torchvision.utils.make_grid(recon_image)\n",
    "\n",
    "        samples = recon_image\n",
    "\n",
    "        torchvision.utils.save_image(samples, \"input.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = get_dataloader('shapes3d_multitask', '/Users/andrzej/Personal/Projects/data/test_dsets', 64,\n",
    "                              123, 64, split=\"train\", test_prec=0.125, val_prec=0.125,\n",
    "                              num_workers=1, pin_memory=True,n_task_headers=1)\n",
    "\n",
    "k = iter(train_loader)\n",
    "batch = k.next()[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_recon(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('/Users/andrzej/Personal/results/3dshapes-10-multi/last', map_location=torch.device('cpu'))\n",
    "model.load_state_dict(checkpoint['model_states']['G'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permute=[0,2,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch.permute(0,2,3,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch[:, permute]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
