{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------------------------------------------------\n",
    "# 0.  CONFIG  – change only these two paths\n",
    "# ---------------------------------------------------------------------\n",
    "PRED_CSV      = Path(\"PedXBench/data/outputs/llm_runs/llm_2stage_fc_predictions_o1.csv\")\n",
    "MANUAL_XLSX   = Path(\"PedXBench/data/outputs/manually_annotated_labels_100_final.xlsx\")\n",
    "\n",
    "\n",
    "PRED_LABELCOL = \"resolved_label\"      # in the LLM file\n",
    "MAN_LABELCOL  = \"resovled_label_A\"    # in the manual file\n",
    "# ---------------------------------------------------------------------\n",
    "\n",
    "import re, pandas as pd\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, classification_report,\n",
    "    accuracy_score, f1_score\n",
    ")\n",
    "\n",
    "########################################################################\n",
    "# 1.  helper – extract a canon-id like  \"NDA_21505\"\n",
    "########################################################################\n",
    "ID_RE   = re.compile(r\"\\b(?P<prefix>NDA|ANDA|BLA)\\s*[-_/]?\\s*(?P<num>\\d{5,7})\",\n",
    "                     re.I)\n",
    "\n",
    "def extract_canon(cell: str) -> str | None:\n",
    "    if not isinstance(cell, str):\n",
    "        return None\n",
    "    m = ID_RE.search(cell)\n",
    "    if m:\n",
    "        return f\"{m.group('prefix').upper()}_{int(m.group('num')):05d}\"\n",
    "    return None                   # nothing found → row will be dropped later\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Merged rows: 200  /  gold rows: 128\n",
      "\n",
      "=== Confusion matrix ===\n",
      "                pred_None  pred_Partial  pred_Full  pred_Unlabeled\n",
      "gold_None             106             5          0              14\n",
      "gold_Partial           20            34          0               3\n",
      "gold_Full               0             0          1               1\n",
      "gold_Unlabeled          9             0          0               7\n",
      "\n",
      "=== Classification report ===\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "        None      0.785     0.848     0.815       125\n",
      "     Partial      0.872     0.596     0.708        57\n",
      "        Full      1.000     0.500     0.667         2\n",
      "   Unlabeled      0.280     0.438     0.341        16\n",
      "\n",
      "    accuracy                          0.740       200\n",
      "   macro avg      0.734     0.595     0.633       200\n",
      "weighted avg      0.772     0.740     0.745       200\n",
      "\n",
      "Accuracy : 0.740\n",
      "Macro-F1 : 0.633\n"
     ]
    }
   ],
   "source": [
    "import re, json\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import (accuracy_score, f1_score, confusion_matrix,\n",
    "                             classification_report)\n",
    "\n",
    "# ───────────────────────────── paths ─────────────────────────────\n",
    "PRED_CSV      = Path(\"/PedXBench/data/processed/llm_2stage_fc_predictions_full.csv\") ## \n",
    "MANUAL_XLSX   = Path(\"/PedXBench/data/outputs/manually_annotated_labels_100_final.xlsx\")\n",
    "\n",
    "\n",
    "PRED_LABEL_COL   = \"resolved_label\"      # column in CSV with the LLM label\n",
    "MANUAL_LABEL_COL = \"resovled_label_A\"    # column in XLSX with the gold label\n",
    "\n",
    "ID_PAT = re.compile(r\"\\d{5,7}\")          # 5–7 digit NDA/BLA/ANDA number\n",
    "\n",
    "def canon(cell: str | float) -> str | None:\n",
    "    \"\"\"Return the first 5-7 digit number inside *cell*, else None.\"\"\"\n",
    "    if pd.isna(cell):\n",
    "        return None\n",
    "    m = ID_PAT.findall(str(cell))\n",
    "    return m[0].lstrip(\"0\") if m else None    # strip leading zeros for safety\n",
    "\n",
    "LABEL_MAP = {\n",
    "    \"notextrapolated\": \"None\",\n",
    "    \"none\":            \"None\",\n",
    "    \"partial\":         \"Partial\",\n",
    "    \"full\":            \"Full\",\n",
    "    \"unlabeled\":       \"Unlabeled\",\n",
    "    \"unlabelled\":      \"Unlabeled\",\n",
    "}\n",
    "\n",
    "def norm(lbl):\n",
    "    if not isinstance(lbl, str):\n",
    "        return np.nan\n",
    "    return LABEL_MAP.get(lbl.strip().lower(), lbl)\n",
    "\n",
    "# ─────────────────────── load & clean ────────────────────────────\n",
    "read_opts = dict(dtype=str, keep_default_na=False)\n",
    "\n",
    "pred = (pd.read_csv(PRED_CSV, **read_opts)\n",
    "          .assign(canon_id=lambda d:\n",
    "                  d.get(\"canon_id\", d.get(\"app_id\")).map(canon))\n",
    "          .rename(columns={PRED_LABEL_COL: \"pred_label\"})\n",
    "          .assign(pred_label=lambda d: d[\"pred_label\"].map(norm))\n",
    "          .dropna(subset=[\"canon_id\", \"pred_label\"]))\n",
    "\n",
    "manual = (pd.read_excel(MANUAL_XLSX, engine=\"openpyxl\", **read_opts)\n",
    "            .assign(canon_id=lambda d:\n",
    "                    d[\"FDA Application Number(s) \"].map(canon))\n",
    "            .rename(columns={MANUAL_LABEL_COL: \"gold_label\"})\n",
    "            .assign(gold_label=lambda d: d[\"gold_label\"].map(norm))\n",
    "            .dropna(subset=[\"canon_id\", \"gold_label\"]))\n",
    "\n",
    "# ───────────────────── merge & evaluate ──────────────────────────\n",
    "merged = pd.merge(manual, pred, on=\"canon_id\", how=\"inner\")\n",
    "print(f\"Merged rows: {len(merged)}  /  gold rows: {len(manual)}\")\n",
    "\n",
    "if merged.empty:\n",
    "    raise SystemExit(\"❌ No overlapping IDs – check the canon-ID extraction.\")\n",
    "\n",
    "y_true, y_pred = merged[\"gold_label\"], merged[\"pred_label\"]\n",
    "order = [\"None\", \"Partial\", \"Full\", \"Unlabeled\"]\n",
    "\n",
    "print(\"\\n=== Confusion matrix ===\")\n",
    "print(pd.DataFrame(confusion_matrix(y_true, y_pred, labels=order),\n",
    "                   index=[f\"gold_{l}\" for l in order],\n",
    "                   columns=[f\"pred_{l}\" for l in order]))\n",
    "\n",
    "print(\"\\n=== Classification report ===\")\n",
    "print(classification_report(y_true, y_pred, labels=order,\n",
    "                            digits=3, zero_division=0))\n",
    "\n",
    "acc  = accuracy_score(y_true, y_pred)\n",
    "f1   = f1_score(y_true, y_pred, average=\"macro\", zero_division=0)\n",
    "print(f\"Accuracy : {acc:.3f}\")\n",
    "print(f\"Macro-F1 : {f1:.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peds-agent-venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
