{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dee1ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5cff816",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from enum import auto\n",
    "from pathlib import Path\n",
    "from typing import Callable, Dict, Optional, Tuple, Type, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import rich\n",
    "import torch\n",
    "import typer\n",
    "from torchmetrics import (\n",
    "    ErrorRelativeGlobalDimensionlessSynthesis,\n",
    "    MeanSquaredError,\n",
    "    MetricCollection,\n",
    "    PeakSignalNoiseRatio,\n",
    "    StructuralSimilarityIndexMeasure,\n",
    ")\n",
    "\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "from rae.utils.evaluation import parse_checkpoint_id, parse_checkpoints_tree, parse_checkpoint\n",
    "from collections import defaultdict\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "\n",
    "from rae.utils.evaluation import plot_latent_space\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "\n",
    "EXPERIMENT_ROOT = PROJECT_ROOT / \"experiments\" / \"fig:latent-rotation-comparison\"\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "PREDICTIONS_TSV = EXPERIMENT_ROOT / \"predictions.tsv\"\n",
    "PERFORMANCE_TSV = EXPERIMENT_ROOT / \"performance.tsv\"\n",
    "\n",
    "DATASET_SANITY = {\n",
    "    \"mnist\": (\"rae.data.vision.mnist.MNISTDataset\", \"test\"),\n",
    "    \"fmnist\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"cifar10\": (\"rae.data.vision.cifar10.CIFAR10Dataset\", \"test\"),\n",
    "    \"cifar100\": (\"rae.data.vision.cifar100.CIFAR100Dataset\", \"test\"),\n",
    "}\n",
    "MODEL_SANITY = {\n",
    "    \"vae\": \"rae.modules.vae.VanillaVAE\",\n",
    "    \"ae\": \"rae.modules.ae.VanillaAE\",\n",
    "    \"rel_vae\": \"rae.modules.rel_vae.VanillaRelVAE\",\n",
    "    \"rel_ae\": \"rae.modules.rel_ae.VanillaRelAE\",\n",
    "}\n",
    "\n",
    "\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9692327b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import parse_checkpoint\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "\n",
    "def get_latents(images_batch, ckpt, pca=None, key=Output.DEFAULT_LATENT):\n",
    "    model, _ = parse_checkpoint(\n",
    "        module_class=PL_MODULE,\n",
    "        checkpoint_path=ckpt,\n",
    "        map_location=\"cpu\",\n",
    "    )\n",
    "    latents = model(images_batch)[key].detach().squeeze()\n",
    "\n",
    "    latents2d = latents[:, [0, 1]]\n",
    "\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"x\": latents2d[:, 0].tolist(),\n",
    "            \"y\": latents2d[:, 1].tolist(),\n",
    "            \"class\": classes,\n",
    "            \"target\": targets,\n",
    "            \"index\": indexes,\n",
    "        }\n",
    "    )\n",
    "    return df, latents, pca"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cccb05d8",
   "metadata": {},
   "source": [
    "# Latent Rotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f7d4fec",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"ae\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20fee225",
   "metadata": {},
   "outputs": [],
   "source": [
    "PL_MODULE = LightningAutoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50e684b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import get_dataset\n",
    "\n",
    "images = []\n",
    "targets = []\n",
    "indexes = []\n",
    "classes = []\n",
    "\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])\n",
    "K = 2_000\n",
    "idxs = torch.randperm(len(val_dataset))[:K]\n",
    "\n",
    "for idx in idxs:\n",
    "    sample = val_dataset[idx]\n",
    "    indexes.append(sample[\"index\"].item())\n",
    "    images.append(sample[\"image\"])\n",
    "    targets.append(sample[\"target\"])\n",
    "    classes.append(sample[\"class\"])\n",
    "\n",
    "images_batch = torch.stack(images, dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13da9076",
   "metadata": {},
   "source": [
    "## AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd4313a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"ae\"]\n",
    "MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c47ff8cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "ae_latents = []\n",
    "\n",
    "for ckpt in MODELS:\n",
    "    df, latents, _ = get_latents(images_batch, ckpt, None)\n",
    "    ae_latents.append((df, latents))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "771f505a",
   "metadata": {},
   "source": [
    "## RelAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97e2e4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"rel_ae\"]\n",
    "MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a9cf1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "rel_ae_latents = []\n",
    "\n",
    "for ckpt in MODELS:\n",
    "    df, latents, _ = get_latents(images_batch, ckpt, None, key=Output.SIMILARITIES)\n",
    "    rel_ae_latents.append((df, latents))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "127f5604",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _ = parse_checkpoint(\n",
    "    module_class=PL_MODULE,\n",
    "    checkpoint_path=checkpoints[\"mnist\"][\"rel_ae\"][0],\n",
    "    map_location=\"cpu\",\n",
    ")\n",
    "sim = model(images_batch)[Output.RECONSTRUCTION]\n",
    "sim.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1558e38a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim = model(images_batch)[Output.RECONSTRUCTION]\n",
    "sim.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebea7ab6",
   "metadata": {},
   "source": [
    "## RelAE Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b025f6ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f97bdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantized_rel_aes = [\"rel_ae_0.1\", \"rel_ae_0.2\", \"rel_ae_0.3\", \"rel_ae_0.5\"]\n",
    "\n",
    "quantized_rel_latents = defaultdict(list)\n",
    "\n",
    "for model in tqdm(quantized_rel_aes):\n",
    "\n",
    "    ckpts = checkpoints[\"mnist\"][model]\n",
    "\n",
    "    for ckpt in tqdm(ckpts, leave=False):\n",
    "        df, latents, _ = get_latents(images_batch, ckpt, None, key=Output.SIMILARITIES)\n",
    "        quantized_rel_latents[model].append((df, latents))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e6118be",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _ = parse_checkpoint(\n",
    "    module_class=PL_MODULE,\n",
    "    checkpoint_path=checkpoints[\"mnist\"][\"rel_ae_0.1\"][0],\n",
    "    map_location=\"cpu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c983bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = model.decode(**model.encode(images_batch))[Output.RECONSTRUCTION]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b69c161f",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = model(images_batch)[Output.RECONSTRUCTION]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b4df266",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2132e49f",
   "metadata": {},
   "source": [
    "# Visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3705d07",
   "metadata": {},
   "outputs": [],
   "source": [
    "def latents_distance(latents):\n",
    "    dists = []\n",
    "    for i in range(len(latents)):\n",
    "        for j in range(i + 1, len(latents)):\n",
    "            x = latents[i][1]\n",
    "            y = latents[j][1]\n",
    "            # dist = ((x - y)**2).sum(dim=-1).sqrt().mean()\n",
    "            dist = ((x - y) ** 2).mean(dim=-1).mean()\n",
    "            dists.append(dist)\n",
    "    return sum(dists) / len(dists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01884026",
   "metadata": {},
   "outputs": [],
   "source": [
    "LIM = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f001efd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "template_df = ae_latents[0][0]\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = LIM\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "\n",
    "cmap = plt.cm.get_cmap(\"Set1\", 10)\n",
    "norm = plt.Normalize(template_df[\"target\"].min(), template_df[\"target\"].max())\n",
    "\n",
    "\n",
    "def plot_row(df, title, equal=True, sharey=False, sharex=False, dpi=150):\n",
    "    fig, axes = plt.subplots(dpi=dpi, nrows=N_ROWS, ncols=N_COLS, sharey=sharey, sharex=sharex, squeeze=True)\n",
    "\n",
    "    for j, ax in enumerate(axes):\n",
    "        if j == 0:\n",
    "            ax.set_ylabel(title)\n",
    "        if equal:\n",
    "            ax.set_aspect(\"equal\")\n",
    "        plot_latent_space(ax, df[j], targets=[0, 1], size=0.75, bg_alpha=0.25, alpha=1, cmap=cmap, norm=norm)\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "546f1c3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in ae_latents[:LIM]], \"AE\", True, True, True)\n",
    "latents_distance(ae_latents[:LIM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb40e1fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig(\"ae.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o ae.pdf ae.svg\n",
    "!rm ae.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a6cac3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in rel_ae_latents[:LIM]], \"RelAE\", True, True, True)\n",
    "latents_distance(rel_ae_latents[:LIM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a1aeb02",
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig(\"rel_ae.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o rel_ae.pdf rel_ae.svg\n",
    "!rm rel_ae.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a367c549",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in quantized_rel_latents[\"rel_ae_0.1\"][:LIM]], \"RelAE 0.1\", True, True, True)\n",
    "latents_distance(quantized_rel_latents[\"rel_ae_0.1\"][:LIM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46c4277d",
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig(\"rel_ae_0.1.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'rel_ae_0.1.pdf' 'rel_ae_0.1.svg'\n",
    "!rm 'rel_ae_0.1'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "985a1085",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in quantized_rel_latents[\"rel_ae_0.2\"][:LIM]], \"RelAE 0.2\", True, True, True)\n",
    "latents_distance(quantized_rel_latents[\"rel_ae_0.2\"][:LIM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ae90527",
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig(\"rel_ae_0.2.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'rel_ae_0.2.pdf' 'rel_ae_0.2.svg'\n",
    "!rm 'rel_ae_0.2'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0ec225f",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in quantized_rel_latents[\"rel_ae_0.3\"][:LIM]], \"RelAE 0.3\", True, True, True)\n",
    "latents_distance(quantized_rel_latents[\"rel_ae_0.3\"][:LIM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d91839f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig(\"rel_ae_0.3.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'rel_ae_0.3.pdf' 'rel_ae_0.3.svg'\n",
    "!rm 'rel_ae_0.3'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e0efe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in quantized_rel_latents[\"rel_ae_0.5\"][:LIM]], \"RelAE 0.5\", True, True, True)\n",
    "latents_distance(quantized_rel_latents[\"rel_ae_0.5\"][:LIM])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fdcaee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig(\"rel_ae_0.5.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'rel_ae_0.5.pdf' 'rel_ae_0.5.svg'\n",
    "!rm 'rel_ae_0.5'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca2553b4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db48899",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
