{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aff62ef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CELL: KFold split builder\n",
    "import os, json, math, numpy as np, pandas as pd\n",
    "from pathlib import Path\n",
    "from sklearn.model_selection import StratifiedKFold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "789a7895",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "cfg=get_config(config_path=\"/\" \\\n",
    "\"experiments/exp_kfold_pretrained_models/config/\" \\\n",
    "\"cross_validation_kfold_datasetrcalc.yml\")\n",
    "\n",
    "# cfg=get_config()\n",
    "print(json.dumps(vars(cfg), indent=2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e398ce7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- config hooks you likely already have in your notebook ----\n",
    "# cfg: your argparse/omegaconf-like object\n",
    "# cfg.dataset_root_dir, cfg.train_csv, cfg.val_csv, cfg.test_csv\n",
    "# cfg.group_size, cfg.global_max\n",
    "\n",
    "# Utilities from your loader (we reuse your CSV schema)\n",
    "from data.loader import load_split_from_csv, save_split_to_csv  # same columns as before\n",
    "\n",
    "ROOT = Path(cfg.dataset_root_dir)\n",
    "FOLDS_DIR = ROOT / \"folds\"\n",
    "FOLDS_DIR.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "151ab26a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load base splits (if you want CV over train+val)\n",
    "train_base = load_split_from_csv(cfg.train_csv, cfg.dataset_root_dir)\n",
    "val_base   = load_split_from_csv(cfg.val_csv,   cfg.dataset_root_dir)\n",
    "test_base  = load_split_from_csv(cfg.test_csv,  cfg.dataset_root_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6862deec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d426697",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# Merge train+val, keep test fixed\n",
    "trainval = train_base + val_base\n",
    "X = np.array([fp for fp, _ in trainval])\n",
    "\n",
    "# Pack tuple-label -> single class id for stratification\n",
    "def pack_label(lbl_tuple):\n",
    "    # energy (2) * alpha (3) * q0 (4) => 24 classes max\n",
    "    e, a, q = lbl_tuple\n",
    "    return e*12 + a*4 + q\n",
    "\n",
    "y = np.array([pack_label(lbl) for _, lbl in trainval])\n",
    "\n",
    "n_splits = getattr(cfg, \"n_splits\", 5)\n",
    "random_state = getattr(cfg, \"random_seed\", 42)\n",
    "\n",
    "skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)\n",
    "\n",
    "fold_paths = []\n",
    "for k, (tr_idx, va_idx) in enumerate(skf.split(X, y), start=1):\n",
    "    fold_dir = FOLDS_DIR / f\"fold_{k}\"\n",
    "    fold_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    tr_list = [trainval[i] for i in tr_idx]\n",
    "    va_list = [trainval[i] for i in va_idx]\n",
    "\n",
    "    # Save CSVs relative to ROOT\n",
    "    tr_csv = fold_dir / \"train.csv\"\n",
    "    va_csv = fold_dir / \"val.csv\"\n",
    "    te_csv = fold_dir / \"test.csv\"   # constant\n",
    "\n",
    "    save_split_to_csv(tr_list, str(tr_csv), cfg.dataset_root_dir)\n",
    "    save_split_to_csv(va_list, str(va_csv), cfg.dataset_root_dir)\n",
    "    save_split_to_csv(test_base, str(te_csv), cfg.dataset_root_dir)\n",
    "\n",
    "    fold_paths.append({\"k\": k, \"train_csv\": str(tr_csv), \"val_csv\": str(va_csv), \"test_csv\": str(te_csv)})\n",
    "\n",
    "print(f\"[OK] Built {n_splits} folds at {FOLDS_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2396ffe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CELL 0 — Setup (paths & params)\n",
    "import os\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "# ==== CONFIG: EDIT THESE ====\n",
    "DATA_ROOT = Path(\"/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled\")\n",
    "\n",
    "TRAIN_CSV = DATA_ROOT / \"file_labels_aggregated_ds1008_g500_train.csv\"\n",
    "VAL_CSV   = DATA_ROOT / \"file_labels_aggregated_ds1008_g500_val.csv\"\n",
    "\n",
    "# Output directory for folds\n",
    "FOLDS_DIR = DATA_ROOT / \"folds_agg_ds1008_g500\"\n",
    "N_SPLITS = 5\n",
    "RANDOM_SEED = 42\n",
    "SHUFFLE_BEFORE_SPLIT = True  # set False to keep incoming order\n",
    "# ============================\n",
    "FOLDS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "print(f\"[INFO] Folds will be written to: {FOLDS_DIR}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5295511a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CELL 1 — Load aggregated CSVs (train + val) and sanity-check\n",
    "def _assert_columns(df: pd.DataFrame, file:str):\n",
    "    required = {\"agg_id\", \"file_paths\", \"energy_loss\", \"alpha\", \"q0\"}\n",
    "    missing = required - set(df.columns)\n",
    "    if missing:\n",
    "        raise ValueError(f\"{file} is missing columns: {sorted(missing)}\")\n",
    "\n",
    "def _coerce_types(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    # Ensure label dtypes are ints (robust to str)\n",
    "    for c in [\"energy_loss\", \"alpha\", \"q0\"]:\n",
    "        df[c] = pd.to_numeric(df[c], downcast=\"integer\")\n",
    "    # Ensure agg_id is string-like (safe to keep as string)\n",
    "    df[\"agg_id\"] = df[\"agg_id\"].astype(str)\n",
    "    # file_paths must remain the pipe-separated string\n",
    "    df[\"file_paths\"] = df[\"file_paths\"].astype(str)\n",
    "    return df\n",
    "\n",
    "train_df = pd.read_csv(TRAIN_CSV)\n",
    "val_df   = pd.read_csv(VAL_CSV)\n",
    "\n",
    "_assert_columns(train_df, str(TRAIN_CSV))\n",
    "_assert_columns(val_df,   str(VAL_CSV))\n",
    "\n",
    "train_df = _coerce_types(train_df)\n",
    "val_df   = _coerce_types(val_df)\n",
    "\n",
    "print(f\"[INFO] Loaded aggregated CSVs: train={len(train_df)}, val={len(val_df)}\")\n",
    "all_df = pd.concat([train_df, val_df], ignore_index=True)\n",
    "\n",
    "if SHUFFLE_BEFORE_SPLIT:\n",
    "    all_df = all_df.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)\n",
    "\n",
    "print(f\"[INFO] Combined rows (train+val): {len(all_df)}\")\n",
    "all_df.head()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
