{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dee1ca4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3278c18",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from enum import auto\n",
    "from pathlib import Path\n",
    "from typing import Callable, Dict, Optional, Tuple, Type, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import rich\n",
    "import torch\n",
    "import typer\n",
    "from torchmetrics import (\n",
    "    ErrorRelativeGlobalDimensionlessSynthesis,\n",
    "    MeanSquaredError,\n",
    "    MetricCollection,\n",
    "    PeakSignalNoiseRatio,\n",
    "    StructuralSimilarityIndexMeasure,\n",
    ")\n",
    "\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "from rae.utils.evaluation import parse_checkpoint_id, parse_checkpoints_tree, parse_checkpoint\n",
    "from collections import defaultdict\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "\n",
    "from rae.utils.evaluation import plot_latent_space\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "\n",
    "EXPERIMENT_ROOT = PROJECT_ROOT / \"experiments\" / \"fig:latent-rotation-comparison\"\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "PREDICTIONS_TSV = EXPERIMENT_ROOT / \"predictions.tsv\"\n",
    "PERFORMANCE_TSV = EXPERIMENT_ROOT / \"performance.tsv\"\n",
    "\n",
    "DATASET_SANITY = {\n",
    "    \"mnist\": (\"rae.data.vision.mnist.MNISTDataset\", \"test\"),\n",
    "    \"fmnist\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"cifar10\": (\"rae.data.vision.cifar10.CIFAR10Dataset\", \"test\"),\n",
    "    \"cifar100\": (\"rae.data.vision.cifar100.CIFAR100Dataset\", \"test\"),\n",
    "}\n",
    "MODEL_SANITY = {\n",
    "    \"vae\": \"rae.modules.vae.VanillaVAE\",\n",
    "    \"ae\": \"rae.modules.ae.VanillaAE\",\n",
    "    \"rel_vae\": \"rae.modules.rel_vae.VanillaRelVAE\",\n",
    "    \"rel_ae\": \"rae.modules.rel_ae.VanillaRelAE\",\n",
    "}\n",
    "\n",
    "\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99d998ac",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "148843f7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import parse_checkpoint\n",
    "\n",
    "\n",
    "def get_latents(images_batch, model, key=Output.DEFAULT_LATENT, return_df: bool = True):\n",
    "    latents = model(images_batch)[key].detach().squeeze()\n",
    "    latents2d = latents[:, [0, 1]]\n",
    "    df = None\n",
    "    if return_df:\n",
    "        df = pd.DataFrame(\n",
    "            {\n",
    "                \"x\": latents2d[:, 0].tolist(),\n",
    "                \"y\": latents2d[:, 1].tolist(),\n",
    "                \"class\": classes,\n",
    "                \"target\": targets,\n",
    "                \"index\": indexes,\n",
    "            }\n",
    "        )\n",
    "    return df, latents"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed0388f8",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Latent Rotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61c6b02f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"ae\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21f3b22",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "PL_MODULE = LightningAutoencoder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86aac3b7",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "decad270",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.utils.evaluation import get_dataset\n",
    "\n",
    "images = []\n",
    "targets = []\n",
    "indexes = []\n",
    "classes = []\n",
    "\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(0)\n",
    "\n",
    "val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])\n",
    "K = 2_000\n",
    "idxs = torch.randperm(len(val_dataset))[:K]\n",
    "\n",
    "for idx in idxs:\n",
    "    sample = val_dataset[idx]\n",
    "    indexes.append(sample[\"index\"].item())\n",
    "    images.append(sample[\"image\"])\n",
    "    targets.append(sample[\"target\"])\n",
    "    classes.append(sample[\"class\"])\n",
    "\n",
    "images_batch = torch.stack(images, dim=0)\n",
    "images_batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935508b7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "LIM = 2\n",
    "N_ROWS = 1\n",
    "N_COLS = LIM\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "\n",
    "cmap = plt.cm.get_cmap(\"Set1\", 10)\n",
    "norm = plt.Normalize(min(targets), max(targets))\n",
    "\n",
    "\n",
    "def plot_row(df, title, equal=True, sharey=False, sharex=False, dpi=150):\n",
    "    fig, axes = plt.subplots(dpi=dpi, nrows=N_ROWS, ncols=N_COLS, sharey=sharey, sharex=sharex, squeeze=True)\n",
    "\n",
    "    for j, ax in enumerate(axes):\n",
    "        if j == 0:\n",
    "            ax.set_title(title)\n",
    "        if equal:\n",
    "            ax.set_aspect(\"equal\")\n",
    "        plot_latent_space(ax, df[j], targets=[0, 1], size=0.75, bg_alpha=0.25, alpha=1, cmap=cmap, norm=norm)\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f79a0882",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def latents_distance(latents):\n",
    "    dists = []\n",
    "    for i in range(len(latents)):\n",
    "        for j in range(i + 1, len(latents)):\n",
    "            x = latents[i][1]\n",
    "            y = latents[j][1]\n",
    "            # dist = ((x - y)**2).sum(dim=-1).sqrt().mean()\n",
    "            dist = F.mse_loss(x, y, reduction=\"sum\")\n",
    "            # dist = ((x - y) ** 2).mean(dim=-1).mean()\n",
    "            dists.append(dist)\n",
    "    return sum(dists) / len(dists)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29a42ea9",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "521d35bc",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_rel, _ = parse_checkpoint(\n",
    "    module_class=PL_MODULE,\n",
    "    checkpoint_path=checkpoints[\"mnist\"][\"rel_ae\"][0],\n",
    "    map_location=\"cpu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a33ee29b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "anchors_batch = model_rel.metadata.anchors_images\n",
    "anchors_batch.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13da9076",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd4313a4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "MODELS = checkpoints[\"mnist\"][\"ae\"]\n",
    "MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c47ff8cd",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "ae_latents = []\n",
    "anchors_latents = []\n",
    "for ckpt in MODELS:\n",
    "    model, _ = parse_checkpoint(\n",
    "        module_class=PL_MODULE,\n",
    "        checkpoint_path=ckpt,\n",
    "        map_location=\"cpu\",\n",
    "    )\n",
    "    df, latents = get_latents(images_batch, model, return_df=True)\n",
    "    _, a_latents = get_latents(anchors_batch, model, return_df=False)\n",
    "    ae_latents.append((df, latents))\n",
    "    anchors_latents.append(a_latents)\n",
    "\n",
    "import copy\n",
    "\n",
    "original_ae_latents = copy.deepcopy(ae_latents)\n",
    "original_anchor_latents = copy.deepcopy(anchors_latents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a75aa6db",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "f = plot_row([df for df, _ in original_ae_latents[:LIM]], \"AE\", True, True, True)\n",
    "latents_distance(ae_latents[:LIM])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94a04de3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Random isometry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "190c63ab",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "de83c0c4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Rel Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cec70ae",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e466ea65",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Rel Attention Quantized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf7c14e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# Absolut\n",
    "from scipy.stats import ortho_group\n",
    "\n",
    "\n",
    "import torch.nn.functional as F\n",
    "\n",
    "raw_latents = original_ae_latents[0][1]\n",
    "raw_anchor_latents = original_anchor_latents[0]\n",
    "\n",
    "\n",
    "anchors_latents = [raw_anchor_latents]\n",
    "ae_latents = [\n",
    "    (\n",
    "        pd.DataFrame(\n",
    "            {\n",
    "                \"x\": raw_latents[:, 0].tolist(),\n",
    "                \"y\": raw_latents[:, 1].tolist(),\n",
    "                \"class\": classes,\n",
    "                \"target\": targets,\n",
    "                \"index\": indexes,\n",
    "            }\n",
    "        ),\n",
    "        raw_latents,\n",
    "    )\n",
    "]\n",
    "\n",
    "for i in range(4):\n",
    "\n",
    "    # random_isometry = torch.as_tensor(ortho_group.rvs(raw_latents.shape[-1]), dtype=torch.float)\n",
    "    #\n",
    "    # random_isometry = random_isometry + torch.randn_like(random_isometry) * 0.01\n",
    "    # # random_isometry[0, :] += torch.randn_like(random_isometry[0])* 0.1\n",
    "\n",
    "    transformed_latents = transform(raw_latents)\n",
    "    anchors_transformed = transform(raw_anchor_latents)\n",
    "\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"x\": transformed_latents[:, 0].tolist(),\n",
    "            \"y\": transformed_latents[:, 1].tolist(),\n",
    "            \"class\": classes,\n",
    "            \"target\": targets,\n",
    "            \"index\": indexes,\n",
    "        }\n",
    "    )\n",
    "    ae_latents.append((df, transformed_latents))\n",
    "    anchors_latents.append(anchors_transformed)\n",
    "\n",
    "f = plot_row([df for df, _ in ae_latents[:LIM]], f\"AE: {latents_distance(ae_latents[:LIM])}\", True, True, True)\n",
    "\n",
    "# Relative\n",
    "from rae.modules.attention import *\n",
    "\n",
    "# Qunatized\n",
    "for quant_mode, bin_size in (\n",
    "    (None, None),\n",
    "    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.1),\n",
    "    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.3),\n",
    "):\n",
    "    rel_latents = []\n",
    "    rel_attention = RelativeAttention(\n",
    "        n_anchors=anchors_batch.shape,\n",
    "        n_classes=len(set(targets)),\n",
    "        similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "        values_mode=ValuesMethod.SIMILARITIES,\n",
    "        normalization_mode=NormalizationMode.L2,\n",
    "        output_normalization_mode=OutputNormalization.NONE,\n",
    "        similarities_quantization_mode=quant_mode,\n",
    "        similarities_bin_size=bin_size,\n",
    "        # absolute_quantization_mode=quant_mode,\n",
    "        # absolute_bin_size=bin_size\n",
    "    )\n",
    "    assert sum(x.numel() for x in rel_attention.parameters()) == 0\n",
    "    for (_, latents), a_latents in zip(ae_latents, anchors_latents):\n",
    "        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]\n",
    "        rellatents2d = rel[:, [0, 1]]\n",
    "        pca = PCA(n_components=2)\n",
    "        rellatents2d = pca.fit(rel.detach())\n",
    "        rellatents2d = pca.transform(rel.detach())\n",
    "        df = pd.DataFrame(\n",
    "            {\n",
    "                \"x\": rellatents2d[:, 0].tolist(),\n",
    "                \"y\": rellatents2d[:, 1].tolist(),\n",
    "                \"class\": classes,\n",
    "                \"target\": targets,\n",
    "                \"index\": indexes,\n",
    "            }\n",
    "        )\n",
    "        rel_latents.append((df, rel))\n",
    "    f = plot_row(\n",
    "        [df for df, _ in rel_latents[:LIM]],\n",
    "        f\"QAtt, bin size: {bin_size}: {latents_distance(rel_latents[:LIM])}\",\n",
    "        True,\n",
    "        True,\n",
    "        True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2132e49f",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Optimal transofrm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca2553b4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.modules.attention import *\n",
    "\n",
    "ae, _ = parse_checkpoint(\n",
    "    module_class=PL_MODULE,\n",
    "    checkpoint_path=checkpoints[\"mnist\"][\"ae\"][0],\n",
    "    map_location=\"cpu\",\n",
    ")\n",
    "\n",
    "att = RelativeAttention(\n",
    "    n_anchors=anchors_batch.shape,\n",
    "    n_classes=len(set(targets)),\n",
    "    similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "    values_mode=ValuesMethod.SIMILARITIES,\n",
    "    normalization_mode=NormalizationMode.L2,\n",
    "    output_normalization_mode=OutputNormalization.NONE,\n",
    "    similarities_quantization_mode=None,\n",
    "    similarities_bin_size=None,\n",
    "    # absolute_quantization_mode=AbsoluteQuantizationMode.DIFFERENTIABLE_ROUND,\n",
    "    # absolute_bin_size=bin_size\n",
    ")\n",
    "att_q = RelativeAttention(\n",
    "    n_anchors=anchors_batch.shape,\n",
    "    n_classes=len(set(targets)),\n",
    "    similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "    values_mode=ValuesMethod.SIMILARITIES,\n",
    "    normalization_mode=NormalizationMode.L2,\n",
    "    output_normalization_mode=OutputNormalization.NONE,\n",
    "    similarities_quantization_mode=SimilaritiesQuantizationMode.CUSTOM_ROUND,\n",
    "    similarities_bin_size=0.1,\n",
    "    # absolute_quantization_mode=AbsoluteQuantizationMode.DIFFERENTIABLE_ROUND,\n",
    "    # absolute_bin_size=0.1\n",
    ")\n",
    "\n",
    "ae.eval()\n",
    "images_z = ae(images_batch)[Output.DEFAULT_LATENT].detach()\n",
    "anchors_z = ae(anchors_batch)[Output.DEFAULT_LATENT].detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "684949d6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import torch\n",
    "from torch.optim.adam import Adam\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# Absolut\n",
    "from scipy.stats import ortho_group\n",
    "\n",
    "opt_isometry = torch.tensor(ortho_group.rvs(images_z.shape[-1]), dtype=torch.float, requires_grad=True)\n",
    "opt_shift = torch.zeros(images_z.shape[-1], dtype=torch.float, requires_grad=True)\n",
    "\n",
    "\n",
    "opt = Adam([opt_isometry, opt_shift], lr=1e-4)\n",
    "\n",
    "\n",
    "def transform(x):\n",
    "    return x @ opt_isometry + opt_shift\n",
    "\n",
    "\n",
    "R = 1000\n",
    "Q = 1\n",
    "I = 1000\n",
    "S = 0\n",
    "for i in (bar := tqdm(range(100))):\n",
    "\n",
    "    rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]\n",
    "    rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]\n",
    "    rel_dist = F.mse_loss(rel, rel_iso, reduction=\"sum\")\n",
    "    rel_loss = -rel_dist * R\n",
    "\n",
    "    qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]\n",
    "    qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]\n",
    "    qrel_dist = F.mse_loss(qrel, rel_iso, reduction=\"sum\")\n",
    "    qrel_loss = qrel_dist * Q\n",
    "\n",
    "    t_temp = opt_isometry @ opt_isometry.T\n",
    "    iso_loss = ((t_temp - t_temp.diag().diag()) ** 2).sum() * I\n",
    "    # iso_loss = (t_temp ** 2 - torch.eye(t_temp.shape[0])).sum() * I\n",
    "    shift_loss = opt_shift.abs().sum() * S\n",
    "    loss = rel_loss + qrel_loss + iso_loss + shift_loss\n",
    "\n",
    "    bar.set_description(f\"Rel: {rel_loss.item():3f} \\t Qua: {qrel_loss.item():3f} \\t  Iso: {iso_loss.item():3f}\")\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    opt.zero_grad()\n",
    "\n",
    "rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]\n",
    "rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]\n",
    "print(\"Relative mse:\", F.mse_loss(rel, rel_iso, reduction=\"sum\"))\n",
    "\n",
    "qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]\n",
    "qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]\n",
    "print(\"Quantized mse:\", F.mse_loss(qrel, qrel_iso, reduction=\"sum\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02bdc2f0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ade66dc",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "ae.eval()\n",
    "images_z = ae(images_batch)[Output.DEFAULT_LATENT].detach()\n",
    "anchors_z = ae(anchors_batch)[Output.DEFAULT_LATENT].detach()\n",
    "\n",
    "\n",
    "rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]\n",
    "rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]\n",
    "print(\"Relative mse:\", F.mse_loss(rel, rel_iso, reduction=\"sum\"))\n",
    "\n",
    "qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]\n",
    "qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]\n",
    "print(\"Quantized mse:\", F.mse_loss(qrel, qrel_iso, reduction=\"sum\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f633fbc8",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "qrel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74aba10",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "b = torch.as_tensor(0.5)\n",
    "x = torch.linspace(-1, 1, 200)\n",
    "y = x - torch.sin(2 * torch.pi * x) / (2 * torch.pi)\n",
    "\n",
    "a = 1\n",
    "f = 1 / b\n",
    "s = 0\n",
    "y = x - a * torch.cos(2 * torch.pi * f * x + s) / (2 * torch.pi * f)\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, dpi=150)\n",
    "f = ax.plot(\n",
    "    x,\n",
    "    y,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
