{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5ce5a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import pandas as pd\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# Valid penalties & importances\n",
    "# ---------------------------------------------------------\n",
    "\n",
    "VALID_PENALTIES = {\n",
    "    \"Fast_Shap\",\n",
    "    \"Shapley\",\n",
    "    \"Jacob_L1\",\n",
    "    \"Jacob_F\"\n",
    "}\n",
    "\n",
    "VALID_IMPORTANCES = {\n",
    "    \"Shapley\",\n",
    "    \"Jacobian\"\n",
    "}\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# Small helpers\n",
    "# ---------------------------------------------------------\n",
    "\n",
    "def infer_importance_from_penalty(penalty: str) -> str | None:\n",
    "    if penalty in {\"Fast_Shap\", \"Shapley\"}:\n",
    "        return \"Shapley\"\n",
    "    elif penalty in {\"Jacob_F\", \"Jacob_L1\"}:\n",
    "        return \"Jacobian\"\n",
    "    elif penalty == \"Layer_Weight\":\n",
    "        return \"Layer_Weight\"\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "\n",
    "def parse_filename(path: str, expected_dataset: str | None = None) -> tuple | None:\n",
    "    name = os.path.splitext(os.path.basename(path))[0]\n",
    "    parts = name.split(\"_\")\n",
    "\n",
    "    offset = 0\n",
    "    if parts[0].lower() in {\"ablation\", \"simulation\"}:\n",
    "        offset = 1\n",
    "\n",
    "    if len(parts) - offset < 5:\n",
    "        return None\n",
    "\n",
    "    dataset = parts[offset]\n",
    "    series_str = parts[offset + 1]\n",
    "    subject_str = parts[offset + 2]\n",
    "    model = parts[offset + 3]\n",
    "    penalty = \"_\".join(parts[offset + 4:])\n",
    "\n",
    "    if expected_dataset is not None and dataset != expected_dataset:\n",
    "        return None\n",
    "\n",
    "    if penalty not in VALID_PENALTIES:\n",
    "        return None\n",
    "\n",
    "    importance = infer_importance_from_penalty(penalty)\n",
    "    if importance is None or importance not in VALID_IMPORTANCES:\n",
    "        return None\n",
    "\n",
    "    try:\n",
    "        series = int(series_str)\n",
    "        subject = int(subject_str)\n",
    "    except ValueError:\n",
    "        return None\n",
    "\n",
    "    return dataset, series, subject, model, penalty, importance\n",
    "\n",
    "def pick_metric_columns(df: pd.DataFrame, importance: str) -> tuple[str, str] | None:\n",
    "    candidates = [\n",
    "        (f\"AUROC_diag_{importance}\", f\"AUPRC_diag_{importance}\"),\n",
    "        (f\"AUROC_{importance}\", f\"AUPRC_{importance}\"),\n",
    "        (\"AUROC\", \"AUPRC\"),\n",
    "    ]\n",
    "    for auroc_col, auprc_col in candidates:\n",
    "        if auroc_col in df.columns and auprc_col in df.columns:\n",
    "            return auroc_col, auprc_col\n",
    "    return None\n",
    "\n",
    "def detect_param_and_metric_cols(df: pd.DataFrame):\n",
    "    \"\"\"\n",
    "    Everything except AUROC/AUPRC*/val_loss is treated as hyperparameter.\n",
    "    \"\"\"\n",
    "    metric_prefixes = (\"AUROC\", \"AUPRC\")\n",
    "    metric_cols = []\n",
    "    param_cols = []\n",
    "    for col in df.columns:\n",
    "        if col == \"val_loss\":\n",
    "            metric_cols.append(col)\n",
    "        elif col.startswith(metric_prefixes):\n",
    "            metric_cols.append(col)\n",
    "        else:\n",
    "            param_cols.append(col)\n",
    "    return param_cols, metric_cols\n",
    "\n",
    "def drop_error_rows(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Drop rows containing 'Error(' or 'BrokenPipeError' anywhere.\n",
    "    \"\"\"\n",
    "    if df.empty:\n",
    "        return df\n",
    "\n",
    "    # Convert everything to string so we can search easily\n",
    "    s = df.astype(str)\n",
    "\n",
    "    error_mask = s.apply(\n",
    "        lambda row: row.str.contains(\"Error(\", regex=False).any()\n",
    "                    or row.str.contains(\"BrokenPipeError\", regex=False).any(),\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    n_bad = error_mask.sum()\n",
    "    if n_bad > 0:\n",
    "        print(f\"[INFO] Dropping {n_bad} rows containing error messages\")\n",
    "\n",
    "    return df.loc[~error_mask].reset_index(drop=True)\n",
    "\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# Main logic: collect metrics & summarize\n",
    "# ---------------------------------------------------------\n",
    "\n",
    "def collect_file_metrics(root_dir: str, dataset: str, series: int) -> pd.DataFrame:\n",
    "    pattern = os.path.join(root_dir, f\"*{dataset}_{series}_*.csv\")\n",
    "    files = glob.glob(pattern)\n",
    "\n",
    "    if not files:\n",
    "        print(f\"[INFO] No files found matching pattern: {pattern}\")\n",
    "\n",
    "    records = []\n",
    "\n",
    "    for path in files:\n",
    "        parsed = parse_filename(path, expected_dataset=dataset)\n",
    "        if parsed is None:\n",
    "            print(f\"[INFO] Skipping file with unrecognized name: {os.path.basename(path)}\")\n",
    "            continue\n",
    "\n",
    "        dataset_i, series_i, subject_i, model, penalty, importance = parsed\n",
    "        if series_i != series:\n",
    "            print(f\"[INFO] Skipping {os.path.basename(path)} due to series mismatch.\")\n",
    "            continue\n",
    "\n",
    "        # 1) Load CSV, skipping malformed lines\n",
    "        try:\n",
    "            df = pd.read_csv(path, on_bad_lines=\"skip\")\n",
    "        except Exception as e:\n",
    "            print(f\"[WARN] Could not load {path}: {e}\")\n",
    "            continue\n",
    "\n",
    "        if df.empty:\n",
    "            print(f\"[WARN] {path} is empty; skipping.\")\n",
    "            continue\n",
    "\n",
    "        # 2) Drop rows that contain error messages like BrokenPipeError(...)\n",
    "        df = drop_error_rows(df)\n",
    "        if df.empty:\n",
    "            print(f\"[WARN] {path} has only error rows after filtering; skipping.\")\n",
    "            continue\n",
    "\n",
    "        # 3) Find AUROC/AUPRC columns for this importance\n",
    "        metric_cols = pick_metric_columns(df, importance)\n",
    "        if metric_cols is None:\n",
    "            print(f\"[WARN] {path} has no recognizable AUROC/AUPRC columns; skipping.\")\n",
    "            continue\n",
    "        auroc_col, auprc_col = metric_cols\n",
    "\n",
    "        # 4) Ensure metric columns are numeric\n",
    "        df[auroc_col] = pd.to_numeric(df[auroc_col], errors=\"coerce\")\n",
    "        df[auprc_col] = pd.to_numeric(df[auprc_col], errors=\"coerce\")\n",
    "\n",
    "        # Drop rows where AUROC is NaN (so idxmax works)\n",
    "        df_valid = df.dropna(subset=[auroc_col])\n",
    "        if df_valid.empty:\n",
    "            print(f\"[WARN] {path}: all {auroc_col} values are NaN after cleaning; skipping.\")\n",
    "            continue\n",
    "\n",
    "        # 5) Take the row with max AUROC\n",
    "        best_idx = df_valid[auroc_col].idxmax()\n",
    "        best_row = df_valid.loc[best_idx]\n",
    "\n",
    "        # Split hyperparameters vs metrics (optional, but you had it)\n",
    "        param_cols, _ = detect_param_and_metric_cols(df_valid)\n",
    "\n",
    "        record = {\n",
    "            \"dataset\": dataset_i,\n",
    "            \"series\": series_i,\n",
    "            \"subject\": subject_i,\n",
    "            \"model\": model,\n",
    "            \"penalty\": penalty,\n",
    "            \"importance\": importance,\n",
    "            \"max_AUROC\": best_row[auroc_col],\n",
    "            \"max_AUPRC\": best_row[auprc_col],\n",
    "        }\n",
    "\n",
    "        # Add hyperparameters from the best row\n",
    "        for col in param_cols:\n",
    "            record[col] = best_row[col]\n",
    "\n",
    "        records.append(record)\n",
    "\n",
    "    if not records:\n",
    "        return pd.DataFrame()\n",
    "\n",
    "    df_out = pd.DataFrame(records)\n",
    "\n",
    "    # Ensure metric columns are numeric\n",
    "    df_out[\"max_AUROC\"] = pd.to_numeric(df_out[\"max_AUROC\"], errors=\"coerce\")\n",
    "    df_out[\"max_AUPRC\"] = pd.to_numeric(df_out[\"max_AUPRC\"], errors=\"coerce\")\n",
    "\n",
    "    return df_out\n",
    "\n",
    "def summarize_over_subjects(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    For each (dataset, series, model, penalty, importance),\n",
    "    compute mean/std of max_AUROC and max_AUPRC over subjects.\n",
    "    \"\"\"\n",
    "    if df.empty:\n",
    "        return df\n",
    "\n",
    "    df = df.copy()\n",
    "    df[\"max_AUROC\"] = pd.to_numeric(df[\"max_AUROC\"], errors=\"coerce\")\n",
    "    df[\"max_AUPRC\"] = pd.to_numeric(df[\"max_AUPRC\"], errors=\"coerce\")\n",
    "    df = df.dropna(subset=[\"max_AUROC\", \"max_AUPRC\"])\n",
    "\n",
    "    grouped = (\n",
    "        df.groupby([\"dataset\", \"series\", \"model\", \"penalty\", \"importance\"])\n",
    "        .agg(\n",
    "            max_AUROC_mean=(\"max_AUROC\", \"mean\"),\n",
    "            max_AUROC_std=(\"max_AUROC\", \"std\"),\n",
    "            max_AUPRC_mean=(\"max_AUPRC\", \"mean\"),\n",
    "            max_AUPRC_std=(\"max_AUPRC\", \"std\"),\n",
    "            n_subjects=(\"subject\", \"nunique\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "    return grouped\n",
    "\n",
    "\n",
    "# ---------------------------------------------------------\n",
    "# Main entry point\n",
    "# ---------------------------------------------------------\n",
    "\n",
    "def main(subdir: str, dataset: str = \"fMRI\", series: int = 1):\n",
    "    base_dir = \"\"\n",
    "\n",
    "    root_dir = os.path.join(base_dir, subdir)\n",
    "    out = os.path.join(base_dir, f\"{subdir}_{dataset}_{series}.csv\")\n",
    "\n",
    "    # print(f\"[INFO] Root directory : {root_dir}\")\n",
    "    print(f\"[INFO] Dataset/series: {dataset}/{series}\")\n",
    "    # print(f\"[INFO] Output file    : {out}\")\n",
    "\n",
    "    df = collect_file_metrics(root_dir, dataset, series)\n",
    "    if df.empty:\n",
    "        print(\"No matching files or no valid metrics found.\")\n",
    "        return\n",
    "\n",
    "    summary = summarize_over_subjects(df)\n",
    "\n",
    "    print(\"=== Summary over subjects (mean ± std of max metrics) ===\")\n",
    "    print(summary.to_string(index=False))\n",
    "\n",
    "    # summary.to_csv(out, index=False)\n",
    "    # print(f\"\\nSaved summary to {out}\")\n",
    "    \n",
    "    return summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "222bfb94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Dataset/series: VAR3/1\n",
      "=== Summary over subjects (mean ± std of max metrics) ===\n",
      "dataset  series       model   penalty importance  max_AUROC_mean  max_AUROC_std  max_AUPRC_mean  max_AUPRC_std  n_subjects\n",
      "   VAR3       1 ResidualMLP Fast_Shap    Shapley        0.999810       0.000426        0.999583       0.000932           5\n",
      "   VAR3       1 ResidualMLP   Jacob_F   Jacobian        0.998000       0.002702        0.995981       0.005055           5\n",
      "   VAR3       1 ResidualMLP  Jacob_L1   Jacobian        0.999238       0.001452        0.998256       0.003332           5\n",
      "   VAR3       1 ResidualMLP   Shapley    Shapley        0.998857       0.001801        0.997612       0.003632           5\n",
      "[INFO] Dataset/series: VAR3/2\n",
      "=== Summary over subjects (mean ± std of max metrics) ===\n",
      "dataset  series       model   penalty importance  max_AUROC_mean  max_AUROC_std  max_AUPRC_mean  max_AUPRC_std  n_subjects\n",
      "   VAR3       2 ResidualMLP Fast_Shap    Shapley        1.000000       0.000000        1.000000       0.000000           5\n",
      "   VAR3       2 ResidualMLP   Jacob_F   Jacobian        0.999810       0.000426        0.999563       0.000978           5\n",
      "   VAR3       2 ResidualMLP  Jacob_L1   Jacobian        0.999810       0.000426        0.999583       0.000932           5\n",
      "   VAR3       2 ResidualMLP   Shapley    Shapley        0.999905       0.000213        0.999785       0.000481           5\n",
      "[INFO] Dataset/series: VAR3/3\n",
      "=== Summary over subjects (mean ± std of max metrics) ===\n",
      "dataset  series       model   penalty importance  max_AUROC_mean  max_AUROC_std  max_AUPRC_mean  max_AUPRC_std  n_subjects\n",
      "   VAR3       3 ResidualMLP Fast_Shap    Shapley        0.881804       0.010751        0.834960       0.009319           5\n",
      "   VAR3       3 ResidualMLP   Jacob_F   Jacobian        0.877427       0.011148        0.831281       0.012317           5\n",
      "   VAR3       3 ResidualMLP  Jacob_L1   Jacobian        0.894768       0.007735        0.854803       0.008832           5\n",
      "   VAR3       3 ResidualMLP   Shapley    Shapley        0.898490       0.008984        0.864854       0.008113           5\n",
      "[INFO] Dataset/series: VAR3/4\n",
      "=== Summary over subjects (mean ± std of max metrics) ===\n",
      "dataset  series       model   penalty importance  max_AUROC_mean  max_AUROC_std  max_AUPRC_mean  max_AUPRC_std  n_subjects\n",
      "   VAR3       4 ResidualMLP Fast_Shap    Shapley        0.957940       0.003126        0.940113       0.005827           5\n",
      "   VAR3       4 ResidualMLP   Jacob_F   Jacobian        0.955417       0.002644        0.936408       0.004470           5\n",
      "   VAR3       4 ResidualMLP  Jacob_L1   Jacobian        0.978608       0.002971        0.970319       0.003741           5\n",
      "   VAR3       4 ResidualMLP   Shapley    Shapley        0.977469       0.005253        0.969705       0.004547           5\n"
     ]
    }
   ],
   "source": [
    "combined_summary = pd.concat(\n",
    "    [main('simulation_VAR3', 'VAR3', i) for i in range(1, 5)],\n",
    "    ignore_index=True\n",
    ")\n",
    "# combined_summary.to_csv('./simulation_VAR3.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MPS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
