{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e34d34",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "668ef655",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from enum import auto\n",
    "from pathlib import Path\n",
    "from typing import Callable, Dict, Optional, Tuple, Type, Union\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import rich\n",
    "import torch\n",
    "import typer\n",
    "from torchmetrics import (\n",
    "    ErrorRelativeGlobalDimensionlessSynthesis,\n",
    "    MeanSquaredError,\n",
    "    MetricCollection,\n",
    "    PeakSignalNoiseRatio,\n",
    "    StructuralSimilarityIndexMeasure,\n",
    ")\n",
    "\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "from rae.utils.evaluation import parse_checkpoint_id, parse_checkpoints_tree, parse_checkpoint\n",
    "from collections import defaultdict\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "\n",
    "from rae.utils.evaluation import plot_latent_space\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "\n",
    "EXPERIMENT_ROOT = PROJECT_ROOT / \"experiments\" / \"fig:latent-rotation-comparison\"\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)\n",
    "\n",
    "PL_MODULE = LightningAutoencoder\n",
    "MODEL = checkpoints[\"mnist\"][\"ae\"][0]\n",
    "MODEL"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a815d9b",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d10345c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import get_dataset\n",
    "\n",
    "images = []\n",
    "targets = []\n",
    "indexes = []\n",
    "classes = []\n",
    "\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODEL)\n",
    "K = 512\n",
    "idxs = torch.randperm(len(val_dataset))[:K]\n",
    "\n",
    "for idx in idxs:\n",
    "    sample = val_dataset[idx]\n",
    "    indexes.append(sample[\"index\"].item())\n",
    "    images.append(sample[\"image\"])\n",
    "    targets.append(sample[\"target\"])\n",
    "    classes.append(sample[\"class\"])\n",
    "\n",
    "images_batch = torch.stack(images, dim=0)\n",
    "images_batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcab275b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _ = parse_checkpoint(\n",
    "    module_class=PL_MODULE,\n",
    "    checkpoint_path=MODEL,\n",
    "    map_location=\"cpu\",\n",
    ")\n",
    "anchors_batch = model.metadata.anchors_images\n",
    "anchors_batch.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84c42756",
   "metadata": {},
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce1fcc0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_row(\n",
    "    axes,\n",
    "    dfs,\n",
    "    title=None,\n",
    "    equal=True,\n",
    "):\n",
    "    for j, ax in enumerate(axes):\n",
    "        if j == 0:\n",
    "            if title is not None:\n",
    "                ax.set_title(title)\n",
    "        if equal:\n",
    "            ax.set_aspect(\"equal\")\n",
    "        plot_latent_space(ax, dfs[j], targets=[0, 1], size=0.75, bg_alpha=0.25, alpha=1, cmap=cmap, norm=norm)\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8987afff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def latents_distance(latents, return_mean: bool = True):\n",
    "    dists = []\n",
    "    all_dists = []\n",
    "\n",
    "    for i in range(len(latents)):\n",
    "        for j in range(i + 1, len(latents)):\n",
    "            x = latents[i][1]\n",
    "            y = latents[j][1]\n",
    "            # dist = ((x - y)**2).sum(dim=-1).sqrt().mean()\n",
    "            dist = F.pairwise_distance(x, y, p=2).mean()\n",
    "            # dist = F.mse_loss(x, y, reduction=\"mean\")\n",
    "            # dist = ((x - y) ** 2).mean(dim=-1).mean()\n",
    "            dists.append(f\"pair=({i}, {j}): {dist:.2e}\")\n",
    "            all_dists.append(dist)\n",
    "    if return_mean:\n",
    "        return f\"{sum(all_dists) / len(all_dists):.2e}\"\n",
    "    else:\n",
    "        return \" \".join(dists)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58c1d8bf",
   "metadata": {},
   "source": [
    "## Latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d90f213",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import parse_checkpoint\n",
    "import torch.nn.functional as F\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "\n",
    "def get_latents(images_batch, model, key=Output.DEFAULT_LATENT):\n",
    "    latents = model(images_batch)[key].detach().squeeze()\n",
    "    return latents\n",
    "\n",
    "\n",
    "def to_df(latents, fit_pca: bool = True):\n",
    "    if fit_pca:\n",
    "        latents2d = PCA(n_components=2).fit_transform(latents.cpu())\n",
    "    else:\n",
    "        latents2d = latents[:, [0, 1]]\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"x\": latents2d[:, 0].tolist(),\n",
    "            \"y\": latents2d[:, 1].tolist(),\n",
    "            \"class\": classes,\n",
    "            \"target\": targets,\n",
    "            \"index\": indexes,\n",
    "        }\n",
    "    )\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acd2a36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "images_latents = get_latents(images_batch, model)\n",
    "anchors_latents = get_latents(anchors_batch, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae566bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "images_latents.shape, anchors_latents.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78a32467",
   "metadata": {},
   "source": [
    "## Build image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a20a5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.attention import *\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from torch.optim.adam import Adam\n",
    "from sklearn.decomposition import PCA\n",
    "from scipy.stats import ortho_group\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "\n",
    "def transform_latents(x, seed=0):\n",
    "    opt_isometry, _ = TRANSFORMS[seed]\n",
    "\n",
    "    noise = torch.randn_like(x) * 0.001\n",
    "    o = x + noise\n",
    "    o = o @ opt_isometry\n",
    "    return o\n",
    "\n",
    "\n",
    "LIM = 6\n",
    "N_ROWS = 4\n",
    "N_COLS = LIM\n",
    "\n",
    "TRANSFORMS = [\n",
    "    (\n",
    "        torch.tensor(ortho_group.rvs(images_latents.shape[-1]), dtype=torch.float),\n",
    "        torch.zeros(images_latents.shape[-1], dtype=torch.float),\n",
    "    )\n",
    "    for _ in range(LIM)\n",
    "]\n",
    "\n",
    "# plt.rcParams.update(bundles.icml2022())\n",
    "# plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "\n",
    "cmap = plt.cm.get_cmap(\"Set1\", 10)\n",
    "norm = plt.Normalize(min(targets), max(targets))\n",
    "\n",
    "\n",
    "SHARE_X = True\n",
    "SHARE_Y = True\n",
    "DPI = 300\n",
    "\n",
    "fig, [\n",
    "    abs_axes,\n",
    "    rel_axes,\n",
    "    quant_axes1,\n",
    "    quant_axes2,\n",
    "] = plt.subplots(dpi=DPI, nrows=N_ROWS, ncols=N_COLS, sharey=SHARE_Y, sharex=SHARE_X, squeeze=True)\n",
    "\n",
    "\n",
    "all_images_latents = [\n",
    "    images_latents,\n",
    "    *[transform_latents(images_latents, seed=idx) for idx in range(LIM - 1)],\n",
    "]\n",
    "all_anchors_latents = [\n",
    "    anchors_latents,\n",
    "    *[transform_latents(anchors_latents, seed=idx) for idx in range(LIM - 1)],\n",
    "]\n",
    "\n",
    "\n",
    "plot_row(\n",
    "    abs_axes,\n",
    "    [to_df(x) for x in all_images_latents],\n",
    ")\n",
    "\n",
    "\n",
    "dists_str = [f\"AE; mean L2 dists: {latents_distance(all_images_latents)}\"]\n",
    "\n",
    "for axes, quant_mode, bin_size in (\n",
    "    (rel_axes, None, None),\n",
    "    (quant_axes1, AbsoluteQuantizationMode.KMEANS, 10),\n",
    "    (quant_axes2, AbsoluteQuantizationMode.KMEANS, 100),\n",
    "):\n",
    "    rel_latents = []\n",
    "    rel_attention = RelativeAttention(\n",
    "        n_anchors=anchors_batch.shape,\n",
    "        n_classes=len(set(targets)),\n",
    "        similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "        values_mode=ValuesMethod.SIMILARITIES,\n",
    "        normalization_mode=NormalizationMode.L2,\n",
    "        similarities_quantization_mode=quant_mode,\n",
    "        similarities_num_clusters=bin_size,\n",
    "        output_normalization_mode=None,  # OutputNormalization.L2,\n",
    "        #         similarities_quantization_mode=quant_mode,\n",
    "        similarities_bin_size=bin_size,\n",
    "    )\n",
    "    assert sum(x.numel() for x in rel_attention.parameters()) == 0\n",
    "\n",
    "    rel_latents = [\n",
    "        rel_attention(x=img_latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]\n",
    "        for (img_latents, a_latents) in zip(all_images_latents, all_anchors_latents)\n",
    "    ]\n",
    "    plot_row(\n",
    "        axes,\n",
    "        [to_df(x) for x in rel_latents],\n",
    "    )\n",
    "    dists_str.append(f\"Att, bin size: {bin_size}; mean L2 dists: {latents_distance(rel_latents)}\")\n",
    "dists_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3f7c4ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"epsisometry.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'epsisometry.pdf' 'epsisometry.svg'\n",
    "!rm 'epsisometry'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59ba8a2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "images_latents.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "373a4287",
   "metadata": {},
   "outputs": [],
   "source": [
    "(images_latents - (images_latents + torch.randn_like(images_latents) * 0.01)).mean(0).mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c3f3de",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.functional.pairwise_distance(images_latents, images_latents + torch.randn_like(images_latents) * 0.001).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b4149f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "\n",
    "absolute_num_clusters = 10\n",
    "x = images_latents\n",
    "anchors = anchors_latents\n",
    "\n",
    "kmeans = KMeans(n_clusters=absolute_num_clusters, random_state=0)\n",
    "kmeans.fit(torch.cat([x, anchors]).detach().cpu().numpy())\n",
    "\n",
    "x = torch.as_tensor(kmeans.cluster_centers_[kmeans.predict(x.detach().cpu().numpy())])\n",
    "anchors = torch.as_tensor(kmeans.cluster_centers_[kmeans.predict(anchors.detach().cpu().numpy())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5185745d",
   "metadata": {},
   "outputs": [],
   "source": [
    "kmeans.cluster_centers_[kmeans.labels_]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c6de555",
   "metadata": {},
   "outputs": [],
   "source": [
    "kmeans.cluster_centers_.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
