{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Macro metrics ===\n",
      "precision_macro: 0.0100\n",
      "recall_macro: 0.0033\n",
      "f1_macro: 0.0050\n",
      "exact_match_rate: 0.0000\n",
      "any_hit_rate: 0.0500\n",
      "avg_num_final_dx: 4.2500\n",
      "avg_confidence: 0.9089\n",
      "Wrote recllama_output_with_metrics.csv\n"
     ]
    }
   ],
   "source": [
    "import ast, json, re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def norm_code(c: str) -> str:\n",
    "    c = str(c).strip()\n",
    "    c = c.replace(\"ICD-9:\", \"\").replace(\"ICD9:\", \"\").replace(\"ICD:\", \"\")\n",
    "    c = c.replace(\".\", \"\").replace(\" \", \"\")\n",
    "    m = re.search(r\"\\b(\\d{3,5})\\b\", c)\n",
    "    return m.group(1) if m else \"\"\n",
    "\n",
    "def parse_gold_list(s):\n",
    "    try:\n",
    "        vals = ast.literal_eval(str(s))\n",
    "        return sorted({norm_code(v) for v in vals if norm_code(v)})\n",
    "    except Exception:\n",
    "        return []\n",
    "\n",
    "def parse_pred_codes_conf_from_json(s):\n",
    "    \"\"\"\n",
    "    final_diagnoses_json -> (codes, confidences)\n",
    "    Accepts list of dicts: [{\"label\":\"75139\",\"confidence\":0.92}, ...]\n",
    "    \"\"\"\n",
    "    try:\n",
    "        arr = json.loads(str(s))\n",
    "        if isinstance(arr, list):\n",
    "            codes, confs = [], []\n",
    "            for d in arr:\n",
    "                if not isinstance(d, dict):\n",
    "                    continue\n",
    "                code = norm_code(d.get(\"label\", \"\"))\n",
    "                if code:\n",
    "                    codes.append(code)\n",
    "                    \n",
    "                    try:\n",
    "                        confs.append(float(d.get(\"confidence\", np.nan)))\n",
    "                    except Exception:\n",
    "                        confs.append(np.nan)\n",
    "            \n",
    "            confs = [c for c in confs if isinstance(c, (int, float)) and not np.isnan(c)]\n",
    "            return sorted(set(codes)), confs\n",
    "    except Exception:\n",
    "        pass\n",
    "    return [], []\n",
    "\n",
    "def parse_pred_from_text(s):\n",
    "    \"\"\"Regex from final_explanation free text (no confidences available).\"\"\"\n",
    "    text = str(s)\n",
    "    codes = set()\n",
    "\n",
    "    for m in re.finditer(r\"(?:ICD-?9|ICD)\\s*[:#]?\\s*([0-9]{3,5}(?:\\.[0-9A-Za-z]+)?)\",\n",
    "                         text, flags=re.IGNORECASE):\n",
    "        codes.add(norm_code(m.group(1)))\n",
    "\n",
    "    window_hits = re.findall(\n",
    "        r\"(?:diagnos\\w*|possible|likely|code|ICD-?9|ICD)\\D{0,20}([0-9]{3,5}(?:\\.[0-9A-Za-z]+)?)\",\n",
    "        text, flags=re.IGNORECASE\n",
    "    )\n",
    "    for w in window_hits:\n",
    "        codes.add(norm_code(w))\n",
    "\n",
    "    if not codes:\n",
    "        for m in re.finditer(r\"\\b([0-9]{3,5})(?:\\.[0-9A-Za-z]+)?\\b\", text):\n",
    "            codes.add(norm_code(m.group(1)))\n",
    "\n",
    "    codes.discard(\"\")\n",
    "    return sorted(codes)\n",
    "\n",
    "def prf(pred, gold):\n",
    "    pset, gset = set(pred), set(gold)\n",
    "    if not pset and not gset:\n",
    "        return (1.0, 1.0, 1.0, True, False)\n",
    "    if not pset:\n",
    "        return (0.0, 0.0, 0.0, False, False)\n",
    "    tp = len(pset & gset)\n",
    "    prec = tp / len(pset)\n",
    "    rec  = tp / len(gset) if gset else 1.0\n",
    "    f1   = 0.0 if (prec+rec)==0 else 2*prec*rec/(prec+rec)\n",
    "    exact = (pset == gset)\n",
    "    any_hit = tp > 0\n",
    "    return (prec, rec, f1, exact, any_hit)\n",
    "\n",
    "\n",
    "df = pd.read_csv(\"questions_CE_RF_RSN.csv\")  \n",
    "\n",
    "# gold labels\n",
    "df[\"gold_icd9\"] = df[\"diagnoses_icd9\"].apply(parse_gold_list)\n",
    "\n",
    "\n",
    "codes_confs = df.get(\"final_diagnoses_json\", pd.Series([\"\"]*len(df))).apply(parse_pred_codes_conf_from_json)\n",
    "df[\"pred_from_json\"] = [cc[0] for cc in codes_confs]\n",
    "df[\"conf_from_json\"] = [cc[1] for cc in codes_confs]\n",
    "\n",
    "\n",
    "df[\"pred_from_text\"] = df[\"final_explanation\"].apply(parse_pred_from_text)\n",
    "\n",
    "\n",
    "def choose_pred(row):\n",
    "    if row[\"pred_from_json\"]:\n",
    "        return row[\"pred_from_json\"], row[\"conf_from_json\"]\n",
    "    return row[\"pred_from_text\"], []  \n",
    "\n",
    "chosen = df.apply(choose_pred, axis=1)\n",
    "df[\"pred_icd9\"] = [c[0] for c in chosen]\n",
    "df[\"pred_conf_list\"] = [c[1] for c in chosen]\n",
    "\n",
    "\n",
    "df[\"final_dx_count\"] = df[\"pred_icd9\"].apply(len)\n",
    "def avg_conf(xs):\n",
    "    xs = [x for x in xs if isinstance(x,(int,float)) and not np.isnan(x)]\n",
    "    return float(np.mean(xs)) if xs else np.nan\n",
    "df[\"final_dx_avg_conf\"] = df[\"pred_conf_list\"].apply(avg_conf)\n",
    "\n",
    "\n",
    "metrics = df.apply(lambda r: prf(r[\"pred_icd9\"], r[\"gold_icd9\"]), axis=1)\n",
    "df[\"prec\"], df[\"recall\"], df[\"f1\"], df[\"exact_match\"], df[\"any_hit\"] = zip(*metrics)\n",
    "\n",
    "\n",
    "macro = {\n",
    "    \"precision_macro\":   float(np.nanmean(df[\"prec\"])),\n",
    "    \"recall_macro\":      float(np.nanmean(df[\"recall\"])),\n",
    "    \"f1_macro\":          float(np.nanmean(df[\"f1\"])),\n",
    "    \"exact_match_rate\":  float(np.nanmean(df[\"exact_match\"])),\n",
    "    \"any_hit_rate\":      float(np.nanmean(df[\"any_hit\"])),\n",
    "    \"avg_num_final_dx\":  float(np.nanmean(df[\"final_dx_count\"])),\n",
    "    \"avg_confidence\":    float(np.nanmean(df[\"final_dx_avg_conf\"]))  # averaged over rows with confidences\n",
    "}\n",
    "\n",
    "print(\"=== Macro metrics ===\")\n",
    "for k, v in macro.items():\n",
    "    print(f\"{k}: {v:.4f}\")\n",
    "\n",
    "\n",
    "df.to_csv(\"recllama_output_with_metrics.csv\", index=False, encoding=\"utf-8-sig\")\n",
    "print(\"Wrote recllama_output_with_metrics.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (oregano)",
   "language": "python",
   "name": "oregano"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.22"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
