{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "EXPERIMENT_ROOT = Path(\".\").parent\n",
    "PERFORMANCE_TSV = EXPERIMENT_ROOT / \"performance.tsv\"\n",
    "PREDICTIONS_TSV = EXPERIMENT_ROOT / \"predictions.tsv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "predictions_df_orig = pd.read_csv(PREDICTIONS_TSV, sep=\"\\t\", index_col=0)\n",
    "predictions_df_orig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchmetrics import Accuracy, F1Score, MetricCollection, Recall, Precision\n",
    "\n",
    "CONSIDERED_METRICS = {\n",
    "    \"acc/macro\": lambda num_classes: Accuracy(average=\"micro\", num_classes=num_classes),\n",
    "    \"acc/micro\": lambda num_classes: Accuracy(average=\"macro\", num_classes=num_classes),\n",
    "    \"acc/weighted\": lambda num_classes: Accuracy(average=\"weighted\", num_classes=num_classes),\n",
    "    \"f1/macro\": lambda num_classes: F1Score(average=\"macro\", num_classes=num_classes),\n",
    "    \"f1/micro\": lambda num_classes: F1Score(average=\"micro\", num_classes=num_classes),\n",
    "    \"f1/weighted\": lambda num_classes: F1Score(average=\"weighted\", num_classes=num_classes),\n",
    "    \"recall/macro\": lambda num_classes: Recall(average=\"macro\", num_classes=num_classes),\n",
    "    \"recall/micro\": lambda num_classes: Recall(average=\"micro\", num_classes=num_classes),\n",
    "    \"recall/weighted\": lambda num_classes: Recall(average=\"weighted\", num_classes=num_classes),\n",
    "    \"precision/macro\": lambda num_classes: Precision(average=\"macro\", num_classes=num_classes),\n",
    "    \"precision/micro\": lambda num_classes: Precision(average=\"micro\", num_classes=num_classes),\n",
    "    \"precision/weighted\": lambda num_classes: Precision(average=\"weighted\", num_classes=num_classes),\n",
    "}\n",
    "\n",
    "PERFORMANCE_TSV.unlink(missing_ok=True)\n",
    "\n",
    "DATASET_NUM_CLASSES = {\"trec-coarse\": 6, \"trec-fine\": 24, \"trec\": 6}\n",
    "performance = {\n",
    "    **{\n",
    "        x: []\n",
    "        for x in (\n",
    "            \"run_id_a\",\n",
    "            \"run_id_b\",\n",
    "            \"model_type\",\n",
    "            \"dataset_name\",\n",
    "            \"embedder_a\",\n",
    "            \"embedder_b\",\n",
    "        )\n",
    "    },\n",
    "    **{k: [] for k in CONSIDERED_METRICS.keys()},\n",
    "}\n",
    "\n",
    "KEYS = [\"stitching\", \"run_id_a\", \"run_id_b\", \"model_type\", \"dataset_name\", \"embedder_a\", \"embedder_b\"]\n",
    "predictions_df = predictions_df_orig.groupby(KEYS)\n",
    "for (values, aggregate_df) in predictions_df:\n",
    "    key2value = dict(zip(KEYS, values))\n",
    "    aggregate_df: pd.DataFrame\n",
    "\n",
    "    metrics = MetricCollection(\n",
    "        {\n",
    "            key: metric(num_classes=DATASET_NUM_CLASSES[key2value[\"dataset_name\"]])\n",
    "            for key, metric in CONSIDERED_METRICS.items()\n",
    "        }\n",
    "    )\n",
    "    run_predictions = torch.as_tensor(aggregate_df[\"pred\"].values)\n",
    "    run_targets = torch.as_tensor(aggregate_df[\"target\"].values)\n",
    "\n",
    "    metrics.update(run_predictions, run_targets)\n",
    "\n",
    "    for key, value in key2value.items():\n",
    "        if key in performance:\n",
    "            performance[key].append(value)\n",
    "\n",
    "    for metric_name, metric_value in metrics.compute().items():\n",
    "        performance[metric_name].append(metric_value.item())\n",
    "performance_df = pd.DataFrame(performance)\n",
    "performance_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
