{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ecb813",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "perf_amazon_reviews_multi_en_fine = \"nlp_stitching-amazon_reviews_multi_en_fine_grained.tsv\"\n",
    "perf_amazon_reviews_multi_en_coarse = \"nlp_stitching-amazon_reviews_multi_en_coarse.tsv\"\n",
    "\n",
    "perf_dbpedia_14 = \"nlp_stitching-dbpedia_14.tsv\"\n",
    "perf_trec = \"nlp_stitching-trec.tsv\"\n",
    "\n",
    "\n",
    "trec = pd.read_csv(perf_trec, sep=\"\\t\", index_col=0)\n",
    "trec\n",
    "\n",
    "dbpedia = pd.read_csv(perf_dbpedia_14, sep=\"\\t\", index_col=0)\n",
    "dbpedia\n",
    "\n",
    "amazon_coarse = pd.read_csv(perf_amazon_reviews_multi_en_coarse, sep=\"\\t\", index_col=0)\n",
    "amazon_coarse\n",
    "\n",
    "amazon_fine = pd.read_csv(perf_amazon_reviews_multi_en_fine, sep=\"\\t\", index_col=0)\n",
    "amazon_fine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "335685a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_results(df, drop_transformers=[]):\n",
    "    for transformer_to_drop in drop_transformers:\n",
    "        df = df[df.embed_transformer != transformer_to_drop]\n",
    "        df = df[df.classifier_transformer != transformer_to_drop]\n",
    "    o = (\n",
    "        df.drop(columns=[\"seed\", \"precision\", \"recall\", \"embed_transformer\", \"classifier_transformer\"])\n",
    "        .groupby([\"embed_type\", \"stitched\"])\n",
    "        .agg([np.mean, np.std, \"count\"])\n",
    "    )\n",
    "    print(o)\n",
    "    return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da8d207",
   "metadata": {},
   "outputs": [],
   "source": [
    "trec_df = display_results(trec, drop_transformers=[\"xlm-roberta-base\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f269e2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dbpedia_df = display_results(dbpedia, drop_transformers=[\"xlm-roberta-base\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b58c3613",
   "metadata": {},
   "outputs": [],
   "source": [
    "amazon_coarse_df = display_results(amazon_coarse, drop_transformers=[\"xlm-roberta-base\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6ea66cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "amazon_fine_df = display_results(amazon_fine, drop_transformers=[\"xlm-roberta-base\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9891d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def latex_float(f):\n",
    "    float_str = \"{0:.2f}\".format(f)\n",
    "    if \"e\" in float_str:\n",
    "        base, exponent = float_str.split(\"e\")\n",
    "        return r\"{0} \\times 10^{{{1}}}\".format(base, int(exponent))\n",
    "    else:\n",
    "        return float_str\n",
    "\n",
    "\n",
    "def extract_mean_std(df: pd.DataFrame, model_type: str, stitching: bool) -> str:\n",
    "    df = df * 100\n",
    "    df = df.round(2)\n",
    "    try:\n",
    "        mean_std = df.loc[model_type, stitching]\n",
    "        return rf\"${latex_float(mean_std['fscore']['mean'])} \\pm {latex_float(mean_std['fscore']['std'])}$\"\n",
    "    except (AttributeError, KeyError):\n",
    "        return \"?\"\n",
    "\n",
    "\n",
    "classification_rel = r\"{} & {} & {} & {} & {} \\\\[1ex]\"\n",
    "\n",
    "for available_model_type in (\"absolute\", \"relative\"):\n",
    "    for stitching, stit_name in zip(\n",
    "        [\n",
    "            False,\n",
    "            True,\n",
    "        ],\n",
    "        [\n",
    "            \"Non-Stitch\",\n",
    "            \"Stitch\",\n",
    "        ],\n",
    "    ):\n",
    "\n",
    "        s = classification_rel.format(\n",
    "            available_model_type,\n",
    "            stit_name,\n",
    "            *[\n",
    "                extract_mean_std(df, available_model_type, stitching)\n",
    "                for df in [trec_df, dbpedia_df, amazon_coarse_df, amazon_fine]\n",
    "            ],\n",
    "        )\n",
    "        print(s)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
