{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from latent_invariances.stitching import data_config\n",
    "from nn_core.common import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dsf = [\n",
    "    pd.read_csv(dataset_dir / \"stitching.tsv\", sep=\"\\t\").assign(dataset=dataset_dir.name.split(\"_\")[0])\n",
    "    for dataset_dir in (PROJECT_ROOT / \"results_paper\" / \"stitching\").iterdir()\n",
    "]\n",
    "df = pd.concat(dataset_dsf)\n",
    "# df = pd.read_csv(PROJECT_ROOT / \"results\" / \"compatibility\" / config.key / \"results.tsv\", sep=\"\\t\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projections = {\"Cosine\", \"Euclidean\", \"L1\", \"Linf\"}\n",
    "datasets = {\"trec\"}\n",
    "datasets = set(df.dataset.unique())\n",
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df[\n",
    "    (df.encoding_space != df.decoding_space)\n",
    "    & (df.aggregation == \"LayerNorm\")\n",
    "    & (df.classifier == \"linear\")\n",
    "    & (df.projections.isin(projections))\n",
    "    & (df.dataset.isin(datasets))\n",
    "    & (df.encoding_space != \"cspdarknet53\") & (df.decoding_space != \"cspdarknet53\")\n",
    "]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option(\"display.max_rows\", 19000)\n",
    "df.drop(columns=['aggregation', 'classifier']).groupby(['dataset', 'encoding_space', 'decoding_space', 'projections']).agg('mean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_aliases = {\n",
    "    \"bert-base-cased\": \"BERT-C\",\n",
    "    \"bert-base-uncased\": \"BERT-U\",\n",
    "    \"google/electra-base-discriminator\": \"ELECTRA\",\n",
    "    \"roberta-base\": \"RoBERTa\",\n",
    "    \"albert-base-v2\": \"ALBERT\",\n",
    "    \"xlm-roberta-base\": \"XLM-R\",\n",
    "    \"openai/clip-vit-base-patch32\": \"CViT-B/32\",\n",
    "    \"vit_base_patch16_384\": \"ViT-B/16\",\n",
    "    \"vit_base_resnet50_384\": \"RViT-B/16\",\n",
    "    \"rexnet_100\": \"RexNet\",\n",
    "    \"vit_small_patch16_224\": \"ViT-S/16\",\n",
    "    \"cspdarknet53\": \"CSPDarkNet\",\n",
    "}\n",
    "df.loc[:, 'encoding_space'] = df['encoding_space'].map(encoder_aliases)\n",
    "df.loc[:, 'decoding_space'] = df['decoding_space'].map(encoder_aliases)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.box(\n",
    "    df[\n",
    "    (df.projections != \"CenterCosine\") & \n",
    "    (df.projections != \"NormEuclidean\") &\n",
    "    (df.aggregation == \"LayerNorm\") &\n",
    "    (df.projections != \"Absolute\") &\n",
    "    (df.encoding_space != df.decoding_space)\n",
    "],\n",
    "    x=\"encoding_space\",\n",
    "    y=\"linear_cka\",\n",
    "    color=\"projections\",\n",
    "    #histfunc=\"avg\",\n",
    ")\n",
    "fig.update_layout(barmode=\"group\", bargap=0.1)\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_back = df.copy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FIGURES_DIR = PROJECT_ROOT / \"figures\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "from tueplots.figsizes import _GOLDEN_RATIO\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 1\n",
    "RATIO = _GOLDEN_RATIO  \n",
    "\n",
    "# Use tueplots for iclr2024 formatting\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "plt.rcParams.update(bundles.iclr2024())\n",
    "plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "plt.rcParams.update(axes.lines())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "\n",
    "df = df_back.copy()\n",
    "\n",
    "invariance_aliases = {\n",
    "    \"Cosine\" : \"Cosine\",\n",
    "    \"Euclidean\": \"Euclidean\",\n",
    "    \"L1\": \"L1\",\n",
    "    \"Linf\": \"$L_{\\infty}$\",\n",
    "}\n",
    "df.loc[:, 'projections'] = df['projections'].map(invariance_aliases)\n",
    " \n",
    "for dataset, (y, y_label), teaser in itertools.product(datasets, ((\"linear_cka\", \"Linear CKA\"), (\"score\", \"Accuracy\")), (True, False)):\n",
    "    dataset_condition = ((df.encoding_space != \"ViT-S/16\") & (df.decoding_space != \"RViT-B/16\") & (df.decoding_space != \"ViT-S/16\") & (df.encoding_space != \"RViT-B/16\")) if y == \"linear_cka\" else True\n",
    "    dataset_condition = dataset_condition if teaser else True\n",
    "    projection_condition = (df.projections != invariance_aliases[\"Linf\"]) if teaser else True\n",
    "\n",
    "    plt.rcParams.update({\"figure.dpi\": 300})\n",
    "    plt.rcParams.update(bundles.iclr2024())\n",
    "    plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "    plt.rcParams.update(axes.lines())\n",
    "    \n",
    "    # sns.set_theme(style='white')\n",
    "    sns.set_theme()\n",
    "\n",
    "    fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, sharex=True, sharey=True)\n",
    "\n",
    "    f = sns.boxenplot(\n",
    "        data=df[(df.dataset == dataset) & (df.projections != \"Absolute\") & dataset_condition & projection_condition],\n",
    "        # data=df[(df.dataset == dataset) & (df.projections != \"Absolute\") & dataset_condition],\n",
    "        x=\"encoding_space\",\n",
    "        y=y,\n",
    "        hue=\"projections\",\n",
    "        palette=\"Spectral_r\",\n",
    "        ax=ax,\n",
    "        #boxprops=dict(alpha=.3)    \n",
    "    )\n",
    "    \n",
    "    ax.set_title(dataset.capitalize())\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(y_label)\n",
    "\n",
    "    ax.legend(title='Projections')\n",
    "    sns.move_legend(ax, title=\"Invariance\", loc='lower right', ncol=2)\n",
    "\n",
    "    plt.switch_backend(\"Agg\")\n",
    "\n",
    "    teaser_string = \"_teaser\" if teaser else \"\"\n",
    "    filename = FIGURES_DIR / f\"figure1/{dataset}_{y}{teaser_string}\"\n",
    "    filename.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Save losslessly in pdf\n",
    "    fig.savefig(f\"{filename}.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "    !rsvg-convert -f pdf -o {filename}.pdf {filename}.svg\n",
    "    !rm {filename}.svg\n",
    "    \n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df_back.copy()\n",
    "invariance_aliases = {\n",
    "    \"Cosine\" : \"Conformal\",\n",
    "    \"Euclidean\": \"E(n)\",\n",
    "    \"L1\": \"ISO$_{L1}$(n)\",\n",
    "    \"Linf\": \"ISO$_{L_{\\infty}}$(n)\",\n",
    "}\n",
    "df.loc[:, 'projections'] = df['projections'].map(invariance_aliases)\n",
    "\n",
    "\n",
    "datasets = [(\"mnist\", {\n",
    "    'loc':\"best\",\n",
    "    'frameon': True,\n",
    "    # 'ncol': 4,\n",
    "    # 'bbox_to_anchor': (0.45, 0),\n",
    "}), (\"fashion\", {\n",
    "    'loc':\"best\",\n",
    "    'frameon': True,\n",
    "    # 'ncol': 2,\n",
    "    # 'bbox_to_anchor': (0, 0.18),\n",
    "})]\n",
    "\n",
    "for dataset, legend_args in datasets:\n",
    "\n",
    "    plt.rcParams.update({\"figure.dpi\": 150})\n",
    "    plt.rcParams.update(bundles.iclr2024())\n",
    "    plt.rcParams.update(figsizes.iclr2024(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "    plt.rcParams.update(axes.lines())\n",
    "\n",
    "    sns.set_theme()\n",
    "    \n",
    "    fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, sharex=True, sharey=True)\n",
    "\n",
    "    f = sns.boxenplot(\n",
    "        data=df[(df.dataset == dataset) & (df.projections != invariance_aliases[\"Linf\"]) & (df.projections != \"Absolute\") & (df.encoding_space != \"RViT-B/16\") & (df.encoding_space != \"ViT-S/16\")],\n",
    "        x=\"encoding_space\",\n",
    "        y=\"linear_cka\",\n",
    "        hue=\"projections\",\n",
    "        palette=\"Spectral_r\",\n",
    "        ax=ax,\n",
    "\n",
    "        #boxprops=dict(alpha=.3)\n",
    "    )\n",
    "    # plt.title(dataset.capitalize())\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"Space Similarity\")\n",
    "\n",
    "    ax.legend(title='Invariance')\n",
    "    sns.move_legend(ax, title=\"Invariance\", **legend_args)\n",
    "    # sns.move_legend(ax, title=\"Invariance\",   loc=\"best\", frameon=True, )\n",
    "\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "    filename = FIGURES_DIR / f\"teaser_{dataset}\"\n",
    "    filename.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    plt.switch_backend(\"Agg\")\n",
    "\n",
    "    # Save losslessly in pdf\n",
    "    fig.savefig(f\"{filename}.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "    !rsvg-convert -f pdf -o {filename}.pdf {filename}.svg\n",
    "    !rm {filename}.svg"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
