{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03549e6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "__file__ = \".\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63b70aa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "from rae.utils.evaluation import parse_checkpoints_tree, parse_checkpoint\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "from rae.utils.evaluation import get_dataset\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "import logging\n",
    "from typing import Optional\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.utils import shuffle\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7496cd7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXPERIMENT_ROOT = PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\"\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)\n",
    "ckpt = checkpoints[\"fmnist\"][\"ae\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d7a1c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import hydra\n",
    "\n",
    "PL_MODULE = LightningAutoencoder\n",
    "\n",
    "\n",
    "def get_dataset(pl_module, ckpt):\n",
    "    _, cfg = parse_checkpoint(\n",
    "        module_class=pl_module,\n",
    "        checkpoint_path=ckpt,\n",
    "        map_location=\"cpu\",\n",
    "    )\n",
    "    datamodule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)\n",
    "    datamodule.setup()\n",
    "    train_dataset = datamodule.train_dataset\n",
    "    val_dataset = datamodule.val_datasets[0]\n",
    "    return train_dataset, val_dataset, datamodule.metadata\n",
    "\n",
    "\n",
    "train_dataset, test_dataset, metadata = get_dataset(pl_module=PL_MODULE, ckpt=ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963e0199",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=8)\n",
    "test_dl = DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65eeb4c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "anchors = metadata.anchors_images.to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "223b22a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "from rae.modules.blocks import build_dynamic_encoder_decoder\n",
    "from torch.nn import CrossEntropyLoss, MSELoss\n",
    "from torch.optim import Adam\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from torch import nn\n",
    "from pytorch_lightning import seed_everything\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from torch.nn import functional as F\n",
    "import math\n",
    "from rae.modules.rel_ae import VanillaRelAE\n",
    "from rae.modules.attention import RelativeAttention\n",
    "\n",
    "\n",
    "def fit(dataset_dl, lr=1e-3, epochs=1, seed=0, hidden_dims=[3, 6, 12, 24], batch_lim=100):\n",
    "    seed_everything(seed)\n",
    "    model = VanillaRelAE(\n",
    "        metadata=metadata,\n",
    "        input_size=None,\n",
    "        latent_dim=None,\n",
    "        hidden_dims=hidden_dims,\n",
    "        relative_attention=RelativeAttention(\n",
    "            n_anchors=anchors.shape[0],\n",
    "            n_classes=len(metadata.class_to_idx),\n",
    "            similarity_mode=\"inner\",\n",
    "            values_mode=\"similarities\",\n",
    "            normalization_mode=\"l2\",\n",
    "        ),\n",
    "        remove_encoder_last_activation=False,\n",
    "    )\n",
    "\n",
    "    model = model.to(DEVICE)\n",
    "    opt = Adam(model.parameters(), lr=lr)\n",
    "    loss_fn = MSELoss()\n",
    "    for epoch in (tqdm_bar := tqdm(range(epochs), leave=False, desc=\"epoch\")):\n",
    "        for batch in itertools.islice(dataset_dl, batch_lim):\n",
    "            batch_x = batch[\"image\"].to(DEVICE, non_blocking=True)\n",
    "            pred_y = model.decode(**model.encode(batch_x))[\"reconstruction\"]\n",
    "            loss = loss_fn(pred_y, batch_x)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            opt.zero_grad()\n",
    "        tqdm_bar.set_description(f\"Loss: {loss:2f}\")\n",
    "    model = model.eval().cpu()\n",
    "\n",
    "    return model, loss.cpu().item()\n",
    "\n",
    "\n",
    "best_model, best_loss = fit(train_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656cc474",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2, resize=None):\n",
    "    ax.axis(\"off\")\n",
    "    ax.set_aspect(\"equal\")\n",
    "\n",
    "    if resize is not None:\n",
    "        images = resize(images)\n",
    "    images = images.cpu().detach()\n",
    "    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eadc2128",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, [source, pred] = plt.subplots(\n",
    "    2,\n",
    "    1,\n",
    "    dpi=150,\n",
    "    sharey=False,\n",
    "    sharex=False,\n",
    ")\n",
    "plot_images(source, anchors.cpu()[:10])\n",
    "plot_images(pred, best_model.decode(**best_model.encode(anchors.cpu()[:10]))[\"reconstruction\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ebc148",
   "metadata": {},
   "outputs": [],
   "source": [
    "models2loss = []\n",
    "\n",
    "\n",
    "for epoch, lr, seed, batch_lim in tqdm(itertools.product([1, 2], [1e-5, 1e-3, 1e-1], [1, 2, 3], [10, 100, None])):\n",
    "    models2loss.append(fit(train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95f74db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "models2loss = sorted(models2loss, key=lambda x: x[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1631c7a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model, best_loss = models2loss[0]\n",
    "best_similarities = best_model.encode(anchors.cpu())[\"similarities\"]\n",
    "best_similarities.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf85efdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def latents_distance(latents1, latents2):\n",
    "    dist = F.pairwise_distance(latents1, latents2, p=2).mean().item()\n",
    "    return dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e37ace9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dists_to_best = []\n",
    "losses = []\n",
    "for model, loss in models2loss:\n",
    "    similarities = model.encode(anchors.cpu())[\"similarities\"]\n",
    "\n",
    "    dists_to_best.append(latents_distance(similarities, best_similarities))\n",
    "    losses.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b24d2fa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\"figure.dpi\": 300})\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=1, nrows=1, height_to_width_ratio=0.4))\n",
    "\n",
    "fig, ax = plt.subplots(nrows=1, ncols=1, sharey=True, sharex=True, squeeze=True)\n",
    "\n",
    "\n",
    "ax.scatter(x=dists_to_best, y=losses, s=5)\n",
    "ax.set_xlabel(\"Distances to best model\")\n",
    "ax.set_ylabel(\"Loss\")\n",
    "# ax.set_title(\"title\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2b7442",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"lossVSdist.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o lossVSdist.pdf lossVSdist.svg\n",
    "!rm lossVSdist.svg"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
