{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cc47bbc7",
   "metadata": {},
   "source": [
    "### Input data\n",
    "\n",
    "1. The predicted probabilities file `best_embs_full.py` can be obtained after running `train.py`.\n",
    "2. `processed_data_test.csv` is from `prepare_data.ipynb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ad2e3e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "test_embs = np.load('best_embs_full.npy')\n",
    "test_interaction_df = pd.read_csv('processed_data_test.csv', index_col=0)\n",
    "test_interaction_df['prob'] = test_embs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b984bda",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_interaction_df.to_csv('covid_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e68f1c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "targets = test_interaction_df['target'].tolist()\n",
    "labels = test_interaction_df['groups'].tolist()\n",
    "names = test_interaction_df['names'].tolist()\n",
    "origins = test_interaction_df['origins'].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bcca769",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_curve, auc\n",
    "\n",
    "def calculate_auprc_per_label(y_true, y_pred, label_types):\n",
    "    # Create a dictionary to store true labels and predictions for each label type\n",
    "    label_data = {}\n",
    "    \n",
    "    for label, true_val, pred_val in zip(label_types, y_true, y_pred):\n",
    "        if label not in label_data:\n",
    "            label_data[label] = {'y_true': [], 'y_pred': []}\n",
    "        \n",
    "        label_data[label]['y_true'].append(true_val)\n",
    "        label_data[label]['y_pred'].append(pred_val)\n",
    "    \n",
    "    # Calculate AUPRC for each label type\n",
    "    auprc_per_label = {}\n",
    "    \n",
    "    for label, data in label_data.items():\n",
    "        if len(data['y_true']) > 3:\n",
    "            precision, recall, _ = precision_recall_curve(data['y_true'], data['y_pred'])\n",
    "            auprc = auc(recall, precision)\n",
    "            auprc_per_label[label] = (auprc, np.unique(data['y_true'], return_counts=True)[1])\n",
    "    return auprc_per_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d5a9c1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "calculate_auprc_per_label(targets, test_embs, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9f2e85",
   "metadata": {},
   "outputs": [],
   "source": [
    "calculate_auprc_per_label(targets, test_embs, origins)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
