{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aac04003-4112-4af3-833b-c714159a48de",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Begin with some analysis\n",
    "import pandas as pd\n",
    "\n",
    "data_file = \"./ehr_data/diagnosis.csv\"\n",
    "\n",
    "print(\"Loading diagnosis data...\")\n",
    "df_diag = pd.read_csv(data_file, dtype=str)\n",
    "\n",
    "print(f\"Loaded {len(df_diag)} rows.\")\n",
    "print(df_diag.columns)\n",
    "\n",
    "print(\"\\nTop 20 most common diagnosis names:\")\n",
    "print(df_diag['DX_NAME'].value_counts().head(20))\n",
    "\n",
    "print(\"\\nTop 20 most common diagnosis codes:\")\n",
    "print(df_diag['DX_CODE'].value_counts().head(20))\n",
    "\n",
    "keywords = {\n",
    "    'sleep_apnea': ['sleep apnea', 'obstructive sleep apnea'],\n",
    "    'asthma': ['asthma'],\n",
    "    'obesity': ['obesity'],\n",
    "    'diabetes': ['diabetes'],\n",
    "    'hypertension': ['hypertension'],\n",
    "    'depression': ['depression', 'mood disorder'],\n",
    "    'anxiety': ['anxiety'],\n",
    "    'adhd': ['adhd'],\n",
    "}\n",
    "\n",
    "def match_any(text, terms):\n",
    "    text = str(text).lower()\n",
    "    return any(t in text for t in terms)\n",
    "\n",
    "for label, terms in keywords.items():\n",
    "    df_diag[label] = df_diag['DX_NAME'].apply(lambda x: int(match_any(x, terms)))\n",
    "\n",
    "print(\"\\nComorbidity counts:\")\n",
    "print(df_diag[list(keywords.keys())].sum().sort_values(ascending=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84c94587-5532-4cf4-a293-ec209cdf82b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "demo_path = \"./ehr_data/demographic.csv\"\n",
    "df = pd.read_csv(demo_path, dtype=str)\n",
    "\n",
    "df['RACE_DESCR'] = df['RACE_DESCR'].fillna('UNKNOWN').str.strip().str.upper()\n",
    "race_counts = df['RACE_DESCR'].value_counts(dropna=False)\n",
    "\n",
    "print(\"Unique race descriptions:\\n\")\n",
    "print(race_counts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5617d0e3-b976-4d15-b4cb-45ac3574ad49",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "\n",
    "data_dir = \"./ehr_data\"\n",
    "out_dir = \"./output_embeddings\"\n",
    "os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "keywords = {\n",
    "    'asthma': ['asthma'],\n",
    "    'obesity': ['obesity'],\n",
    "    'diabetes': ['diabetes'],\n",
    "    'hypertension': ['hypertension'],\n",
    "    'anxiety': ['anxiety'],\n",
    "    'depression': ['depression', 'mood disorder'],\n",
    "    'adhd': ['adhd'],\n",
    "    'seizure': ['seizure', 'epilepsy'],\n",
    "    'gerd': ['reflux', 'gerd', 'gastroesophageal'],\n",
    "    'cerebral_palsy': ['cerebral palsy'],\n",
    "    'autism': ['autism'],\n",
    "    'dev_delay': ['developmental delay', 'speech delay', 'language delay'],\n",
    "}\n",
    "keyword_names = list(keywords.keys())\n",
    "\n",
    "print(\"Loading files...\")\n",
    "diag = pd.read_csv(os.path.join(data_dir, \"diagnosis.csv\"), dtype=str)\n",
    "demo = pd.read_csv(os.path.join(data_dir, \"demographic.csv\"), dtype=str)\n",
    "study = pd.read_csv(os.path.join(data_dir, \"sleep_study.csv\"), dtype=str)\n",
    "study[\"AGE_AT_SLEEP_STUDY_DAYS\"] = pd.to_numeric(study[\"AGE_AT_SLEEP_STUDY_DAYS\"], errors='coerce')\n",
    "demo = demo.set_index(\"STUDY_PAT_ID\")\n",
    "\n",
    "print(\"Checking existing embeddings...\")\n",
    "existing_pairs = set()\n",
    "for file in glob(os.path.join(out_dir, \"*_embeddings.npy\")):\n",
    "    base = os.path.basename(file).replace(\"_embeddings.npy\", \"\")\n",
    "    existing_pairs.add(base)\n",
    "print(f\"Found {len(existing_pairs)} embedding pairs.\")\n",
    "\n",
    "print(\"Extracting comorbidity flags...\")\n",
    "diag['DX_NAME'] = diag['DX_NAME'].fillna(\"\").str.lower()\n",
    "patient_flags = defaultdict(lambda: [0] * len(keyword_names))\n",
    "for _, row in diag.iterrows():\n",
    "    pid = row['STUDY_PAT_ID']\n",
    "    dx_name = row['DX_NAME']\n",
    "    for j, name in enumerate(keyword_names):\n",
    "        if any(k in dx_name for k in keywords[name]):\n",
    "            patient_flags[pid][j] = 1\n",
    "\n",
    "gender_map = {\"F\": 0, \"M\": 1}\n",
    "race_list = [\"WHITE\", \"BLACK\", \"ASIAN\", \"OTHER\", \"MULTIRACIAL\", \"UNKNOWN\"]\n",
    "race_map = {\n",
    "    \"WHITE\": \"WHITE\",\n",
    "    \"BLACK OR AFRICAN AMERICAN\": \"BLACK\",\n",
    "    \"ASIAN\": \"ASIAN\",\n",
    "    \"MULTIPLE RACE\": \"MULTIRACIAL\",\n",
    "    \"REFUSE TO ANSWER\": \"UNKNOWN\",\n",
    "    \"UNKNOWN\": \"UNKNOWN\",\n",
    "    \"NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER\": \"OTHER\",\n",
    "    \"AMERICAN INDIAN OR ALASKA NATIVE\": \"OTHER\",\n",
    "}\n",
    "\n",
    "def build_ehr_vector(pid, sid):\n",
    "    eid = f\"{pid}_{sid}\"\n",
    "    if eid not in existing_pairs:\n",
    "        return None\n",
    "    row = study[(study[\"STUDY_PAT_ID\"] == pid) & (study[\"SLEEP_STUDY_ID\"] == sid)]\n",
    "    if row.empty:\n",
    "        return None\n",
    "    age = row[\"AGE_AT_SLEEP_STUDY_DAYS\"].values[0]\n",
    "    age = float(age) if not pd.isnull(age) else np.nan\n",
    "\n",
    "    if pid in demo.index:\n",
    "        gender = demo.loc[pid, \"PCORI_GENDER_CD\"].upper()\n",
    "        gender_onehot = [0, 0, 0]\n",
    "        gender_onehot[gender_map.get(gender, 2)] = 1\n",
    "        race_raw = str(demo.loc[pid, \"RACE_DESCR\"]).strip().upper()\n",
    "        race_clean = race_map.get(race_raw, \"UNKNOWN\")\n",
    "        race_onehot = [int(race_clean == r) for r in race_list]\n",
    "        hispanic = int(str(demo.loc[pid, \"PCORI_HISPANIC_CD\"]).upper() == \"Y\")\n",
    "    else:\n",
    "        gender_onehot = [0, 0, 1]\n",
    "        race_onehot = [0] * len(race_list)\n",
    "        hispanic = 0\n",
    "\n",
    "    comorbs = patient_flags.get(pid, [0] * len(keyword_names))\n",
    "    return np.array([age] + gender_onehot + race_onehot + [hispanic] + comorbs, dtype=np.float32)\n",
    "\n",
    "pairs = study[[\"STUDY_PAT_ID\", \"SLEEP_STUDY_ID\"]].dropna().drop_duplicates()\n",
    "count = 0\n",
    "\n",
    "print(\"Building and saving EHR features...\")\n",
    "for _, (pid, sid) in pairs.iterrows():\n",
    "    eid = f\"{pid}_{sid}\"\n",
    "    vec = build_ehr_vector(pid, sid)\n",
    "    if vec is not None:\n",
    "        outpath = os.path.join(out_dir, f\"{eid}_ehr_feature.npy\")\n",
    "        np.save(outpath, vec)\n",
    "        count += 1\n",
    "\n",
    "print(f\"Saved {count} EHR feature files.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ba3d74c-8161-448f-8e73-45c4b27967d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "ehr_files = sorted(glob(os.path.join(feature_dir, \"*_ehr_feature.npy\")))\n",
    "\n",
    "gender_labels = [\"Female\", \"Male\", \"Other\"]\n",
    "race_labels = [\"White\", \"Black\", \"Asian\", \"Other\", \"Multiracial\", \"Unknown\"]\n",
    "comorbidity_labels = [\n",
    "    'asthma', 'obesity', 'diabetes', 'hypertension', 'anxiety',\n",
    "    'depression', 'adhd', 'seizure', 'gerd', 'cerebral_palsy', 'autism', 'dev_delay'\n",
    "]\n",
    "\n",
    "n_features = 1 + len(gender_labels) + len(race_labels) + 1 + len(comorbidity_labels)\n",
    "\n",
    "print(f\"Loading {len(ehr_files)} vectors...\")\n",
    "data, ids = [], []\n",
    "\n",
    "for f in ehr_files:\n",
    "    vec = np.load(f)\n",
    "    if vec.shape[0] == n_features:\n",
    "        data.append(vec)\n",
    "        ids.append(os.path.basename(f).replace(\"_ehr_feature.npy\", \"\"))\n",
    "    else:\n",
    "        print(f\"Skipped {f}, wrong shape {vec.shape}\")\n",
    "\n",
    "df = pd.DataFrame(\n",
    "    data, index=ids,\n",
    "    columns=[\"age_days\"] + gender_labels + race_labels + [\"hispanic\"] + comorbidity_labels\n",
    ")\n",
    "\n",
    "print(\"Dataframe shape:\", df.shape)\n",
    "\n",
    "print(\"\\nSummary Statistics:\")\n",
    "print(df.describe(include='all'))\n",
    "\n",
    "plt.figure()\n",
    "df['age_days'].hist(bins=30)\n",
    "plt.title(\"Age Distribution\")\n",
    "plt.xlabel(\"Age (days)\")\n",
    "plt.ylabel(\"Count\")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "df_gender = df[gender_labels].sum().sort_values(ascending=False)\n",
    "df_gender.plot(kind='bar', title=\"Gender Distribution\", figsize=(6, 4))\n",
    "plt.ylabel(\"Count\")\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "df_race = df[race_labels].sum().sort_values(ascending=False)\n",
    "df_race.plot(kind='bar', title=\"Race Distribution\", figsize=(6, 4))\n",
    "plt.ylabel(\"Count\")\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "df_comorb = df[comorbidity_labels].sum().sort_values(ascending=False)\n",
    "df_comorb.plot(kind='barh', title=\"Comorbidity Prevalence\", figsize=(7, 6))\n",
    "plt.xlabel(\"Count\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
