{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate csv: DO NOT EXECUTE if you already have the csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from datasets.dataset_dict import DatasetDict\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from tqdm import tqdm\n",
    "import itertools\n",
    "from tqdm import tqdm\n",
    "from latent_invariances.utils.relreps import SIMPLE_PROJECTION_TYPE\n",
    "from torch import cosine_similarity\n",
    "from torchmetrics.functional import spearman_corrcoef, pearson_corrcoef\n",
    "from collections import namedtuple\n",
    "import pandas as pd\n",
    "\n",
    "EMBEDDINGS_DIR = PROJECT_ROOT / \"data\" / \"model_zoo\" / \"embeddings_all\"\n",
    "\n",
    "SPLIT = \"test\"\n",
    "NUM_ANCHORS = 800\n",
    "DEVICE = \"cuda\"\n",
    "DATASETS = [\n",
    "    \"mnist\",\n",
    "    \"fashion_mnist\",\n",
    "    \"cifar10\",\n",
    "    \"cifar100\",\n",
    "]\n",
    "\n",
    "MODEL_NAME = [\n",
    "    \"ae\",\n",
    "    \"vae\",\n",
    "    \"linearized_ae\",\n",
    "    \"linearized_vae\",\n",
    "]\n",
    "\n",
    "SEEDS_NAME = [\n",
    "    \"0\",\n",
    "    \"1\",\n",
    "    \"2\",\n",
    "    \"3\",\n",
    "    \"4\",\n",
    "]\n",
    "\n",
    "\n",
    "Result = namedtuple(\n",
    "    \"Result\",\n",
    "    [\n",
    "        \"dataset\",\n",
    "        \"model\",\n",
    "        \"seed_name_1\",\n",
    "        \"seed_name_2\",\n",
    "        \"num_anchors\",\n",
    "        \"column_name_1\",\n",
    "        \"column_name_2\",\n",
    "        \"projection\",\n",
    "        \"spearman_mean\",\n",
    "        \"pearson_mean\",\n",
    "        \"spearman_std\",\n",
    "        \"pearson_std\",\n",
    "        \"cosine_mean\",\n",
    "        \"mse_mean\",\n",
    "        \"l1_mean\",\n",
    "        \"cosine_std\",\n",
    "        \"mse_std\",\n",
    "        \"l1_std\",\n",
    "        \"norm_cosine_mean\",\n",
    "        \"norm_mse_mean\",\n",
    "        \"norm_l1_mean\",\n",
    "        \"norm_cosine_std\",\n",
    "        \"norm_mse_std\",\n",
    "        \"norm_l1_std\",\n",
    "    ],\n",
    ")\n",
    "\n",
    "performance = []\n",
    "for dataset_name, model_name in tqdm(list(itertools.product(DATASETS, MODEL_NAME))):\n",
    "    dataset_embeds_dir = EMBEDDINGS_DIR / dataset_name\n",
    "    embeds_dataset = DatasetDict.load_from_disk(str(dataset_embeds_dir), keep_in_memory=True)\n",
    "    embeds_dataset.set_format(\"np\")\n",
    "\n",
    "    anchors_idxs = torch.randperm(len(embeds_dataset[\"train\"]))[:NUM_ANCHORS]\n",
    "\n",
    "    column_name2embeds = {}\n",
    "    for seed_name in SEEDS_NAME:\n",
    "        column_name = f\"{model_name}_{seed_name}\"\n",
    "        column_name2embeds[column_name] = {\n",
    "            \"anchors\": torch.as_tensor(embeds_dataset[\"train\"][anchors_idxs][column_name], device=DEVICE).flatten(\n",
    "                start_dim=1\n",
    "            ),\n",
    "            \"embeds\": torch.as_tensor(embeds_dataset[SPLIT][column_name], device=DEVICE).flatten(start_dim=1),\n",
    "        }\n",
    "\n",
    "    for projection_name, projection_fun in SIMPLE_PROJECTION_TYPE.items():\n",
    "        print(f\"{dataset_name}_{model_name}_{projection_name}\")\n",
    "\n",
    "        for seed_name_1, seed_name_2 in tqdm(\n",
    "            list(itertools.combinations(SEEDS_NAME, r=2)),\n",
    "            desc=f\"{dataset_name}_{model_name}_{projection_name}\",\n",
    "            leave=False,\n",
    "        ):\n",
    "            column_name_1 = f\"{model_name}_{seed_name_1}\"\n",
    "            column_name_2 = f\"{model_name}_{seed_name_2}\"\n",
    "\n",
    "            anchor_1 = column_name2embeds[column_name_1][\"anchors\"]\n",
    "            anchor_2 = column_name2embeds[column_name_2][\"anchors\"]\n",
    "\n",
    "            embeds_1 = column_name2embeds[column_name_1][\"embeds\"]\n",
    "            embeds_2 = column_name2embeds[column_name_2][\"embeds\"]\n",
    "\n",
    "            rel_1 = projection_fun(anchors=anchor_1, points=embeds_1)\n",
    "            rel_2 = projection_fun(anchors=anchor_2, points=embeds_2)\n",
    "\n",
    "            # EXPRESS RELATIVE ERRORS IN TERMS OF ABSOLUTE ERRORS\n",
    "            cosine = cosine_similarity(rel_1, rel_2).cpu()\n",
    "            mse = torch.nn.functional.mse_loss(rel_1, rel_2, reduction=\"none\").cpu()\n",
    "            l1 = torch.nn.functional.l1_loss(rel_1, rel_2, reduction=\"none\").cpu()\n",
    "\n",
    "            spearman = spearman_corrcoef(rel_1.T, rel_2.T)  #\n",
    "            pearson = pearson_corrcoef(rel_1.T, rel_2.T)\n",
    "\n",
    "            norm_rel_1 = torch.nn.functional.normalize(rel_1, dim=1)\n",
    "            norm_rel_2 = torch.nn.functional.normalize(rel_2, dim=1)\n",
    "            norm_cosine = cosine_similarity(norm_rel_1, norm_rel_2).cpu()\n",
    "            norm_mse = torch.nn.functional.mse_loss(norm_rel_1, norm_rel_2, reduction=\"none\").cpu()\n",
    "            norm_l1 = torch.nn.functional.l1_loss(norm_rel_1, norm_rel_2, reduction=\"none\").cpu()\n",
    "\n",
    "            performance.append(\n",
    "                Result(\n",
    "                    dataset=dataset_name,\n",
    "                    model=model_name,\n",
    "                    seed_name_1=seed_name_1,\n",
    "                    seed_name_2=seed_name_2,\n",
    "                    num_anchors=NUM_ANCHORS,\n",
    "                    column_name_1=column_name_1,\n",
    "                    column_name_2=column_name_2,\n",
    "                    projection=projection_name,\n",
    "                    spearman_mean=spearman.mean().item(),\n",
    "                    pearson_mean=pearson.mean().item(),\n",
    "                    spearman_std=spearman.std().item(),\n",
    "                    pearson_std=pearson.std().item(),\n",
    "                    cosine_mean=cosine.mean().item(),\n",
    "                    mse_mean=mse.mean().item(),\n",
    "                    l1_mean=l1.mean().item(),\n",
    "                    cosine_std=cosine.std().item(),\n",
    "                    mse_std=mse.std().item(),\n",
    "                    l1_std=l1.std().item(),\n",
    "                    #\n",
    "                    norm_cosine_mean=norm_cosine.mean().item(),\n",
    "                    norm_mse_mean=norm_mse.mean().item(),\n",
    "                    norm_l1_mean=norm_l1.mean().item(),\n",
    "                    norm_cosine_std=norm_cosine.std().item(),\n",
    "                    norm_mse_std=norm_mse.std().item(),\n",
    "                    norm_l1_std=norm_l1.std().item(),\n",
    "                )\n",
    "            )\n",
    "\n",
    "df = pd.DataFrame(performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.DataFrame(performance).to_csv(\"modelzoo_all_ae_performance.csv\", index=False, sep=\"\\t\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# If you already have the csv start from here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = pd.read_csv(\"modelzoo_all_ae_performance.csv\", sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import plotly.express as px\n",
    "from typing import Sequence, Optional, Tuple, Dict\n",
    "from matplotlib import axis\n",
    "from torch.distributed.algorithms import join\n",
    "\n",
    "\n",
    "from latent_invariances.utils.latex import convert_to_latex_with_meanstd\n",
    "\n",
    "METRIC = [\n",
    "    \"cosine_mean\",\n",
    "    \"mse_mean\",\n",
    "    \"l1_mean\",\n",
    "    \"spearman_mean\",\n",
    "    \"pearson_mean\",\n",
    "]\n",
    "ROWS = [\"dataset\", \"model\", \"projection\"]\n",
    "\n",
    "pd.set_option(\"display.max_rows\", None)\n",
    "pd.set_option(\"display.max_columns\", None)\n",
    "pd.set_option(\"display.precision\", 3)\n",
    "\n",
    "df = data.copy()\n",
    "\n",
    "\n",
    "df_group = df[METRIC + ROWS].groupby([\"dataset\", \"model\", \"projection\"]).agg([\"mean\", \"std\"])\n",
    "\n",
    "import plotly.express as px\n",
    "\n",
    "df_group = df[METRIC + ROWS].groupby([\"dataset\", \"model\", \"projection\"]).agg(\"mean\")\n",
    "a = df_group.reset_index()\n",
    "metrics = [\n",
    "    \"cosine_mean\",\n",
    "    \"mse_mean\",\n",
    "    \"l1_mean\",\n",
    "    \"spearman_mean\",\n",
    "    \"pearson_mean\",\n",
    "]\n",
    "\n",
    "a = a[[\"dataset\", \"model\", \"projection\"] + metrics]\n",
    "a = a[a.projection != \"CoB\"]\n",
    "a = a.reset_index()\n",
    "px.bar(\n",
    "    a,\n",
    "    x=\"model\",\n",
    "    y=\"pearson_mean\",\n",
    "    color=\"projection\",\n",
    "    barmode=\"group\",\n",
    "    range_y=[0.8, 1.01],\n",
    "    facet_col=\"dataset\",\n",
    ").show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nn_core.common import PROJECT_ROOT\n",
    "\n",
    "FIGURES_DIR = PROJECT_ROOT / \"figures\" / \"aes\" / \"absolute\"\n",
    "FIGURES_DIR.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inconsistencies on ALL datasets pearson/spearman"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "from tueplots.figsizes import _GOLDEN_RATIO\n",
    "import seaborn as sns\n",
    "import matplotlib as mpl\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 1\n",
    "RATIO = _GOLDEN_RATIO\n",
    "\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "plt.rcParams.update(bundles.iclr2024())\n",
    "plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "plt.rcParams.update(axes.lines())\n",
    "\n",
    "\n",
    "# sns.set_theme(style='white')\n",
    "sns.set_theme()\n",
    "\n",
    "\n",
    "for metric, metric_name in [[\"pearson_mean\", \"Pearson Correlation\"], [\"spearman_mean\", \"Spearman Correlation\"]]:\n",
    "    for i, dataset in enumerate([\"mnist\", \"fashion_mnist\", \"cifar10\", \"cifar100\"]):\n",
    "        fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, sharex=True, sharey=True)\n",
    "        f = sns.boxenplot(\n",
    "            df[\n",
    "                (df.dataset == dataset)\n",
    "                & (df.projection != \"Absolute\")\n",
    "                & (df.projection != \"CenterCosine\")\n",
    "                & (df.projection != \"NormEuclidean\")\n",
    "                & (df.projection != \"Linf\")\n",
    "            ],\n",
    "            x=\"model\",\n",
    "            y=metric,\n",
    "            hue=\"projection\",\n",
    "            palette=\"Spectral_r\",\n",
    "            ax=ax,\n",
    "            # k_depth=\"full\",\n",
    "            # showfliers=False\n",
    "        )\n",
    "        # for a in ax.collections:\n",
    "        #     if isinstance(a, mpl.collections.PatchCollection):\n",
    "        #         # remove line surround each box\n",
    "        #         a.set_linewidth(0)\n",
    "\n",
    "        # ax.set_aspect('equal' )\n",
    "        ax.set_title(dataset)\n",
    "        ax.set_xlabel(\"\")\n",
    "        ax.set_ylabel(metric_name)\n",
    "        ax.set_xticklabels([\"AE\", \"VAE\", \"LinAE\", \"LinVAE\"])\n",
    "        ax.legend(title=\"Projections\")\n",
    "        sns.move_legend(ax, title=\"Projections\", loc=\"lower center\", frameon=True, ncol=4, bbox_to_anchor=(0.45, 0))\n",
    "\n",
    "        filename = FIGURES_DIR / f\"{dataset}_{metric}\"\n",
    "        filename.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "        fig.savefig(f\"{filename}.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "        !rsvg-convert -f pdf -o {filename}.pdf {filename}.svg\n",
    "        !rm {filename}.svg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OLD STUFF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import plotly.express as px\n",
    "from typing import Sequence, Optional, Tuple, Dict\n",
    "\n",
    "from latent_invariances.utils.latex import convert_to_latex_with_meanstd\n",
    "\n",
    "METRIC = [\n",
    "    \"cosine_mean\",\n",
    "    \"mse_mean\",\n",
    "    \"l1_mean\",\n",
    "    \"spearman_mean\",\n",
    "    \"pearson_mean\",\n",
    "]\n",
    "ROWS = [\"model\", \"dataset\", \"projection\"]\n",
    "\n",
    "\n",
    "print(\n",
    "    convert_to_latex_with_meanstd(\n",
    "        df,\n",
    "        ROWS,\n",
    "        METRIC,\n",
    "        label_mapping={\n",
    "            \"cosine_mean\": \"cosine\",\n",
    "            \"l1_mean\": \"l1\",\n",
    "            \"mse_mean\": \"mse\",\n",
    "        },\n",
    "        caption=\"aesstitching\",\n",
    "        precision=2,\n",
    "    )[1]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "# cosine_mean, cosine_std, mse_mean, mse_std, l1_mean, l1_std, norm_cosine_mean, norm_cosine_std, norm_mse_mean, norm_mse_std, norm_l1_mean, norm_l1_std\n",
    "\n",
    "METRIC = \"mse_mean\"\n",
    "MODEL = \"vae\"  # vae\n",
    "DATASET = None  # cifar10, cifar100, mnist, fmnist\n",
    "\n",
    "\n",
    "def compute_relative_metrics(df, reference_row):\n",
    "    reference_row = df.loc[reference_row]\n",
    "    relative_errors = (df / reference_row) * 100\n",
    "    return relative_errors\n",
    "\n",
    "\n",
    "rdf = df\n",
    "if DATASET is not None:\n",
    "    rdf = rdf[rdf.dataset == DATASET]\n",
    "\n",
    "if MODEL is not None:\n",
    "    rdf = rdf[rdf.model == MODEL]\n",
    "\n",
    "rdf = (\n",
    "    rdf[\n",
    "        [\n",
    "            \"model\",\n",
    "            \"projection\",\n",
    "            METRIC,\n",
    "        ]\n",
    "    ]\n",
    "    .groupby([\"model\", \"projection\"])\n",
    "    .agg(\"mean\")\n",
    ")\n",
    "\n",
    "relative_errors = compute_relative_metrics(rdf, reference_row := (MODEL if MODEL is not None else \"ae\", \"Cosine\"))\n",
    "relative_errors.round(2)\n",
    "\n",
    "plot_df = relative_errors.reset_index()\n",
    "# PLOT_COLS = []\n",
    "# tmpdf = relative_errors.reset_index()[METRICS_COL + GROUP_COL].groupby(GROUP_COL).agg(\"mean\").round(6).reset_index()\n",
    "\n",
    "# Add horizontal line\n",
    "fig = px.bar(\n",
    "    plot_df,\n",
    "    facet_col=\"model\",\n",
    "    x=\"projection\",\n",
    "    y=METRIC,\n",
    "    range_y=[0, 200],\n",
    "    title=f\"{METRIC} relative to {reference_row} (%) [{(MODEL if MODEL is not None else '')}, {(DATASET if DATASET is not None else '')}]\",\n",
    ")\n",
    "fig.add_hline(y=100, line_width=3, line_dash=\"dash\", line_color=\"red\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option(\"display.float_format\", lambda x: \"%.2e\" % x)\n",
    "# pd.set_option('display.float_format', lambda x: '%.7f' % x)\n",
    "\n",
    "grouped_df = (\n",
    "    df[[\"model\", \"projection\", \"cosine_mean\", \"mse_mean\", \"l1_mean\", \"spearman_mean\", \"pearson_mean\"]]\n",
    "    .groupby([\"model\", \"projection\"])\n",
    "    .agg(\"mean\")\n",
    "    .reset_index()\n",
    ")\n",
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "scaler = MinMaxScaler()\n",
    "columns_to_normalize = [\"cosine_mean\", \"mse_mean\", \"l1_mean\", \"spearman_mean\", \"pearson_mean\"]\n",
    "grouped_df[columns_to_normalize] = scaler.fit_transform(grouped_df[columns_to_normalize])\n",
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = \"cosine_mean\"  # cosine_mean, mse_mean, l1_mean, spearman_mean, pearson_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.bar(\n",
    "    grouped_df,\n",
    "    facet_col=\"model\",\n",
    "    x=\"projection\",\n",
    "    y=metric,\n",
    "    range_y=[0, 0.5],\n",
    "    text_auto=\".2\",\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "fig.write_image(f\"plots/ae_norm_{metric}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "\n",
    "# model = \"ae\"  # ae, vae\n",
    "\n",
    "# if model == \"ae\":\n",
    "#     df_plot = grouped_df[grouped_df.model == \"ae\"]\n",
    "# else:\n",
    "#     df_plot = grouped_df[grouped_df.model == \"vae\"]\n",
    "\n",
    "# projection = df_plot[\"projection\"]\n",
    "# value = df_plot[metric].round(3)\n",
    "\n",
    "# fig, ax = plt.subplots()\n",
    "\n",
    "# cm = plt.cm.get_cmap(\"Spectral_r\")\n",
    "# bars = ax.bar(projection, value, color=cm(range(len(projection))))\n",
    "\n",
    "# # cm = plt.cm.get_cmap('Spectral_r')\n",
    "# # bars = ax.bar(projection, value, color=cm((value - value.min()) / (value.max() - value.min())))\n",
    "\n",
    "# ax.set_ylabel(metric)\n",
    "\n",
    "# for bar in bars:\n",
    "#     height = bar.get_height()\n",
    "#     ax.text(bar.get_x() + bar.get_width() / 2, height, str(height), ha=\"center\", va=\"bottom\")\n",
    "\n",
    "# # Rotate and align x-axis labels\n",
    "# ax.set_xticklabels(projection, rotation=45, ha=\"right\")\n",
    "\n",
    "# plt.savefig(f\"plots/{model}_norm_{metric}.png\", dpi=300, bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grouped_df = (\n",
    "#     df[[\"model\", \"projection\", \"cosine_mean\", \"mse_mean\", \"l1_mean\", \"spearman_mean\", \"pearson_mean\"]]\n",
    "#     .groupby([\"model\", \"projection\"])\n",
    "#     .agg(\"mean\")\n",
    "#     .reset_index()\n",
    "# )\n",
    "\n",
    "# pd.DataFrame(grouped_df).to_csv(\"grouped_ae_vae_performance.csv\", index=False, sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cm_performance = pd.read_csv(\"grouped_ae_vae_performance.csv\", sep=\"\\t\")\n",
    "cm_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "model = \"vae\"  # ae, vae\n",
    "\n",
    "cm_df = cm_performance[cm_performance.model == model]\n",
    "\n",
    "# cm_df[\"mse_mean\"] = 1 / (1 + cm_df[\"mse_mean\"])\n",
    "# cm_df[\"l1_mean\"] = 1 / (1 + cm_df[\"l1_mean\"])\n",
    "\n",
    "scores_df = cm_df.iloc[:, 1:].set_index(\"projection\")\n",
    "\n",
    "correlation_matrix = scores_df.corr()\n",
    "\n",
    "cmap = sns.diverging_palette(20, 220, n=200)\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "\n",
    "sns.heatmap(correlation_matrix, annot=True, cmap=cmap, fmt=\".2f\", linewidths=0.5)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = data.copy()\n",
    "grouped_dataset = (\n",
    "    df[\n",
    "        [\n",
    "            \"dataset\",\n",
    "            \"model\",\n",
    "            \"projection\",\n",
    "            \"cosine_mean\",\n",
    "            \"mse_mean\",\n",
    "            \"l1_mean\",\n",
    "            \"spearman_mean\",\n",
    "            \"pearson_mean\",\n",
    "            \"cosine_std\",\n",
    "            \"mse_std\",\n",
    "            \"l1_std\",\n",
    "            \"spearman_std\",\n",
    "            \"pearson_std\",\n",
    "        ]\n",
    "    ]\n",
    "    .groupby([\"model\", \"projection\", \"dataset\"])\n",
    "    .agg(\"mean\")\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "# grouped_dataset[\"mse_mean\"] = 1 / (1 + grouped_dataset[\"mse_mean\"])\n",
    "# grouped_dataset[\"l1_mean\"] = 1 / (1 + grouped_dataset[\"l1_mean\"])\n",
    "\n",
    "grouped_dataset = grouped_dataset[grouped_dataset.projection != \"Normalized Absolute\"]\n",
    "grouped_dataset = grouped_dataset[grouped_dataset.projection != \"L1\"]\n",
    "grouped_dataset = grouped_dataset[grouped_dataset.projection != \"Standardized Euclidean\"]\n",
    "\n",
    "pd.DataFrame(grouped_dataset).to_csv(\"grouped_dataset_ae_vae_performance.csv\", index=False, sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "grouped_dataset = pd.read_csv(\"grouped_dataset_ae_vae_performance.csv\", sep=\"\\t\")\n",
    "\n",
    "metric_mean = \"cosine_mean\"  # cosine_mean, mse_mean, l1_mean, spearman_mean, pearson_mean\n",
    "title = \"Cosine\"\n",
    "metric_std = \"cosine_std\"  # cosine_std, mse_std, l1_std, spearman_std, pearson_std\n",
    "model = \"vae\"  # ae, vae\n",
    "\n",
    "grouped_dataset = grouped_dataset[grouped_dataset.model == model]\n",
    "\n",
    "datasets = grouped_dataset[\"dataset\"].unique()\n",
    "projections = grouped_dataset[\"projection\"].unique()\n",
    "color_map = plt.cm.get_cmap(\"tab10\", len(projections))\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "dataset_spacing = 0.55\n",
    "bar_width = 0.4\n",
    "\n",
    "group_width = len(projections) * bar_width + dataset_spacing\n",
    "\n",
    "x_pos = np.arange(len(datasets)) * group_width\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    data_subset = grouped_dataset[grouped_dataset[\"dataset\"] == dataset]\n",
    "\n",
    "    for j, projection in enumerate(projections):\n",
    "        score = data_subset[data_subset[\"projection\"] == projection][metric_mean].values[0]\n",
    "        std = data_subset.loc[data_subset[\"projection\"] == projection, metric_std].values[0]\n",
    "\n",
    "        bar_x_pos = x_pos[i] + j * bar_width + bar_width / 2\n",
    "\n",
    "        ax.bar(bar_x_pos, score, width=bar_width, color=color_map(j), zorder=2)\n",
    "        ax.errorbar(bar_x_pos, score, yerr=std, color=\"black\", capsize=3, zorder=3)\n",
    "\n",
    "\n",
    "# Set up labels and title\n",
    "ax.set_xticks(x_pos + bar_width + (bar_width + 0.6))\n",
    "ax.set_xticklabels(datasets, rotation=45, ha=\"right\")\n",
    "ax.set_ylabel(title)\n",
    "ax.set_xlabel(\"\")\n",
    "# ax.set_ylim(0, 1.4)\n",
    "ax.grid(axis=\"y\", color=\"gray\", linestyle=\"-\", alpha=0.3, zorder=1)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Rectangle((0, 0), 1, 1, color=color_map(j), label=projection) for j, projection in enumerate(projections)\n",
    "]\n",
    "legend_ncol = len(projections)\n",
    "ax.legend(\n",
    "    handles=legend_elements,\n",
    "    bbox_to_anchor=(0.5, 1.1),\n",
    "    loc=\"upper center\",\n",
    "    ncol=legend_ncol,\n",
    ")\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"plots/hist_{model}_{title}_alldata.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rsvg-convert -f pdf -o 'plots/hist_vae_Cosine_alldata.pdf' 'plots/hist_vae_Cosine_alldata.svg'\n",
    "!rm 'plots/hist_vae_Cosine_alldata'.svg"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot slides"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "grouped_dataset = pd.read_csv(\"grouped_dataset_ae_vae_performance.csv\", sep=\"\\t\")\n",
    "\n",
    "metric_mean = \"cosine_mean\"  # cosine_mean, mse_mean, l1_mean, spearman_mean, pearson_mean\n",
    "title = \"Cosine\"\n",
    "metric_std = \"cosine_std\"  # cosine_std, mse_std, l1_std, spearman_std, pearson_std\n",
    "model = \"vae\"  # ae, vae\n",
    "\n",
    "grouped_dataset = grouped_dataset[grouped_dataset.model == model]\n",
    "\n",
    "datasets = grouped_dataset[\"dataset\"].unique()\n",
    "projections = grouped_dataset[\"projection\"].unique()\n",
    "color_map = plt.cm.get_cmap(\"tab10\", len(projections))\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "dataset_spacing = 0.55\n",
    "bar_width = 0.4\n",
    "\n",
    "group_width = len(projections) * bar_width + dataset_spacing\n",
    "\n",
    "x_pos = np.arange(len(datasets)) * group_width\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    data_subset = grouped_dataset[grouped_dataset[\"dataset\"] == dataset]\n",
    "\n",
    "    for j, projection in enumerate(projections):\n",
    "        score = data_subset[data_subset[\"projection\"] == projection][metric_mean].values[0]\n",
    "        std = data_subset.loc[data_subset[\"projection\"] == projection, metric_std].values[0]\n",
    "\n",
    "        bar_x_pos = x_pos[i] + j * bar_width + bar_width / 2\n",
    "\n",
    "        ax.bar(bar_x_pos, score, width=bar_width, color=color_map(j), zorder=2)\n",
    "        ax.errorbar(bar_x_pos, score, yerr=std, color=\"white\", capsize=3, zorder=3)\n",
    "\n",
    "\n",
    "# Set up labels and title\n",
    "ax.set_xticks(x_pos + bar_width + (bar_width + 0.6))\n",
    "ax.set_xticklabels(datasets, rotation=45, ha=\"right\")\n",
    "ax.set_ylabel(title)\n",
    "ax.set_xlabel(\"\")\n",
    "ax.grid(axis=\"y\", color=\"white\", linestyle=\"-\", alpha=0.3, zorder=1)\n",
    "\n",
    "ax.spines[\"bottom\"].set_color(\"white\")\n",
    "ax.spines[\"left\"].set_color(\"white\")\n",
    "ax.spines[\"right\"].set_color(\"white\")\n",
    "ax.spines[\"top\"].set_color(\"white\")\n",
    "ax.xaxis.label.set_color(\"white\")\n",
    "ax.yaxis.label.set_color(\"white\")\n",
    "ax.tick_params(colors=\"white\", which=\"both\")\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Rectangle((0, 0), 1, 1, color=color_map(j), label=projection) for j, projection in enumerate(projections)\n",
    "]\n",
    "legend_ncol = len(projections)\n",
    "ax.legend(\n",
    "    handles=legend_elements,\n",
    "    bbox_to_anchor=(0.5, 1.1),\n",
    "    loc=\"upper center\",\n",
    "    ncol=legend_ncol,\n",
    "    fancybox=True,\n",
    "    framealpha=0,\n",
    "    labelcolor=\"white\",\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"plots/slides/{model}_hist.png\", bbox_inches=\"tight\", format=\"png\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
