{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32cca86",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9075904f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "from rae.utils.evaluation import parse_checkpoints_tree, parse_checkpoint\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "from rae.utils.evaluation import get_dataset\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "import logging\n",
    "from typing import Optional\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.utils import shuffle\n",
    "\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "\n",
    "EXPERIMENT_ROOT = PROJECT_ROOT / \"experiments\" / \"sec:model-reusability-ae\"\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9ba6857",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2, resize=None):\n",
    "    if resize is not None:\n",
    "        images = resize(images)\n",
    "    images = images.cpu().detach()\n",
    "    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10f59280",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d59d5194",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "PL_MODULE = LightningAutoencoder\n",
    "\n",
    "num_samples = 20\n",
    "\n",
    "mnist = get_dataset(pl_module=PL_MODULE, ckpt=checkpoints[\"mnist\"][\"ae\"][0])\n",
    "fmnist = get_dataset(pl_module=PL_MODULE, ckpt=checkpoints[\"fmnist\"][\"ae\"][0])\n",
    "cifar10 = get_dataset(pl_module=PL_MODULE, ckpt=checkpoints[\"cifar10\"][\"ae\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d518dbe5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "\n",
    "def get_class2idx(dataset, k: int = 10):\n",
    "    shuffled_idxs, shuffled_targets = shuffle(\n",
    "        np.asarray(list(range(len(dataset)))),\n",
    "        np.asarray(dataset.targets),\n",
    "        random_state=0,\n",
    "    )\n",
    "    all_targets = sorted(set(shuffled_targets))\n",
    "    class2idxs = {target: shuffled_idxs[shuffled_targets == target][:k] for target in all_targets}\n",
    "    return class2idxs\n",
    "\n",
    "\n",
    "mnist_class2idx = get_class2idx(mnist)\n",
    "fmnist_class2idx = get_class2idx(fmnist)\n",
    "cifar10_class2idx = get_class2idx(cifar10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a36d8c2",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "\n",
    "# Sample selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21a09d0e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "mnist_class2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c26401d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fmnist_class2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "626049da",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cifar10_class2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f05bf405",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "[mnist_class2idx[x][0] for x in mnist_class2idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1e8fb01",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import default_collate\n",
    "\n",
    "mnist_idxs = [8225, 7407, 4721, 8940, 2846, 5334, 598]  # [mnist_class2idx[x][6] for x in mnist_class2idx]\n",
    "fmnist_idxs = [8711, 19, 6058, 2702, 382, 3891, 122]  # [fmnist_class2idx[x][1] for x in mnist_class2idx]\n",
    "cifar10_idxs = [11, 60, 6, 84, 98, 8940, 2606]  # + [cifar10_class2idx[x][7] for x in mnist_class2idx]\n",
    "\n",
    "batch_mnist = default_collate([mnist[i] for i in mnist_idxs])\n",
    "batch_fmnist = default_collate([fmnist[i] for i in fmnist_idxs])\n",
    "batch_cifar10 = default_collate([cifar10[i] for i in cifar10_idxs])\n",
    "\n",
    "\n",
    "fig, [ax1, ax2, ax3] = plt.subplots(\n",
    "    3,\n",
    "    1,\n",
    "    dpi=150,\n",
    ")\n",
    "plot_images(\n",
    "    ax1,\n",
    "    batch_mnist[\"image\"],\n",
    ")\n",
    "plot_images(\n",
    "    ax2,\n",
    "    batch_fmnist[\"image\"],\n",
    ")\n",
    "plot_images(\n",
    "    ax3,\n",
    "    batch_cifar10[\"image\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd75020b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "plot_images(ax1, batch_mnist['image'], )\n",
    "# Visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25b49acb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.pl_modules.pl_stitching_module import StitchingModule\n",
    "from torchvision.transforms import Resize\n",
    "\n",
    "resize = Resize((28, 28))\n",
    "\n",
    "\n",
    "def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2, resize=None):\n",
    "    ax.axis(\"off\")\n",
    "    ax.set_aspect(\"equal\")\n",
    "\n",
    "    if resize is not None:\n",
    "        images = resize(images)\n",
    "    images = images.cpu().detach()\n",
    "    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))\n",
    "\n",
    "\n",
    "def plot_stitching(ax, ckpt_a, ckpt_b, images, padding=2, resize=resize):\n",
    "    model_a, _ = parse_checkpoint(\n",
    "        module_class=PL_MODULE,\n",
    "        checkpoint_path=ckpt_a,\n",
    "        map_location=\"cpu\",\n",
    "    )\n",
    "\n",
    "    model_b, _ = parse_checkpoint(\n",
    "        module_class=PL_MODULE,\n",
    "        checkpoint_path=ckpt_b,\n",
    "        map_location=\"cpu\",\n",
    "    )\n",
    "    recon_a = model_a(images)[Output.RECONSTRUCTION]\n",
    "    model_ab = StitchingModule(model_a, model_b)\n",
    "    recon_ab = model_ab(images)[Output.RECONSTRUCTION]\n",
    "\n",
    "    plot_images(ax, torch.cat([recon_a, recon_ab]), images_per_row=recon_a.shape[0], padding=padding, resize=resize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "246b89ec",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "N_ROWS = 1\n",
    "N_COLS = 3\n",
    "RATIO = 0.3\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "\n",
    "\n",
    "fig, [source_mnist_ax, source_fmnist_ax, source_cifar10_ax] = plt.subplots(\n",
    "    N_ROWS,\n",
    "    N_COLS,\n",
    "    dpi=300,\n",
    "    sharey=True,\n",
    "    sharex=True,\n",
    "    # constrained_layout=True\n",
    ")\n",
    "fig.subplots_adjust(hspace=0.02, wspace=0.01)\n",
    "\n",
    "\n",
    "plot_images(source_mnist_ax, batch_mnist[\"image\"], resize=resize)\n",
    "plot_images(source_fmnist_ax, batch_fmnist[\"image\"], resize=resize)\n",
    "plot_images(source_cifar10_ax, batch_cifar10[\"image\"], resize=resize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c134aa4f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig.savefig(\"source.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o source.pdf source.svg\n",
    "!rm source.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "512393b3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(2)\n",
    "A_idx = 0\n",
    "B_idx = 5\n",
    "\n",
    "\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "N_ROWS = 4\n",
    "N_COLS = 3\n",
    "RATIO = 0.3\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "\n",
    "\n",
    "fig, [\n",
    "    [ae_mnist_ax, ae_fmnist_ax, ae_cifar10_ax],\n",
    "    [relae_mnist_ax, relae_fmnist_ax, relae_cifar10_ax],\n",
    "    [vae_mnist_ax, vae_fmnist_ax, vae_cifar10_ax],\n",
    "    [relvae_mnist_ax, relvae_fmnist_ax, relvae_cifar10_ax],\n",
    "] = plt.subplots(\n",
    "    N_ROWS,\n",
    "    N_COLS,\n",
    "    dpi=300,\n",
    "    sharey=True,\n",
    "    sharex=True,\n",
    "    # constrained_layout=True\n",
    ")\n",
    "fig.subplots_adjust(hspace=0.02, wspace=0.01)\n",
    "\n",
    "# fig.tight_layout()\n",
    "# fig.subplots_adjust(hspace=0, wspace=0)\n",
    "# fig.subplots_adjust(hspace = .001)\n",
    "\n",
    "\n",
    "plot_stitching(\n",
    "    ae_mnist_ax,\n",
    "    checkpoints[\"mnist\"][\"ae\"][A_idx],\n",
    "    checkpoints[\"mnist\"][\"ae\"][B_idx],\n",
    "    batch_mnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    ae_fmnist_ax,\n",
    "    checkpoints[\"fmnist\"][\"ae\"][A_idx],\n",
    "    checkpoints[\"fmnist\"][\"ae\"][B_idx],\n",
    "    batch_fmnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    ae_cifar10_ax,\n",
    "    checkpoints[\"cifar10\"][\"ae\"][A_idx],\n",
    "    checkpoints[\"cifar10\"][\"ae\"][B_idx],\n",
    "    batch_cifar10[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "\n",
    "plot_stitching(\n",
    "    vae_mnist_ax,\n",
    "    checkpoints[\"mnist\"][\"vae_nocal\"][0],\n",
    "    checkpoints[\"mnist\"][\"vae_nocal\"][1],\n",
    "    batch_mnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    vae_fmnist_ax,\n",
    "    checkpoints[\"fmnist\"][\"vae_nocal\"][1],\n",
    "    checkpoints[\"fmnist\"][\"vae_nocal\"][0],\n",
    "    batch_fmnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    vae_cifar10_ax,\n",
    "    checkpoints[\"cifar10\"][\"vae\"][A_idx],\n",
    "    checkpoints[\"cifar10\"][\"vae\"][B_idx],\n",
    "    batch_cifar10[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "\n",
    "plot_stitching(\n",
    "    relae_mnist_ax,\n",
    "    checkpoints[\"mnist\"][\"rel_ae\"][A_idx],\n",
    "    checkpoints[\"mnist\"][\"rel_ae\"][B_idx],\n",
    "    batch_mnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    relae_fmnist_ax,\n",
    "    checkpoints[\"fmnist\"][\"rel_ae\"][A_idx],\n",
    "    checkpoints[\"fmnist\"][\"rel_ae\"][B_idx],\n",
    "    batch_fmnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    relae_cifar10_ax,\n",
    "    checkpoints[\"cifar10\"][\"rel_ae\"][A_idx],\n",
    "    checkpoints[\"cifar10\"][\"rel_ae\"][B_idx],\n",
    "    batch_cifar10[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "\n",
    "plot_stitching(\n",
    "    relvae_mnist_ax,\n",
    "    checkpoints[\"mnist\"][\"rel_vae\"][A_idx],\n",
    "    checkpoints[\"mnist\"][\"rel_vae\"][B_idx],\n",
    "    batch_mnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    relvae_fmnist_ax,\n",
    "    checkpoints[\"fmnist\"][\"rel_vae\"][A_idx],\n",
    "    checkpoints[\"fmnist\"][\"rel_vae\"][B_idx],\n",
    "    batch_fmnist[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")\n",
    "plot_stitching(\n",
    "    relvae_cifar10_ax,\n",
    "    checkpoints[\"cifar10\"][\"rel_vae\"][A_idx],\n",
    "    checkpoints[\"cifar10\"][\"rel_vae\"][B_idx],\n",
    "    batch_cifar10[\"image\"],\n",
    "    padding=2,\n",
    "    resize=resize,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9aeb616",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig.savefig(\"stitching.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o stitching.pdf stitching.svg\n",
    "!rm stitching.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58be0e69",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
