{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0f5cbe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.utils import resample\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# --- Stratified 3-way split function ---\n",
    "def stratified_split(df, test_frac=0.1, val_frac=0.1, seed=42):\n",
    "    y = df[['energy_loss','alpha','q0']].astype(str).agg('_'.join, axis=1)\n",
    "    df_train, df_temp = train_test_split(df, test_size=test_frac+val_frac, stratify=y, random_state=seed)\n",
    "    y_temp = df_temp[['energy_loss','alpha','q0']].astype(str).agg('_'.join, axis=1)\n",
    "    df_val, df_test = train_test_split(df_temp,\n",
    "                                       test_size=val_frac/(test_frac+val_frac),\n",
    "                                       stratify=y_temp,\n",
    "                                       random_state=seed)\n",
    "    return df_train, df_val, df_test\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f522756",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Config ---\n",
    "INPUT_CSV = \"/data/\" \\\n",
    "\"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\" \\\n",
    "\"file_labels_aggregated_ds7200000_g500.csv\"\n",
    "OUTPUT_CSV = \"/data/\" \\\n",
    "\"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\" \\\n",
    "\"file_labels_aggregated_ds1008_g500.csv\"\n",
    "\n",
    "TRAIN_CSV = \"/data/\" \\\n",
    "\"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\" \\\n",
    "\"file_labels_aggregated_ds1008_g500_train.csv\"\n",
    "\n",
    "VAL_CSV = \"/data/\" \\\n",
    "\"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\" \\\n",
    "\"file_labels_aggregated_ds1008_g500_val.csv\"\n",
    "\n",
    "TEST_CSV = \"/data/\" \\\n",
    "\"jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\" \\\n",
    "\"file_labels_aggregated_ds1008_g500_test.csv\"\n",
    "TARGET_TOTAL = 1008\n",
    "SEED = 42\n",
    "\n",
    "# --- Load full data ---\n",
    "df = pd.read_csv(INPUT_CSV)\n",
    "\n",
    "# --- Create balanced 1000-sample dataset (equal per label combo) ---\n",
    "label_cols = ['energy_loss', 'alpha', 'q0']\n",
    "df['label_combo'] = df[label_cols].astype(str).agg('_'.join, axis=1)\n",
    "n_classes = df['label_combo'].nunique()\n",
    "samples_per_class = TARGET_TOTAL // n_classes\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2079395",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(samples_per_class)\n",
    "print(n_classes)\n",
    "TARGET_TOTAL % n_classes "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5eace13",
   "metadata": {},
   "outputs": [],
   "source": [
    "if TARGET_TOTAL % n_classes != 0:\n",
    "    raise ValueError(f\"{TARGET_TOTAL} is not divisible by {n_classes} unique label combinations.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "195e5250",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_balanced = (\n",
    "    df.groupby('label_combo', group_keys=False)\n",
    "    .apply(lambda g: resample(g, replace=True, n_samples=samples_per_class, random_state=SEED))\n",
    "    .reset_index(drop=True)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef47ef92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show df_balanced['label_combo'].value_counts()\n",
    "print(df_balanced['label_combo'].value_counts())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b40fbbd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_balanced = df_balanced.drop(columns=['label_combo'])\n",
    "df_balanced.to_csv(OUTPUT_CSV, index=False)\n",
    "print(f\"✅ saved:\")\n",
    "print(f\"  → Dataset: {len(df_balanced)} rows → {OUTPUT_CSV}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f318bf85",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Stratified split into train/val/test ---\n",
    "df_train, df_val, df_test = stratified_split(df_balanced, test_frac=0.1, val_frac=0.1, seed=SEED)\n",
    "\n",
    "# --- Save to disk ---\n",
    "df_train.to_csv(TRAIN_CSV, index=False)\n",
    "df_val.to_csv(VAL_CSV, index=False)\n",
    "df_test.to_csv(TEST_CSV, index=False)\n",
    "\n",
    "print(f\"✅ Train/val split completed and saved:\")\n",
    "print(f\"  → Train: {len(df_train)} rows → {TRAIN_CSV}\")\n",
    "print(f\"  → Val:   {len(df_val)} rows   → {VAL_CSV}\")\n",
    "print(f\"  → Test:  {len(df_test)} rows   → {TEST_CSV}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
