{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a6d7d46",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dea7fa6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from rae.utils.evaluation import parse_checkpoints_tree\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "\n",
    "import hydra\n",
    "\n",
    "from rae.data.vision.datamodule import MyDataModule\n",
    "\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "\n",
    "EXPERIMENT_ROOT = Path(\".\").parent\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "PREDICTIONS_TSV = EXPERIMENT_ROOT / \"predictions.tsv\"\n",
    "PERFORMANCE_TSV = EXPERIMENT_ROOT / \"performance.tsv\"\n",
    "\n",
    "DATASET_SANITY = {\n",
    "    \"mnist\": (\"rae.data.vision.mnist.MNISTDataset\", \"test\"),\n",
    "    \"fmnist\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"cifar10\": (\"rae.data.vision.cifar10.CIFAR10Dataset\", \"test\"),\n",
    "    \"cifar100\": (\"rae.data.vision.cifar100.CIFAR100Dataset\", \"test\"),\n",
    "}\n",
    "MODEL_SANITY = {\n",
    "    \"abs\": \"rae.modules.vision.resnet.ResNet\",\n",
    "    \"rel\": \"rae.modules.vision.relresnet.RelResNet\",\n",
    "}\n",
    "\n",
    "\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "153b60be",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a73cd9d5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import parse_checkpoint\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "\n",
    "def get_latents(images_batch, ckpt, pca=None):\n",
    "    model, _ = parse_checkpoint(\n",
    "        module_class=PL_MODULE,\n",
    "        checkpoint_path=ckpt,\n",
    "        map_location=\"cpu\",\n",
    "    )\n",
    "    latents = model(images_batch)[Output.DEFAULT_LATENT].detach()\n",
    "\n",
    "    if latents.shape[-1] == 2:\n",
    "        latents2d = latents\n",
    "    else:\n",
    "        if pca is None:\n",
    "            pca = PCA(n_components=2)\n",
    "            pca.fit(latents)\n",
    "\n",
    "        latents2d = pca.transform(latents)\n",
    "\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"x\": latents2d[:, 0].tolist(),\n",
    "            \"y\": latents2d[:, 1].tolist(),\n",
    "            \"class\": classes,\n",
    "            \"target\": targets,\n",
    "            \"index\": indexes,\n",
    "        }\n",
    "    )\n",
    "    return df, pca"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d7943eb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ad999176",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Latent Rotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a3862bb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"small_ae\"]\n",
    "MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e7b9500",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "PL_MODULE = LightningAutoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a9c155f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import get_dataset\n",
    "\n",
    "images = []\n",
    "targets = []\n",
    "indexes = []\n",
    "classes = []\n",
    "\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])\n",
    "K = 2_000\n",
    "idxs = torch.randperm(len(val_dataset))[:K]\n",
    "\n",
    "for idx in idxs:\n",
    "    sample = val_dataset[idx]\n",
    "    indexes.append(sample[\"index\"].item())\n",
    "    images.append(sample[\"image\"])\n",
    "    targets.append(sample[\"target\"])\n",
    "    classes.append(sample[\"class\"])\n",
    "\n",
    "images_batch = torch.stack(images, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119d3971",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "all_latents_df = []\n",
    "for ckpt in MODELS:\n",
    "    df, _ = get_latents(images_batch, ckpt)\n",
    "    all_latents_df.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1cd71a3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "TO_CONSIDER = range(len(all_latents_df))\n",
    "latents_df = [all_latents_df[i] for i in TO_CONSIDER]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea5f75f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import plot_latent_space\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "N_ROWS = 2\n",
    "N_COLS = len(latents_df) // 2\n",
    "\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "cmap = plt.cm.get_cmap(\"Set1\", 10)\n",
    "norm = plt.Normalize(latents_df[0][\"target\"].min(), latents_df[0][\"target\"].max())\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)\n",
    "\n",
    "for i, row in enumerate(axes):\n",
    "    for j, ax in enumerate(row):\n",
    "        ax.set_aspect(\"equal\")\n",
    "        plot_latent_space(\n",
    "            ax, all_latents_df[i * N_COLS + j], targets=[0, 2], size=0.5, cmap=cmap, norm=norm, bg_alpha=0.15\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cecd58bb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "TO_CONSIDER = [4, 6, 5, 8]\n",
    "chosen_latents_df = [all_latents_df[i] for i in TO_CONSIDER]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faee89b7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import plot_latent_space\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "N_ROWS = 1\n",
    "N_COLS = len(chosen_latents_df)\n",
    "\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "cmap = plt.cm.get_cmap(\"Set1\", 10)\n",
    "norm = plt.Normalize(latents_df[0][\"target\"].min(), latents_df[0][\"target\"].max())\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)\n",
    "\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.set_aspect(\"equal\")\n",
    "    plot_latent_space(ax, chosen_latents_df[i], targets=[0, 2], size=0.75, bg_alpha=0.15, cmap=cmap, norm=norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcf5a8a9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig.savefig(\"latent_rotation.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2575b80",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!rsvg-convert -f pdf -o latent_rotation.pdf latent_rotation.svg\n",
    "!rm latent_rotation.svg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "241eed6e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Latent Rotations\n",
    "\n",
    "Single PCA proof"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec695bb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"ae\"]\n",
    "PL_MODULE = LightningAutoencoder\n",
    "MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27b5ff84",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "images = []\n",
    "targets = []\n",
    "indexes = []\n",
    "classes = []\n",
    "\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])\n",
    "\n",
    "K = 2_000\n",
    "idxs = torch.randperm(len(val_dataset))[:K]\n",
    "\n",
    "for idx in idxs:\n",
    "    sample = val_dataset[idx]\n",
    "    indexes.append(sample[\"index\"].item())\n",
    "    images.append(sample[\"image\"])\n",
    "    targets.append(sample[\"target\"])\n",
    "    classes.append(sample[\"class\"])\n",
    "\n",
    "images_batch = torch.stack(images, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6df00135",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "latents_single_pca = []\n",
    "pca = None\n",
    "for ckpt in MODELS:\n",
    "    df, pca = get_latents(images_batch, ckpt, pca)\n",
    "    latents_single_pca.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b79ed58d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "latents_independent_pca = []\n",
    "pca = None\n",
    "for ckpt in MODELS:\n",
    "    df, _ = get_latents(images_batch, ckpt, None)\n",
    "    latents_independent_pca.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c524e82",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "TO_CONSIDER = [0, 1, 2, 3, 4][: len(all_latents_df)]\n",
    "latents_single_pca = [latents_single_pca[i] for i in TO_CONSIDER]\n",
    "latents_independent_pca = [latents_independent_pca[i] for i in TO_CONSIDER]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac5c8029",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from tueplots import figsizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "494cc99d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "N_ROWS = 1\n",
    "N_COLS = len(latents_single_pca)\n",
    "\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "\n",
    "cmap = plt.cm.get_cmap(\"Set1\", 10)\n",
    "norm = plt.Normalize(latents_single_pca[0][\"target\"].min(), latents_single_pca[0][\"target\"].max())\n",
    "\n",
    "fig, axes = plt.subplots(\n",
    "    ncols=N_COLS,\n",
    "    nrows=N_ROWS,\n",
    "    sharey=True,\n",
    "    sharex=True,\n",
    "    squeeze=True,\n",
    ")\n",
    "\n",
    "for j, (ax, df) in enumerate(zip(axes, latents_independent_pca)):\n",
    "    ax.set_aspect(\"equal\")\n",
    "    ax.set_title(f\"Train {j}\")\n",
    "    plot_latent_space(\n",
    "        ax,\n",
    "        df,\n",
    "        targets=[\n",
    "            0,\n",
    "            1,\n",
    "        ],\n",
    "        size=0.5,\n",
    "        bg_alpha=0.1,\n",
    "        alpha=0.7,\n",
    "        cmap=cmap,\n",
    "        norm=norm,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a049d7b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig.savefig(\"pca-proof-row1.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c427ef3b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(\n",
    "    ncols=N_COLS,\n",
    "    nrows=N_ROWS,\n",
    "    sharey=False,\n",
    "    sharex=False,\n",
    "    squeeze=True,\n",
    ")\n",
    "\n",
    "for j, (ax, df) in enumerate(zip(axes, latents_single_pca)):\n",
    "    ax.set_aspect(\"equal\")\n",
    "    plot_latent_space(\n",
    "        ax,\n",
    "        df,\n",
    "        targets=[\n",
    "            0,\n",
    "            1,\n",
    "        ],\n",
    "        size=0.5,\n",
    "        bg_alpha=0.1,\n",
    "        alpha=0.7,\n",
    "        cmap=cmap,\n",
    "        norm=norm,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5473c78",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig.savefig(\"pca-proof-row2.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e94a0ea9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!rsvg-convert -f pdf -o pca-proof-row1.pdf pca-proof-row1.svg\n",
    "!rsvg-convert -f pdf -o pca-proof-row2.pdf pca-proof-row2.svg\n",
    "!rm pca-proof-row2.svg\n",
    "!rm pca-proof-row1.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7272da5d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
