{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import polars as pl\n",
    "\n",
    "ROOT_DIR = \"/storage/shared/mimic-iv/meds_v0.3.2/\"  # Replace with your actual root directory\n",
    "\n",
    "lab_to_codes = {}\n",
    "pl.Config.set_fmt_str_lengths(100)\n",
    "df = pl.read_parquet(f\"{ROOT_DIR}/meds/metadata/codes.parquet\")\n",
    "creatinine_codes = df.filter(\n",
    "    pl.col(\"description\").str.contains(\"Creatinine [Mass/volume] in Blood\", literal=True)\n",
    "    | pl.col(\"description\").str.contains(\"Creatinine [Mass/volume] in Serum or Plasma\", literal=True)\n",
    ")[\"code\"].to_list()\n",
    "lab_to_codes[\"creatinine\"] = \"|\".join(creatinine_codes)\n",
    "\n",
    "hemoglobin_codes = df.filter(\n",
    "    pl.col(\"description\").str.contains(\"Hemoglobin [Mass/volume] in Blood by calculation\", literal=True)\n",
    "    | pl.col(\"description\").str.contains(\"Hemoglobin [Mass/volume] in Blood\", literal=True)\n",
    ")[\"code\"].to_list()\n",
    "lab_to_codes[\"hemoglobin\"] = \"|\".join(hemoglobin_codes)\n",
    "\n",
    "hematocrit_codes = df.filter(\n",
    "    pl.col(\"description\").str.contains(\n",
    "        \"Hematocrit [Volume Fraction] of Blood by Automated count\", literal=True\n",
    "    )\n",
    "    | pl.col(\"description\").str.contains(\"Hematocrit [Volume Fraction] of Blood by Estimated\", literal=True)\n",
    ")[\"code\"].to_list()\n",
    "lab_to_codes[\"hematocrit\"] = \"|\".join(hematocrit_codes)\n",
    "\n",
    "\n",
    "leukocytes_codes = df.filter(\n",
    "    pl.col(\"description\").str.contains(\"Leukocytes [#/volume] in Blood by Automated count\", literal=True)\n",
    ")[\"code\"].to_list()\n",
    "lab_to_codes[\"leukocytes\"] = \"|\".join(leukocytes_codes)\n",
    "\n",
    "\n",
    "platets_codes = df.filter(\n",
    "    pl.col(\"description\").str.contains(\"Platelets [#/volume] in Blood by Automated count\", literal=True)\n",
    ")[\"code\"].to_list()\n",
    "lab_to_codes[\"platets\"] = \"|\".join(platets_codes)\n",
    "\n",
    "\n",
    "def get_aces_config(location, lab, time_interval, extrema):\n",
    "    lab_codes = lab_to_codes[lab]\n",
    "    min_val, max_val = extrema\n",
    "    if min_val is None and max_val is None:\n",
    "        raise ValueError(\"Can't define both min and max\")\n",
    "    if not (min_val is None or max_val is None):\n",
    "        raise ValueError(\"Defined neither min nor max\")\n",
    "    if min_val is not None:\n",
    "        # YES MIN SHOULD BE NAMED MAX!!! The min value cutoff is the max value for aces to search for when defining this predicate.\n",
    "        extrema_type = \"max\"\n",
    "        value = min_val\n",
    "    else:\n",
    "        extrema_type = \"min\"\n",
    "        value = max_val\n",
    "    lab_requirement = f\"\"\"  abnormal_lab:\n",
    "    code: {{regex: \"{lab_codes}\"}}\n",
    "    value_{extrema_type}: {value}\n",
    "    value_{extrema_type}_inclusive: True\n",
    "    \"\"\"\n",
    "    return f\"\"\"#This config checks for an abnormal {lab} lab within {time_interval} after {location}\n",
    "predicates:\n",
    "  trigger_event:\n",
    "    code: {{regex: \"{location}//.*\"}}\n",
    "  lab:\n",
    "    code: {{regex: \"{lab_codes}\"}}\n",
    "{lab_requirement}\n",
    "\n",
    "trigger: trigger_event\n",
    "\n",
    "windows:\n",
    "  input:\n",
    "    start: NULL\n",
    "    end: trigger\n",
    "    start_inclusive: True\n",
    "    end_inclusive: True\n",
    "    index_timestamp: end\n",
    "  target:\n",
    "    start: input.end\n",
    "    end: start + {time_interval}\n",
    "    start_inclusive: True\n",
    "    end_inclusive: True\n",
    "    has:\n",
    "      lab: (1, None)\n",
    "    label: abnormal_lab\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "results = {}\n",
    "tasks = []\n",
    "extrema = {\n",
    "    \"creatinine\": (None, 2.0),\n",
    "    \"hemoglobin\": (None, 2.0),\n",
    "    \"hematocrit\": (24, None),\n",
    "    \"leukocytes\": (5, None),\n",
    "    \"platets\": (20, None),\n",
    "}\n",
    "for location in [\n",
    "    \"HOSPITAL_ADMISSION\",\n",
    "    \"HOSPITAL_DISCHARGE\",\n",
    "    \"ICU_ADMISSION\",\n",
    "    \"ICU_DISCHARGE\",\n",
    "]:\n",
    "    for lab in [\"creatinine\", \"hemoglobin\", \"hematocrit\", \"leukocytes\", \"platets\"]:\n",
    "        for time_interval in [\"30d\", \"60d\", \"90d\"]:\n",
    "            config = get_aces_config(location, lab, time_interval, extrema[lab])\n",
    "            task_name = f\"abnormal_lab/{location.lower()}/{lab}/{time_interval}\"\n",
    "            fp = Path(f\"../ZERO_SHOT_TUTORIAL/configs/tasks/eic/{task_name}.yaml\")\n",
    "            fp.parent.mkdir(parents=True, exist_ok=True)\n",
    "            fp.write_text(config)\n",
    "            tasks.append('\"' + task_name + '\"')\n",
    "print(\"\\n\".join(tasks))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "meds-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
