{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52a79c3a-1865-45bf-8ee1-78e34ebe3537",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================== 0) Setup & paths ==================\n",
    "import os, gc, json, random\n",
    "import numpy as np, pandas as pd\n",
    "import seaborn as sns, matplotlib.pyplot as plt\n",
    "\n",
    "from scipy import stats\n",
    "from statsmodels.stats.multitest import multipletests\n",
    "\n",
    "sns.set_context(\"talk\"); sns.set_style(\"whitegrid\")\n",
    "\n",
    "DATA_DIR    = \"./output_embeddings\"\n",
    "FEATURE_DIR = \"./tda_feature_outputs\"\n",
    "BASE_FEAT   = os.path.join(FEATURE_DIR, \"features_base.parquet\")\n",
    "OUTPUT_DIR  = \"./tda_analysis_outputs\"\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "\n",
    "TIME_FEATURES_PATH = None\n",
    "\n",
    "SEED   = 42\n",
    "N_PERM = 5000\n",
    "ALPHA  = 0.05\n",
    "rng    = np.random.default_rng(SEED)\n",
    "\n",
    "TDA_FEATURES = [\n",
    "    \"betti_L1\",\"betti_L2\",\n",
    "    \"H1_sum_pers\",\"H1_max_pers\",\"H1_n_bars\",\"H0_sum_pers\",\"H0_n_bars\",\n",
    "    \"ratio_sum_H1_H0\",\"ratio_count_H1_H0\",\"midlife_gap_H1_H0\",\"entropy_gap_H1_H0\",\n",
    "    \"BettiH1_AUC\",\"BettiH1_peak_loc_norm\",\n",
    "    \"PI_H1_long\",\"PI_H1_short\",\"PI_H1_long_short_ratio\",\n",
    "]\n",
    "\n",
    "print(\"[cfg] BASE_FEAT:\", BASE_FEAT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42425158-a925-44fb-923f-c2b34e2feb48",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1) Index computation\n",
    "SLEEP_STAGES_ASLEEP = {1, 2, 3, 4}  # 0=Wake, others = asleep\n",
    "\n",
    "def session_paths(sid: str, data_dir: str = DATA_DIR) -> dict:\n",
    "    return {\n",
    "        \"evt1\": os.path.join(data_dir, f\"{sid}_event1.npy\"),\n",
    "        \"evt2\": os.path.join(data_dir, f\"{sid}_event2.npy\"),\n",
    "        \"sleep\": os.path.join(data_dir, f\"{sid}_sleep.npy\"),\n",
    "    }\n",
    "\n",
    "def compute_index_from_epoch_labels(sid: str) -> dict:\n",
    "    p = session_paths(sid)\n",
    "    if not all(os.path.exists(v) for v in p.values()):\n",
    "        return {\"id\": sid, \"index_all\": np.nan, \"TST_hours\": np.nan, \"n_events\": np.nan}\n",
    "\n",
    "    evt1 = np.load(p[\"evt1\"])\n",
    "    evt2 = np.load(p[\"evt2\"])\n",
    "    sleep = np.load(p[\"sleep\"])\n",
    "\n",
    "    asleep_epochs = np.isin(sleep.astype(int), list(SLEEP_STAGES_ASLEEP)).sum()\n",
    "    TST_hours = asleep_epochs / 120.0\n",
    "    if TST_hours <= 0:\n",
    "        return {\"id\": sid, \"index_all\": np.nan, \"TST_hours\": 0.0, \"n_events\": 0}\n",
    "\n",
    "    n_events = np.logical_or(evt1 > 0, evt2 > 0).sum()\n",
    "    index_all = n_events / TST_hours\n",
    "\n",
    "    return {\"id\": sid, \"index_all\": float(index_all), \"TST_hours\": float(TST_hours), \"n_events\": int(n_events)}\n",
    "\n",
    "def group_index(val: float) -> tuple[int, str]:\n",
    "    \"\"\"Return (group_id, group_name) per thresholds.\"\"\"\n",
    "    if not np.isfinite(val):\n",
    "        return (-1, \"missing\")\n",
    "    if val < 1:   return (0, \"grp0(<1)\")\n",
    "    if val < 5:   return (1, \"grp1(1-5)\")\n",
    "    if val < 10:  return (2, \"grp2(5-10)\")\n",
    "    else:         return (3, \"grp3(>=10)\")\n",
    "\n",
    "df_feat = pd.read_parquet(BASE_FEAT)\n",
    "ids = df_feat[\"id\"].astype(str).unique().tolist()\n",
    "rows = [compute_index_from_epoch_labels(sid) for sid in ids]\n",
    "df_idx = pd.DataFrame(rows)\n",
    "\n",
    "grp_id, grp_nm = [], []\n",
    "for v in df_idx[\"index_all\"].values:\n",
    "    g, n = group_index(v)\n",
    "    grp_id.append(g); grp_nm.append(n)\n",
    "df_idx[\"grp_id\"] = grp_id\n",
    "df_idx[\"grp_name\"] = grp_nm\n",
    "\n",
    "print(df_idx.describe(include=\"all\"))\n",
    "print(df_idx[\"grp_name\"].value_counts().to_string())\n",
    "\n",
    "map_path = os.path.join(OUTPUT_DIR, \"index_mapping.csv\")\n",
    "df_idx.to_csv(map_path, index=False)\n",
    "print(\"[save] index map ->\", map_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99271c8e-af33-4ac8-89da-a17e8db95643",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2) Merge with feature tables\n",
    "df_base = df_feat.drop(columns=[c for c in [\"y_sess\",\"label_name\",\"source\"] if c in df_feat.columns], errors=\"ignore\")\n",
    "\n",
    "if TIME_FEATURES_PATH and os.path.exists(TIME_FEATURES_PATH):\n",
    "    df_time = pd.read_parquet(TIME_FEATURES_PATH) if TIME_FEATURES_PATH.endswith(\".parquet\") else pd.read_csv(TIME_FEATURES_PATH)\n",
    "    df_time = df_time.add_prefix(\"time_\").rename(columns={\"time_id\": \"id\"})\n",
    "    df_merged = df_base.merge(df_time, on=\"id\", how=\"left\")\n",
    "else:\n",
    "    df_merged = df_base.copy()\n",
    "\n",
    "df_merged = df_merged.merge(\n",
    "    df_idx[[\"id\", \"index_all\", \"grp_id\", \"grp_name\", \"TST_hours\", \"n_events\"]],\n",
    "    on=\"id\", how=\"left\"\n",
    ")\n",
    "\n",
    "df_merged = df_merged[df_merged[\"grp_id\"].between(0, 3)].reset_index(drop=True)\n",
    "print(\"[info] merged shape:\", df_merged.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f90b7ef-553f-44be-a9b8-2900033e6ba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================== (NEW) Build GLOBAL time features ==================\n",
    "TIME_KEEP_PER_DIM = False\n",
    "TIME_MAX_DIMS     = 16\n",
    "\n",
    "def time_feature_path(sid: str, data_dir: str = DATA_DIR) -> str:\n",
    "    return os.path.join(data_dir, f\"{sid}_time_features.npy\")\n",
    "\n",
    "def summarize_time_array(A: np.ndarray) -> dict:\n",
    "    A = np.asarray(A, dtype=float)\n",
    "    out = {}\n",
    "    if A.ndim == 1:\n",
    "        v = np.nan_to_num(A, nan=0.0, posinf=0.0, neginf=0.0)\n",
    "        out.update({\n",
    "            \"time_f1\": float(np.mean(v)),\n",
    "            \"time_f2\": float(np.std(v, ddof=1)),\n",
    "            \"time_f3\": float(np.linalg.norm(v)),\n",
    "        })\n",
    "    elif A.ndim == 2:\n",
    "        A = np.nan_to_num(A, nan=0.0, posinf=0.0, neginf=0.0)\n",
    "        mu = np.mean(A, axis=0)\n",
    "        sd = np.std(A, axis=0, ddof=1)\n",
    "        out.update({\n",
    "            \"time_f1\": float(np.mean(mu)),\n",
    "            \"time_f2\": float(np.std(mu, ddof=1)),\n",
    "            \"time_f3\": float(np.mean(sd)),\n",
    "            \"time_f4\": float(np.std(sd, ddof=1)),\n",
    "            \"time_f5\": float(np.linalg.norm(mu)),\n",
    "            \"time_f6\": float(np.linalg.norm(sd)),\n",
    "        })\n",
    "    else:\n",
    "        v = np.nan_to_num(A.reshape(-1), nan=0.0, posinf=0.0, neginf=0.0)\n",
    "        out.update({\n",
    "            \"time_f1\": float(np.mean(v)),\n",
    "            \"time_f2\": float(np.std(v, ddof=1)),\n",
    "            \"time_f3\": float(np.linalg.norm(v)),\n",
    "        })\n",
    "    return out\n",
    "\n",
    "def build_time_df_from_files(ids) -> pd.DataFrame:\n",
    "    rows = []\n",
    "    for sid in ids:\n",
    "        p = time_feature_path(sid)\n",
    "        if not os.path.exists(p):\n",
    "            continue\n",
    "        try:\n",
    "            arr = np.load(p, allow_pickle=False)\n",
    "            feats = summarize_time_array(arr)\n",
    "            feats[\"id\"] = str(sid)\n",
    "            rows.append(feats)\n",
    "        except Exception as e:\n",
    "            print(f\"[time warn] {sid}: {e}\")\n",
    "    if not rows:\n",
    "        return pd.DataFrame(columns=[\"id\"])\n",
    "    return pd.DataFrame(rows).replace([np.inf, -np.inf], np.nan).fillna(0.0)\n",
    "\n",
    "# Build table for IDs present in base feature file\n",
    "ids = df_feat[\"id\"].astype(str).unique().tolist()\n",
    "df_time = build_time_df_from_files(ids)\n",
    "print(\"[info] time feature shape:\", df_time.shape)\n",
    "\n",
    "# Merge with base + index map\n",
    "df_base = df_feat.drop(columns=[c for c in [\"y_sess\",\"label_name\",\"source\"] if c in df_feat.columns], errors=\"ignore\")\n",
    "df_merged = (\n",
    "    df_base.merge(df_time, on=\"id\", how=\"left\")\n",
    "           .merge(df_idx[[\"id\",\"index_all\",\"grp_id\",\"grp_name\",\"TST_hours\",\"n_events\"]],\n",
    "                  on=\"id\", how=\"left\")\n",
    ")\n",
    "df_merged = df_merged[df_merged[\"grp_id\"].between(0,3)].reset_index(drop=True)\n",
    "print(\"[info] merged shape (base + time):\", df_merged.shape)\n",
    "\n",
    "# Fixed feature lists (6 + 6)\n",
    "TDA_FEATURES = [\"tda_f1\",\"tda_f2\",\"tda_f3\",\"tda_f4\",\"tda_f5\",\"tda_f6\"]\n",
    "TIME_FEATURES = [\"time_f1\",\"time_f2\",\"time_f3\",\"time_f4\",\"time_f5\",\"time_f6\"]\n",
    "ALL_FEATURES = [f for f in TDA_FEATURES + TIME_FEATURES if f in df_merged.columns]\n",
    "print(f\"[info] testing {len(ALL_FEATURES)} features\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c15d00c-734a-4236-a96a-d931eea8175a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_merged"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acedf9f2-569b-451e-9c62-d8a92acf0f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================== (UPDATED) Permutation tests ==================\n",
    "feat_cols = ALL_FEATURES\n",
    "results = []\n",
    "g_all = df_merged[\"grp_id\"].astype(int).values\n",
    "\n",
    "for c in feat_cols:\n",
    "    v = df_merged[c].astype(float).values\n",
    "    m = np.isfinite(v)\n",
    "    x = v[m]; g = g_all[m]\n",
    "    if len(x) < 10:\n",
    "        continue\n",
    "\n",
    "    # Global KW across groups\n",
    "    H = stat_kw_H(x, g)\n",
    "    p_kw, null_kw = perm_p_value(H, stat_kw_H, x, g, N_PERM, rng)\n",
    "    results.append(dict(feature=c, test=\"PermKW_global\", stat=H, p=p_kw, q=np.nan,\n",
    "                        effect_median_diff=np.nan, cliffs_delta=np.nan, n_perm=len(null_kw)))\n",
    "\n",
    "    # One-vs-rest MWU (per group)\n",
    "    for k in sorted(np.unique(g)):\n",
    "        obs = stat_mwu_umax_ova(x, g, pos_label=k)\n",
    "        p_mwu, null_mwu = perm_p_value(\n",
    "            obs, lambda xv, gv: stat_mwu_umax_ova(xv, gv, pos_label=k), x, g, N_PERM, rng\n",
    "        )\n",
    "        xk, xr = x[g==k], x[g!=k]\n",
    "        eff_md = float(np.median(xk) - np.median(xr)) if len(xk) and len(xr) else np.nan\n",
    "        try: \n",
    "            delta = cliffs_delta(xk, xr)\n",
    "        except Exception: \n",
    "            delta = np.nan\n",
    "        results.append(dict(feature=c, test=f\"PermMWU_group{k}_vs_all\", stat=obs, p=p_mwu, q=np.nan,\n",
    "                            effect_median_diff=eff_md, cliffs_delta=delta, n_perm=len(null_mwu)))\n",
    "\n",
    "res = pd.DataFrame(results)\n",
    "if res.empty:\n",
    "    print(\"[warn] no tests ran.\")\n",
    "else:\n",
    "    for tname, sub in res.groupby(\"test\", sort=False):\n",
    "        _, qvals, *_ = multipletests(sub[\"p\"].values, alpha=ALPHA, method=\"fdr_bh\")\n",
    "        res.loc[sub.index, \"q\"] = qvals\n",
    "    res.sort_values([\"q\",\"p\",\"feature\"], inplace=True)\n",
    "\n",
    "    out_csv = os.path.join(OUTPUT_DIR, \"significance_groups_TDAplusTIME.csv\")\n",
    "    res.to_csv(out_csv, index=False)\n",
    "    print(\"[save] significance ->\", out_csv)\n",
    "    print(res[[\"feature\",\"test\",\"stat\",\"p\",\"q\",\"n_perm\",\"effect_median_diff\",\"cliffs_delta\"]]\n",
    "          .head(40).to_string(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d777f7e3-ed89-4f13-a7bd-857b68599ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- EHR-only permutation tests ---\n",
    "\n",
    "feat_cols = EHR_FEATURES\n",
    "g_all = df_merged[\"grp_id\"].astype(int).values\n",
    "\n",
    "results = []\n",
    "for c in feat_cols:\n",
    "    v = df_merged[c].astype(float).values\n",
    "    m = np.isfinite(v)\n",
    "    x = v[m]; g = g_all[m]\n",
    "    if len(x) < 10:\n",
    "        continue\n",
    "\n",
    "    # Global KW across groups\n",
    "    H = stat_kw_H(x, g)\n",
    "    p_kw, null_kw = perm_p_value(H, stat_kw_H, x, g, N_PERM, rng)\n",
    "    results.append(dict(\n",
    "        feature=c, test=\"PermKW_global\", stat=H, p=p_kw, q=np.nan,\n",
    "        effect_median_diff=np.nan, cliffs_delta=np.nan, n_perm=len(null_kw)\n",
    "    ))\n",
    "\n",
    "    # One-vs-rest MWU\n",
    "    for k in sorted(np.unique(g)):\n",
    "        obs = stat_mwu_umax_ova(x, g, pos_label=k)\n",
    "        p_mwu, null_mwu = perm_p_value(\n",
    "            obs, lambda xv, gv: stat_mwu_umax_ova(xv, gv, pos_label=k),\n",
    "            x, g, N_PERM, rng\n",
    "        )\n",
    "        xk, xr = x[g==k], x[g!=k]\n",
    "        eff_md = float(np.median(xk) - np.median(xr)) if len(xk) and len(xr) else np.nan\n",
    "        try:\n",
    "            delta = cliffs_delta(xk, xr)\n",
    "        except Exception:\n",
    "            delta = np.nan\n",
    "        results.append(dict(\n",
    "            feature=c, test=f\"PermMWU_group{k}_vs_all\", stat=obs, p=p_mwu, q=np.nan,\n",
    "            effect_median_diff=eff_md, cliffs_delta=delta, n_perm=len(null_mwu)\n",
    "        ))\n",
    "\n",
    "res_ehr = pd.DataFrame(results)\n",
    "\n",
    "if res_ehr.empty:\n",
    "    print(\"[warn] no tests ran.\")\n",
    "else:\n",
    "    for tname, sub in res_ehr.groupby(\"test\", sort=False):\n",
    "        _, qvals, *_ = multipletests(sub[\"p\"].values, alpha=0.05, method=\"fdr_bh\")\n",
    "        res_ehr.loc[sub.index, \"q\"] = qvals\n",
    "\n",
    "    res_ehr.sort_values([\"q\",\"p\",\"feature\"], inplace=True)\n",
    "\n",
    "    out_csv = os.path.join(OUTPUT_DIR, \"significance_groups_EHR.csv\")\n",
    "    res_ehr.to_csv(out_csv, index=False)\n",
    "    print(\"[save] EHR-only significance ->\", out_csv)\n",
    "    display(res_ehr[[\"feature\",\"test\",\"stat\",\"p\",\"q\",\"n_perm\",\"effect_median_diff\",\"cliffs_delta\"]].head(40))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4cbfa83-d853-4d81-9cb9-45081d79e806",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc3cead4-7a47-4529-88ba-e5f28568e928",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rcParams['font.size'] = 11\n",
    "\n",
    "TIME_FEATURES = [\"time_mean\", \"time_std\", \"time_l2\"]\n",
    "\n",
    "def plot_time_features_by_group(df_merged):\n",
    "    \"\"\"Visualize time features across groups\"\"\"\n",
    "    \n",
    "    available_feats = [f for f in TIME_FEATURES if f in df_merged.columns]\n",
    "    if not available_feats:\n",
    "        print(\"No time features found in dataframe\")\n",
    "        return\n",
    "        \n",
    "    print(f\"Plotting {len(available_feats)} time features across groups\")\n",
    "    \n",
    "    for feat in available_feats:\n",
    "        d = df_merged[[\"grp_name\", feat]].dropna()\n",
    "        if d.empty:\n",
    "            continue\n",
    "            \n",
    "        groups = sorted(d[\"grp_name\"].unique())\n",
    "        \n",
    "        fig, axes = plt.subplots(2, 2, figsize=(15, 12))\n",
    "        fig.suptitle(f\"Time Feature: {feat}\\nDistribution across Groups\", fontsize=16, fontweight='bold')\n",
    "        \n",
    "        # Boxplot + points\n",
    "        sns.boxplot(data=d, x=\"grp_name\", y=feat, ax=axes[0,0], showfliers=False, order=groups)\n",
    "        sns.stripplot(data=d, x=\"grp_name\", y=feat, ax=axes[0,0], alpha=0.5, size=3, order=groups)\n",
    "        axes[0,0].set_title(\"Boxplot with Data Points\")\n",
    "        axes[0,0].tick_params(axis='x', rotation=45)\n",
    "        for i, group in enumerate(groups):\n",
    "            n = len(d[d[\"grp_name\"] == group])\n",
    "            axes[0,0].text(i, d[feat].max() * 1.02, f'n={n}', ha='center', va='bottom', fontsize=9)\n",
    "        \n",
    "        # ECDF\n",
    "        colors = plt.cm.Set3(np.linspace(0, 1, len(groups)))\n",
    "        for i, group in enumerate(groups):\n",
    "            group_data = d[d[\"grp_name\"] == group][feat].values\n",
    "            sns.ecdfplot(group_data, ax=axes[0,1], label=group, color=colors[i], linewidth=2)\n",
    "        axes[0,1].legend(title=\"Group\")\n",
    "        axes[0,1].set_title(\"Empirical Cumulative Distribution\")\n",
    "        axes[0,1].set_xlabel(feat)\n",
    "        axes[0,1].set_ylabel(\"Cumulative Probability\")\n",
    "        \n",
    "        # Density\n",
    "        for i, group in enumerate(groups):\n",
    "            group_data = d[d[\"grp_name\"] == group][feat].values\n",
    "            if len(group_data) > 1:\n",
    "                density = gaussian_kde(group_data)\n",
    "                x_vals = np.linspace(d[feat].min(), d[feat].max(), 100)\n",
    "                axes[1,0].plot(x_vals, density(x_vals), label=group, color=colors[i], linewidth=2, alpha=0.8)\n",
    "        axes[1,0].legend(title=\"Group\")\n",
    "        axes[1,0].set_title(\"Probability Density\")\n",
    "        axes[1,0].set_xlabel(feat)\n",
    "        axes[1,0].set_ylabel(\"Density\")\n",
    "        \n",
    "        # Mean ± CI\n",
    "        means, cis, labels = [], [], []\n",
    "        for group in groups:\n",
    "            group_data = d[d[\"grp_name\"] == group][feat].values\n",
    "            if len(group_data) > 0:\n",
    "                mean_val = np.mean(group_data)\n",
    "                se_val = np.std(group_data) / np.sqrt(len(group_data)) if len(group_data) > 1 else 0\n",
    "                ci_val = 1.96 * se_val\n",
    "                means.append(mean_val); cis.append(ci_val); labels.append(group)\n",
    "        \n",
    "        axes[1,1].errorbar(range(len(means)), means, yerr=cis, fmt='o',\n",
    "                           capsize=5, capthick=2, markersize=8, linewidth=2)\n",
    "        axes[1,1].set_xticks(range(len(means)))\n",
    "        axes[1,1].set_xticklabels(labels, rotation=45)\n",
    "        axes[1,1].set_title(\"Mean ± 95% CI\")\n",
    "        axes[1,1].set_xlabel(\"Group\")\n",
    "        axes[1,1].set_ylabel(f\"Mean {feat}\")\n",
    "        axes[1,1].grid(True, alpha=0.3)\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        \n",
    "        # Summary stats\n",
    "        print(f\"\\n{feat} - Summary Statistics:\")\n",
    "        for group in groups:\n",
    "            group_data = d[d[\"grp_name\"] == group][feat]\n",
    "            if len(group_data) > 0:\n",
    "                print(f\"  {group}: n={len(group_data)}, mean={group_data.mean():.3f} ± {group_data.std():.3f}\")\n",
    "\n",
    "# Run\n",
    "plot_time_features_by_group(df_merged)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d88597a0-3470-4dde-a7b7-56ce566a7a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "outdir = \"./tda_analysis_outputs\"\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "\n",
    "def plot_time_features_by_group_pretty(\n",
    "    df_merged,\n",
    "    features=(\"time_f1\",),\n",
    "    group_col=\"grp_name\",\n",
    "    outdir=outdir,\n",
    "    zoom_quantiles=(0.02, 0.98),\n",
    "    show_points=True,\n",
    "    pdf_xlim=(-1, 1),\n",
    "    ecdf_xlim=(-1, 1)\n",
    "):\n",
    "    import numpy as np, seaborn as sns, matplotlib.pyplot as plt\n",
    "    from scipy.stats import gaussian_kde\n",
    "\n",
    "    sns.set_style(\"whitegrid\")\n",
    "    palette = sns.color_palette(\"tab10\")\n",
    "\n",
    "    groups = sorted(df_merged[group_col].unique())\n",
    "    color_map = {g: palette[i % len(palette)] for i, g in enumerate(groups)}\n",
    "\n",
    "    for feat in [f for f in features if f in df_merged.columns]:\n",
    "        d = df_merged[[group_col, feat]].dropna()\n",
    "        if d.empty:\n",
    "            continue\n",
    "\n",
    "        # ---------- Box + strip ----------\n",
    "        lo, hi = np.nanquantile(d[feat], zoom_quantiles)\n",
    "        pad = 0.05 * (hi - lo) if hi > lo else 1.0\n",
    "        fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)\n",
    "        sns.boxplot(data=d, x=group_col, y=feat, order=groups,\n",
    "                    ax=ax, showfliers=False, width=0.55, linewidth=1.25)\n",
    "        if show_points:\n",
    "            sns.stripplot(data=d, x=group_col, y=feat, order=groups,\n",
    "                          ax=ax, alpha=0.45, size=2.8, jitter=0.15, color=\"k\")\n",
    "        ax.set_ylim(lo - pad, hi + pad)\n",
    "        ax.set_title(\"Distribution by Group\")\n",
    "        ax.set_xlabel(\"Group\"); ax.set_ylabel(feat)\n",
    "        path = os.path.join(outdir, f\"{feat}_box.pdf\")\n",
    "        fig.savefig(path, dpi=450); plt.close(fig)\n",
    "        print(f\"Saved {path}\")\n",
    "\n",
    "        # ---------- ECDF ----------\n",
    "        fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)\n",
    "        for g in groups:\n",
    "            vals = d.loc[d[group_col] == g, feat].to_numpy()\n",
    "            if vals.size:\n",
    "                sns.ecdfplot(vals, ax=ax, label=str(g),\n",
    "                             color=color_map[g], linewidth=2)\n",
    "        ax.set_xlim(*ecdf_xlim)\n",
    "        ax.set_title(\"Empirical CDF\")\n",
    "        ax.set_xlabel(feat); ax.set_ylabel(\"Proportion\")\n",
    "        ax.legend(title=\"Group\", frameon=True, framealpha=0.9,\n",
    "                  facecolor=\"white\", fontsize=9, title_fontsize=10, loc=\"lower right\")\n",
    "        path = os.path.join(outdir, f\"{feat}_ecdf.pdf\")\n",
    "        fig.savefig(path, dpi=450); plt.close(fig)\n",
    "        print(f\"Saved {path}\")\n",
    "\n",
    "        # ---------- KDE ----------\n",
    "        fig, ax = plt.subplots(figsize=(12, 4), constrained_layout=True)\n",
    "        X = np.linspace(pdf_xlim[0], pdf_xlim[1], 500)\n",
    "        for g in groups:\n",
    "            vals = d.loc[d[group_col] == g, feat].to_numpy()\n",
    "            vals = vals[np.isfinite(vals)]\n",
    "            if vals.size > 1:\n",
    "                kde = gaussian_kde(vals)\n",
    "                ax.plot(X, kde(X), label=str(g),\n",
    "                        color=color_map[g], linewidth=2)\n",
    "        ax.set_xlim(*pdf_xlim)\n",
    "        ax.set_title(\"Probability Density\")\n",
    "        ax.set_xlabel(feat); ax.set_ylabel(\"Density\")\n",
    "        ax.legend(title=\"Group\", frameon=True, framealpha=0.9,\n",
    "                  facecolor=\"white\", fontsize=9, title_fontsize=10,\n",
    "                  loc=\"upper right\", ncol=min(len(groups), 2))\n",
    "        path = os.path.join(outdir, f\"{feat}_kde.pdf\")\n",
    "        fig.savefig(path, dpi=450); plt.close(fig)\n",
    "        print(f\"Saved {path}\")\n",
    "\n",
    "        # ---------- Mean ± CI ----------\n",
    "        fig, ax = plt.subplots(figsize=(7.2, 4.6), constrained_layout=True)\n",
    "        means, cis, labels = [], [], []\n",
    "        for g in groups:\n",
    "            vals = d.loc[d[group_col] == g, feat].to_numpy()\n",
    "            if vals.size:\n",
    "                m = float(np.nanmean(vals))\n",
    "                se = float(np.nanstd(vals, ddof=1)) / np.sqrt(np.count_nonzero(np.isfinite(vals))) if vals.size > 1 else 0.0\n",
    "                means.append(m); cis.append(1.96 * se); labels.append(str(g))\n",
    "        ax.errorbar(range(len(means)), means, yerr=cis, fmt=\"o\",\n",
    "                    capsize=5, capthick=1.6, markersize=6.5, linewidth=2)\n",
    "        ax.set_xticks(range(len(means))); ax.set_xticklabels(labels)\n",
    "        ax.set_title(\"Mean ± 95% CI\")\n",
    "        ax.set_xlabel(\"Group\"); ax.set_ylabel(feat)\n",
    "        ax.grid(alpha=0.3)\n",
    "        path = os.path.join(outdir, f\"{feat}_meanCI.pdf\")\n",
    "        fig.savefig(path, dpi=450); plt.close(fig)\n",
    "        print(f\"Saved {path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b04a231-08eb-41e9-9556-01927576e17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_time_features_by_ahi_pretty(\n",
    "    df_merged,\n",
    "    features=(\"time_mean\",),   # you can add \"time_l2\" etc if you want\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
