{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49a470b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Summary over subjects:\n",
      "        method  max_AUROC_mean  max_AUROC_std  max_AUPRC_mean  max_AUPRC_std  \\\n",
      "0  Fast_Shap_1        0.847822       0.022497        0.634066       0.036209   \n",
      "1  Fast_Shap_3        0.848327       0.027940        0.639948       0.033239   \n",
      "2  Fast_Shap_5        0.850979       0.021889        0.637849       0.042150   \n",
      "3         None        0.670739       0.010170        0.328472       0.040196   \n",
      "4      Shapley        0.856345       0.021524        0.676456       0.019346   \n",
      "\n",
      "   min_val_loss_mean  min_val_loss_std  n_subjects  \n",
      "0           0.963972          0.032089           5  \n",
      "1           0.963075          0.037860           5  \n",
      "2           0.960913          0.033666           5  \n",
      "3           1.060134          0.038475           5  \n",
      "4           0.946260          0.037556           5  \n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import glob\n",
    "import pandas as pd\n",
    "import re\n",
    "import numpy as np\n",
    "\n",
    "VALID_METHODS = {\"Fast_Shap_1\", \"Fast_Shap_3\", \"Fast_Shap_5\", \"Shapley\", \"None\"}\n",
    "\n",
    "def parse_filename(filename: str) -> tuple | None:\n",
    "    \"\"\"\n",
    "    Parse filenames like:\n",
    "        sensitivity_fMRI_1_0_ResidualMLP_Fast_Shap_1.csv\n",
    "    Returns: dataset, series, subject, model, method\n",
    "    \"\"\"\n",
    "    name = os.path.splitext(os.path.basename(filename))[0]\n",
    "\n",
    "    # Remove optional prefixes\n",
    "    prefix_match = re.match(r\"^(sensitivity)_\", name)\n",
    "    if prefix_match:\n",
    "        name = name[prefix_match.end():]\n",
    "\n",
    "    parts = name.split(\"_\")\n",
    "    if len(parts) < 5:\n",
    "        return None\n",
    "\n",
    "    dataset = parts[0]\n",
    "    try:\n",
    "        series = int(parts[1])\n",
    "        subject = int(parts[2])\n",
    "    except ValueError:\n",
    "        return None\n",
    "\n",
    "    model = parts[3]\n",
    "    method = \"_\".join(parts[4:])\n",
    "\n",
    "    if method not in VALID_METHODS:\n",
    "        return None\n",
    "\n",
    "    return dataset, series, subject, model, method\n",
    "\n",
    "\n",
    "def collect_all_metrics(root_dir: str, dataset: str, series: int) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Scan folder, parse filenames, load CSVs.\n",
    "    Extracts the BEST value for each metric INDEPENDENTLY:\n",
    "      - max_AUROC: Global max of the AUROC column.\n",
    "      - max_AUPRC: Global max of the AUPRC column.\n",
    "      - min_val_loss: Global min of the val_loss column.\n",
    "    \"\"\"\n",
    "    pattern = os.path.join(root_dir, f\"*{dataset}_{series}_*.csv\")\n",
    "    files = glob.glob(pattern)\n",
    "    if not files:\n",
    "        print(f\"[INFO] No files found in {root_dir} matching series {series}\")\n",
    "        return pd.DataFrame()\n",
    "\n",
    "    records = []\n",
    "\n",
    "    for fpath in files:\n",
    "        parsed = parse_filename(fpath)\n",
    "        if parsed is None:\n",
    "            continue\n",
    "        dataset_i, series_i, subject_i, model, method = parsed\n",
    "        if series_i != series:\n",
    "            continue\n",
    "\n",
    "        # Load CSV\n",
    "        try:\n",
    "            df = pd.read_csv(fpath, on_bad_lines=\"skip\")\n",
    "        except Exception as e:\n",
    "            print(f\"[WARN] Failed to load {fpath}: {e}\")\n",
    "            continue\n",
    "        if df.empty:\n",
    "            continue\n",
    "\n",
    "        # -----------------------------------------------------------\n",
    "        # 1. Identify Columns\n",
    "        # -----------------------------------------------------------\n",
    "        # Metric columns\n",
    "        auroc_col, auprc_col = None, None\n",
    "        for col_pair in [(f\"AUROC_{method}\", f\"AUPRC_{method}\"), (\"AUROC\", \"AUPRC\")]:\n",
    "            if col_pair[0] in df.columns and col_pair[1] in df.columns:\n",
    "                auroc_col, auprc_col = col_pair\n",
    "                break\n",
    "        \n",
    "        # Val loss column\n",
    "        val_loss_col = None\n",
    "        possible_loss_names = [\"val_loss\", \"loss_val\", \"valid_loss\", \"Validation Loss\"]\n",
    "        for name in possible_loss_names:\n",
    "            if name in df.columns:\n",
    "                val_loss_col = name\n",
    "                break\n",
    "\n",
    "        if auroc_col is None:\n",
    "            continue\n",
    "\n",
    "        # -----------------------------------------------------------\n",
    "        # 2. Extract Independent Bests\n",
    "        # -----------------------------------------------------------\n",
    "        # Ensure numeric types\n",
    "        df[auroc_col] = pd.to_numeric(df[auroc_col], errors=\"coerce\")\n",
    "        df[auprc_col] = pd.to_numeric(df[auprc_col], errors=\"coerce\")\n",
    "        \n",
    "        # Calculate Max AUROC / AUPRC\n",
    "        # (We use max() on the whole column directly)\n",
    "        max_auroc = df[auroc_col].max()\n",
    "        max_auprc = df[auprc_col].max()\n",
    "        \n",
    "        # Calculate Min Val Loss\n",
    "        min_val_loss = np.nan\n",
    "        if val_loss_col:\n",
    "            df[val_loss_col] = pd.to_numeric(df[val_loss_col], errors=\"coerce\")\n",
    "            min_val_loss = df[val_loss_col].min()\n",
    "\n",
    "        # Skip if data is completely junk (all NaNs)\n",
    "        if pd.isna(max_auroc):\n",
    "            continue\n",
    "\n",
    "        record = {\n",
    "            \"dataset\": dataset_i,\n",
    "            \"series\": series_i,\n",
    "            \"subject\": subject_i,\n",
    "            \"model\": model,\n",
    "            \"method\": method,\n",
    "            \"max_AUROC\": max_auroc,\n",
    "            \"max_AUPRC\": max_auprc,\n",
    "            \"min_val_loss\": min_val_loss\n",
    "        }\n",
    "        records.append(record)\n",
    "\n",
    "    df_out = pd.DataFrame(records)\n",
    "    return df_out\n",
    "\n",
    "\n",
    "def summarize_metrics(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Compute mean/std across subjects for each (dataset, series, model, method)\n",
    "    \"\"\"\n",
    "    if df.empty:\n",
    "        return df\n",
    "\n",
    "    grouped = (\n",
    "        df.groupby([\"method\"])\n",
    "        .agg(\n",
    "            max_AUROC_mean=(\"max_AUROC\", \"mean\"),\n",
    "            max_AUROC_std=(\"max_AUROC\", \"std\"),\n",
    "            max_AUPRC_mean=(\"max_AUPRC\", \"mean\"),\n",
    "            max_AUPRC_std=(\"max_AUPRC\", \"std\"),\n",
    "            min_val_loss_mean=(\"min_val_loss\", \"mean\"),\n",
    "            min_val_loss_std=(\"min_val_loss\", \"std\"),\n",
    "            n_subjects=(\"subject\", \"nunique\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "    return grouped\n",
    "\n",
    "# ----------------------------\n",
    "# Example usage\n",
    "# ----------------------------\n",
    "if __name__ == \"__main__\":\n",
    "    # Update this path\n",
    "    root_dir = \"\"\n",
    "    dataset = \"fMRI\"\n",
    "    series = 1\n",
    "\n",
    "    df_all = collect_all_metrics(root_dir, dataset, series)\n",
    "    \n",
    "    if not df_all.empty:\n",
    "        summary = summarize_metrics(df_all)\n",
    "        \n",
    "        print(\"\\nSummary over subjects:\")\n",
    "        pd.set_option('display.max_columns', None)\n",
    "        print(summary)\n",
    "    else:\n",
    "        print(\"No valid data found.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MPS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
