{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles, figsizes, axes, fonts\n",
    "from tueplots.figsizes import _GOLDEN_RATIO\n",
    "import seaborn\n",
    "\n",
    "pd.set_option(\"display.max_rows\", 400)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export latex tables sorted by maximum difference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\n",
    "    \"cifar10\",\n",
    "    \"cifar100_False\",\n",
    "    \"fashion_mnist\",\n",
    "    \"mnist\",\n",
    "    \"n24news_image_False\",\n",
    "    # text\n",
    "    \"dbpedia_14\",\n",
    "    \"trec_False\",\n",
    "    \"n24news_text_False\",\n",
    "]\n",
    "\n",
    "names_mapping = {\n",
    "    \"rexnet_100\": \"RexNet\",\n",
    "    \"vit_base_patch16_384\": \"ViT-B/16\",\n",
    "    \"vit_small_patch16_224\": \"ViT-S/16\",\n",
    "    \"vit_base_resnet50_384\": \"RViT-B/16\",\n",
    "    \"openai/clip-vit-base-patch32\": \"CViT-B/32\",\n",
    "    \"bert-base-cased\": \"BERT-C\",\n",
    "    \"bert-base-uncased\": \"BERT-U\",\n",
    "    \"google/electra-base-discriminator\": \"ELECTRA\",\n",
    "    \"roberta-base\": \"RoBERTa\",\n",
    "    \"albert-base-v2\": \"ALBERT\",\n",
    "    \"xlm-roberta-base\": \"XLM-RoBERTa\",\n",
    "    \"openai/clip-vit-base-patch32\": \"CViT-B/32\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nn_core.common import PROJECT_ROOT\n",
    "\n",
    "df = pd.read_csv(\n",
    "    PROJECT_ROOT / f\"results_paper/stitching/{datasets[0]}_train_test_1/stitching.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")\n",
    "\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from anypy.latex import df2table_meanstd\n",
    "\n",
    "\n",
    "def compute_max_diff(row):\n",
    "    score_eu = row[(\"score\", \"Euclidean\")]\n",
    "    score_cos = row[(\"score\", \"Cosine\")]\n",
    "    score_l1 = row[(\"score\", \"L1\")]\n",
    "    return max(score_cos, score_eu, score_l1) - min(score_cos, score_eu, score_l1)\n",
    "\n",
    "\n",
    "def make_projection_columns(df):\n",
    "    pivot = (\n",
    "        df[\n",
    "            (df.encoding_space != df.decoding_space)\n",
    "            & (df.projections != \"CenterCosine\")\n",
    "            & (df.projections != \"NormEuclidean\")\n",
    "            & (df.projections != \"Absolute\")\n",
    "            & (df.aggregation == \"LayerNorm\")\n",
    "        ]\n",
    "    ).drop(columns=[\"aggregation\"])\n",
    "\n",
    "    pivot = pivot.drop(columns=[\"num_anchors\", \"l1\", \"mse\", \"linear_cka\", \"cosine_sim\"]).pivot_table(\n",
    "        aggfunc=\"mean\",\n",
    "        columns=[\"projections\"],\n",
    "        index=[\n",
    "            \"seed\",\n",
    "            \"encoding_space\",\n",
    "            \"decoding_space\",\n",
    "        ],\n",
    "    )\n",
    "    #\n",
    "    pivot[\"MaxDiff\"] = pivot.apply(lambda x: compute_max_diff(x), axis=1)\n",
    "    pivot = pivot.sort_values(by=\"MaxDiff\", ascending=False).reset_index()\n",
    "    pivot.columns = pivot.columns.map(\"|\".join).str.strip(\"|\")\n",
    "    return pivot\n",
    "\n",
    "\n",
    "def get_latex_table(dataset_name):\n",
    "    df = pd.read_csv(\n",
    "        PROJECT_ROOT / f\"results_paper/stitching/{dataset_name}_train_test_1/stitching.tsv\",\n",
    "        sep=\"\\t\",\n",
    "    )\n",
    "\n",
    "    df[\"encoding_space\"] = df[\"encoding_space\"].map(names_mapping)\n",
    "    df[\"decoding_space\"] = df[\"decoding_space\"].map(names_mapping)\n",
    "    pivot = make_projection_columns(\n",
    "        df.drop(\n",
    "            columns=[\n",
    "                \"classifier\",\n",
    "            ]\n",
    "        )\n",
    "    )\n",
    "    res, tex = df2table_meanstd(\n",
    "        pivot,\n",
    "        rows=[\n",
    "            \"encoding_space\",\n",
    "            \"decoding_space\",\n",
    "        ],\n",
    "        metrics=[\n",
    "            \"score|Cosine\",\n",
    "            \"score|Euclidean\",\n",
    "            \"score|L1\",\n",
    "            \"score|Linf\",\n",
    "            \"MaxDiff\",\n",
    "        ],\n",
    "        postprocess_pivot=lambda x: x.sort_values(by=(\"mean\", \"MaxDiff\"), ascending=False),\n",
    "        caption=f\"Stitching results for {dataset_name}. The table shows the mean and standard deviation of the test accuracy for the different projection methods sorted by maximum difference between projections, reported in the first column.\",\n",
    "    )\n",
    "    return res, tex\n",
    "\n",
    "\n",
    "all_tex = []\n",
    "for dataset in datasets:\n",
    "    _, tex = get_latex_table(dataset)\n",
    "    all_tex.append(tex)\n",
    "\n",
    "print(\"\\n\\n\".join(all_tex))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
