{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4768c33",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "fine_grained: bool = True\n",
    "# anchor_dataset_name: str = \"amazon_translated\" # wikimatrix, amazon_translated\n",
    "train_perc: float = 0.25\n",
    "COLUMNS_TO_DROP = [\"precision\", \"recall\"]\n",
    "\n",
    "\n",
    "def read_df(fine_grained, anchor_dataset_name, train_perc):\n",
    "\n",
    "    full_df = pd.read_csv(\n",
    "        f\"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv\",\n",
    "        sep=\"\\t\",\n",
    "        index_col=0,\n",
    "    )\n",
    "    return full_df\n",
    "\n",
    "\n",
    "def rearrange_embedtype_as_column(mydf, domain):\n",
    "    relative_out = mydf[mydf[(\"embed_type\", \"\")] == \"relative\"]\n",
    "    relative_out.columns = pd.MultiIndex.from_tuples(\n",
    "        [\n",
    "            (\"seed\", \"\", \"\"),\n",
    "            (\"embed_type\", \"\", \"\"),\n",
    "            (\"train_lang\", \"\", \"\"),\n",
    "            (\"test_lang\", \"\", \"\"),\n",
    "            (\"Relative\", domain, \"fscore\"),\n",
    "            (\"Relative\", domain, \"mae\"),\n",
    "            (\"stitched\", \"\", \"\"),\n",
    "        ],\n",
    "    )\n",
    "    absolute_out = mydf[mydf[(\"embed_type\", \"\")] == \"absolute\"]\n",
    "    absolute_out.columns = pd.MultiIndex.from_tuples(\n",
    "        [\n",
    "            (\"seed\", \"\", \"\"),\n",
    "            (\"embed_type\", \"\", \"\"),\n",
    "            (\"train_lang\", \"\", \"\"),\n",
    "            (\"test_lang\", \"\", \"\"),\n",
    "            (\"Absolute\", domain, \"fscore\"),\n",
    "            (\"Absolute\", domain, \"mae\"),\n",
    "            (\"stitched\", \"\", \"\"),\n",
    "        ],\n",
    "    )\n",
    "    return pd.merge(\n",
    "        relative_out.drop(columns=[\"embed_type\"]),\n",
    "        absolute_out.drop(columns=[\"embed_type\"]),\n",
    "        on=[\n",
    "            (\"train_lang\", \"\", \"\"),\n",
    "            (\"test_lang\", \"\", \"\"),\n",
    "            (\"seed\", \"\", \"\"),\n",
    "            (\"stitched\", \"\", \"\"),\n",
    "        ],\n",
    "    )\n",
    "\n",
    "\n",
    "domain = \"In Domain\"\n",
    "full_in_domain = read_df(fine_grained=fine_grained, anchor_dataset_name=\"amazon_translated\", train_perc=train_perc)\n",
    "full_in_domain = full_in_domain.drop(columns=COLUMNS_TO_DROP)\n",
    "full_in_domain[\"fscore\"] = full_in_domain[\"fscore\"] * 100\n",
    "full_in_domain.columns = pd.MultiIndex.from_tuples(\n",
    "    [\n",
    "        (\"seed\", \"\"),\n",
    "        (\"embed_type\", \"\"),\n",
    "        (\"train_lang\", \"\"),\n",
    "        (\"test_lang\", \"\"),\n",
    "        # ('In Domain',  'precision'),\n",
    "        # ('In Domain',     'recall'),\n",
    "        (domain, \"fscore\"),\n",
    "        (domain, \"mae\"),\n",
    "        (\"stitched\", \"\"),\n",
    "    ],\n",
    ")\n",
    "full_in_domain = rearrange_embedtype_as_column(full_in_domain, domain=domain)\n",
    "\n",
    "domain = \"Out Domain\"\n",
    "full_out_domain = read_df(fine_grained=fine_grained, anchor_dataset_name=\"wikimatrix\", train_perc=train_perc)\n",
    "full_out_domain = full_out_domain.drop(columns=COLUMNS_TO_DROP)\n",
    "full_out_domain[\"fscore\"] = full_out_domain[\"fscore\"] * 100\n",
    "full_out_domain.columns = pd.MultiIndex.from_tuples(\n",
    "    [\n",
    "        (\"seed\", \"\"),\n",
    "        (\"embed_type\", \"\"),\n",
    "        (\"train_lang\", \"\"),\n",
    "        (\"test_lang\", \"\"),\n",
    "        # ('Out Domain',  'precision'),\n",
    "        # ('Out Domain',     'recall'),\n",
    "        (domain, \"fscore\"),\n",
    "        (domain, \"mae\"),\n",
    "        (\"stitched\", \"\"),\n",
    "    ],\n",
    ")\n",
    "full_out_domain = rearrange_embedtype_as_column(full_out_domain, domain=domain)\n",
    "\n",
    "df = pd.merge(\n",
    "    full_in_domain,\n",
    "    full_out_domain,\n",
    "    on=[\n",
    "        (\n",
    "            \"seed\",\n",
    "            \"\",\n",
    "            \"\",\n",
    "        ),\n",
    "        (\n",
    "            \"train_lang\",\n",
    "            \"\",\n",
    "            \"\",\n",
    "        ),\n",
    "        (\n",
    "            \"test_lang\",\n",
    "            \"\",\n",
    "            \"\",\n",
    "        ),\n",
    "        (\n",
    "            \"stitched\",\n",
    "            \"\",\n",
    "            \"\",\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "full_df = df.drop(\n",
    "    columns=[\n",
    "        (\n",
    "            \"seed\",\n",
    "            \"\",\n",
    "            \"\",\n",
    "        ),\n",
    "        (\"stitched\", \"\", \"\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "train_lang = \"Train Lang\"\n",
    "test_lang = \"Test Lang\"\n",
    "full_df = full_df.rename(columns={\"train_lang\": train_lang, \"test_lang\": test_lang})\n",
    "full_df = full_df[\n",
    "    [\n",
    "        (\"Train Lang\", \"\", \"\"),\n",
    "        (\"Test Lang\", \"\", \"\"),\n",
    "        (\"Absolute\", \"In Domain\", \"fscore\"),\n",
    "        (\"Absolute\", \"In Domain\", \"mae\"),\n",
    "        (\"Relative\", \"In Domain\", \"fscore\"),\n",
    "        (\"Relative\", \"In Domain\", \"mae\"),\n",
    "        (\"Relative\", \"Out Domain\", \"fscore\"),\n",
    "        (\"Relative\", \"Out Domain\", \"mae\"),\n",
    "        (\"Absolute\", \"Out Domain\", \"fscore\"),\n",
    "        (\"Absolute\", \"Out Domain\", \"mae\"),\n",
    "    ]\n",
    "]\n",
    "full_df = full_df.drop(columns=[(\"Absolute\", \"Out Domain\")])\n",
    "full_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1ac6f98",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_latex(df, label):\n",
    "    return df.to_latex(\n",
    "        escape=False,\n",
    "        caption=f\"Fine-grained: {fine_grained}, Train perc: {train_perc}\",\n",
    "        label=f'tab:multilingual-{label}-{\"fine\" if fine_grained else \"coarse\"}-grained',\n",
    "        multirow=True,\n",
    "        sparsify=True,\n",
    "        multicolumn_format=\"c\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a63f782d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option(\"display.max_rows\", None)\n",
    "MEAN_STD_FORMAT = r\"${:.2f} \\pm {:.2f}$\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4db68dc0",
   "metadata": {},
   "source": [
    "# SupMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7d1a8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = (\n",
    "    full_df.groupby(\n",
    "        [(train_lang, \"\", \"\"), (test_lang, \"\", \"\")],\n",
    "    )\n",
    "    .agg([np.mean, np.std])\n",
    "    .round(2)\n",
    ")\n",
    "o = df.copy()\n",
    "for embed in (\n",
    "    \"Absolute\",\n",
    "    \"Relative\",\n",
    "):\n",
    "    for domain in (\"In Domain\", \"Out Domain\"):\n",
    "        if embed == \"Absolute\" and domain == \"Out Domain\":\n",
    "            continue\n",
    "        for metric, new_name in ((\"fscore\", \"FScore\"), (\"mae\", \"MAE\")):\n",
    "            df[(embed, domain, new_name, \"\")] = df.apply(\n",
    "                lambda row: MEAN_STD_FORMAT.format(\n",
    "                    row[(embed, domain, metric, \"mean\")], row[(embed, domain, metric, \"std\")]\n",
    "                ),\n",
    "                axis=1,\n",
    "            )\n",
    "            for agg in (\"mean\", \"std\"):\n",
    "                df = df.drop(columns=[(embed, domain, metric, agg)])\n",
    "\n",
    "from IPython.display import Latex\n",
    "from IPython.display import display\n",
    "\n",
    "print(to_latex(df, \"full\"))\n",
    "o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbf6dee0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "85b0093d",
   "metadata": {},
   "source": [
    "# Main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52124655",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = full_df[full_df[(train_lang, \"\", \"\")] == \"en\"]\n",
    "df = (\n",
    "    df.groupby(\n",
    "        [(train_lang, \"\", \"\"), (test_lang, \"\", \"\")],\n",
    "    )\n",
    "    .agg([np.mean, np.std])\n",
    "    .round(2)\n",
    ")\n",
    "\n",
    "o = df.copy()\n",
    "for embed in (\n",
    "    \"Absolute\",\n",
    "    \"Relative\",\n",
    "):\n",
    "    for domain in (\"In Domain\", \"Out Domain\"):\n",
    "        if embed == \"Absolute\" and domain == \"Out Domain\":\n",
    "            continue\n",
    "        for metric, new_name in ((\"fscore\", \"FScore\"), (\"mae\", \"MAE\")):\n",
    "            df[(embed, domain, new_name, \"\")] = df.apply(\n",
    "                lambda row: MEAN_STD_FORMAT.format(\n",
    "                    row[(embed, domain, metric, \"mean\")], row[(embed, domain, metric, \"std\")]\n",
    "                ),\n",
    "                axis=1,\n",
    "            )\n",
    "            for agg in (\"mean\", \"std\"):\n",
    "                df = df.drop(columns=[(embed, domain, metric, agg)])\n",
    "\n",
    "from IPython.display import Latex\n",
    "from IPython.display import display\n",
    "\n",
    "print(to_latex(df, \"en\"))\n",
    "o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4baf885",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb4ecff2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
