{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0746ba9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ⚙️ Configuration\n",
    "import os\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "import random\n",
    "\n",
    "# Paths (update if needed)\n",
    "dataset_root = os.path.expanduser(\n",
    "    '~'\n",
    ")\n",
    "# dataset_root = os.path.expanduser(\n",
    "#     '~/data/jet_ml_benchmark_config_01_to_09_'\n",
    "#     'alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/'\n",
    "# )\n",
    "# dataset_root = os.path.expanduser(\n",
    "#     '~/data/jet_ml_benchmark_config_01_to_09_'\n",
    "#     'alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_1000000_balanced_unshuffled/'\n",
    "# )\n",
    "file_labels_csv = os.path.join(dataset_root, 'file_labels.csv')\n",
    "\n",
    "# Aggregation settings\n",
    "group_size = 500\n",
    "agg_csv_out = os.path.join(\n",
    "    dataset_root,\n",
    "    f'file_labels_aggregated_g{group_size}.csv'\n",
    ")\n",
    "\n",
    "print(f\"Dataset root: {dataset_root}\")\n",
    "print(f\"Original labels CSV: {file_labels_csv}\")\n",
    "print(f\"Aggregated CSV will be saved to: {agg_csv_out}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe00c99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 🧱 Build Aggregation CSV from file_labels.csv\n",
    "# %% Cell 2: Build Aggregation CSV\n",
    "def aggregate_file_labels(group_size=500):\n",
    "    \"\"\"\n",
    "    Aggregate file labels from the original CSV into groups of specified size.\n",
    "\n",
    "    Args:\n",
    "        group_size (int): Size of each group for aggregation.\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    df = pd.read_csv(file_labels_csv)\n",
    "\n",
    "    label_to_paths = defaultdict(list)\n",
    "    for _, row in df.iterrows():\n",
    "        label = (row['energy_loss'], row['alpha'], row['q0'])\n",
    "        label_to_paths[label].append(row['file_path'])\n",
    "\n",
    "    agg_entries = []\n",
    "    agg_id = 0\n",
    "    for label, paths in label_to_paths.items():\n",
    "        random.shuffle(paths)\n",
    "        for i in range(0, len(paths) - group_size + 1, group_size):\n",
    "            group = paths[i:i + group_size]\n",
    "            if len(group) == group_size:\n",
    "                agg_entries.append({\n",
    "                    'agg_id': f'agg_{agg_id:06d}',\n",
    "                    'file_paths': '|'.join(group),\n",
    "                    'energy_loss': label[0],\n",
    "                    'alpha': label[1],\n",
    "                    'q0': label[2]\n",
    "                })\n",
    "                agg_id += 1\n",
    "\n",
    "    agg_df = pd.DataFrame(agg_entries)\n",
    "    agg_df.to_csv(agg_csv_out, index=False)\n",
    "    print(f\"✅ Saved {len(agg_df)} aggregated entries to {agg_csv_out}\")\n",
    "# aggregate_file_labels(dataset_root, group_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbdd3017",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 📦 Dataset Class with GPU Aggregation\n",
    "# %% Cell 3: AggregatedJetDataset Definition\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "import numpy as np\n",
    "\n",
    "class AggregatedJetDataset(Dataset):\n",
    "    def __init__(self, agg_csv, root_dir, global_max, device='cuda'):\n",
    "        self.df = pd.read_csv(agg_csv)\n",
    "        self.root_dir = root_dir\n",
    "        self.global_max = global_max\n",
    "        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.df)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.df.iloc[idx]\n",
    "        file_list = row['file_paths'].split('|')\n",
    "        imgs = []\n",
    "        for rel_path in file_list:\n",
    "            abs_path = os.path.join(self.root_dir, rel_path)\n",
    "            arr = np.load(abs_path).astype(np.float32) / self.global_max\n",
    "            imgs.append(torch.tensor(arr, device=self.device).unsqueeze(0))\n",
    "        img_avg = torch.stack(imgs).mean(dim=0).unsqueeze(0)\n",
    "        labels = {\n",
    "            'energy_loss_output': torch.tensor([row['energy_loss']], dtype=torch.long, device=self.device),\n",
    "            'alpha_output':       torch.tensor([row['alpha']],       dtype=torch.long, device=self.device),\n",
    "            'q0_output':          torch.tensor([row['q0']],          dtype=torch.long, device=self.device)\n",
    "        }\n",
    "        return img_avg, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "614cc08a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 📊 Stratified Split and DataLoader Builder\n",
    "# %% Cell 4: Stratified Split & DataLoader Builder\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "def stratified_split(df, test_frac=0.1, val_frac=0.1, seed=42):\n",
    "    y = df[['energy_loss','alpha','q0']].astype(str).agg('_'.join, axis=1)\n",
    "    df_train, df_temp = train_test_split(df, test_size=test_frac+val_frac, stratify=y, random_state=seed)\n",
    "    y_temp = df_temp[['energy_loss','alpha','q0']].astype(str).agg('_'.join, axis=1)\n",
    "    df_val, df_test = train_test_split(df_temp,\n",
    "                                       test_size=val_frac/(test_frac+val_frac),\n",
    "                                       stratify=y_temp,\n",
    "                                       random_state=seed)\n",
    "    return df_train, df_val, df_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44540427",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_split_csvs(agg_csv: str,\n",
    "    root_dir: str,\n",
    "    test_frac: float = 0.1,\n",
    "    val_frac: float = 0.1,\n",
    "    seed: int = 42\n",
    ") -> tuple[str, str, str]:\n",
    "    # Read the aggregated CSV\n",
    "    full_df = pd.read_csv(agg_csv)\n",
    "    df_train, df_val, df_test = stratified_split(full_df,test_frac=test_frac, val_frac=val_frac, seed=seed)\n",
    "\n",
    "    # Incorporate group_size into split filenames\n",
    "    basename = os.path.splitext(os.path.basename(agg_csv))[0]  # e.g. \"file_labels_aggregated_g5\"\n",
    "    train_csv = os.path.join(root_dir, f'{basename}_train.csv')\n",
    "    val_csv   = os.path.join(root_dir, f'{basename}_val.csv')\n",
    "    test_csv  = os.path.join(root_dir, f'{basename}_test.csv')\n",
    "\n",
    "    # Save the splits\n",
    "    df_train.to_csv(train_csv, index=False)\n",
    "    df_val.to_csv(val_csv,     index=False)\n",
    "    df_test.to_csv(test_csv,   index=False)\n",
    "    return train_csv, val_csv, test_csv\n",
    "def build_dataloaders_from_splits(\n",
    "    train_csv: str,\n",
    "    val_csv:   str,\n",
    "    test_csv:  str,\n",
    "    root_dir:  str,\n",
    "    global_max: float,\n",
    "    batch_size: int = 32\n",
    ") -> tuple[DataLoader, DataLoader, DataLoader]:\n",
    "    \"\"\"\n",
    "    Given filepaths to the train/val/test CSVs (generated by create_split_csvs),\n",
    "    instantiate AggregatedJetDataset and return the three DataLoaders.\n",
    "    \"\"\"\n",
    "    # Create Datasets\n",
    "    train_ds = AggregatedJetDataset(train_csv, root_dir, global_max)\n",
    "    val_ds   = AggregatedJetDataset(val_csv,   root_dir, global_max)\n",
    "    test_ds  = AggregatedJetDataset(test_csv,  root_dir, global_max)\n",
    "\n",
    "    # Build DataLoaders\n",
    "    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
    "    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)\n",
    "    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    return train_loader, val_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9454a0bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ✅ Test DataLoader Pipeline\n",
    "# %% Cell 5: Quick Test\n",
    "\n",
    "# train_csv, val_csv, test_csv=create_split_csvs(\n",
    "#     agg_csv_out,\n",
    "#     dataset_root,\n",
    "#     test_frac=0.1,\n",
    "#     val_frac=0.1,\n",
    "#     seed=42\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "376618b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_loader, val_loader, test_loader = build_dataloaders_from_splits(\n",
    "#     train_csv,\n",
    "#     val_csv,\n",
    "#     test_csv,\n",
    "#     dataset_root,\n",
    "#     global_max=121.79151153564453,\n",
    "#     batch_size=512\n",
    "# ) \n",
    "# x, y = next(iter(train_loader))\n",
    "# print(\"Batch images:\", x.shape)\n",
    "# for k, v in y.items():\n",
    "#     print(f\"{k}: {v.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f7b360",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Function to Load & Stratified-Split Aggregated CSV\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "def load_and_stratified_split(csv_path, test_frac=0.1, val_frac=0.1, seed=42):\n",
    "    \"\"\"\n",
    "    Load file_labels_aggregated_g500.csv and return stratified train/val/test splits.\n",
    "\n",
    "    Args:\n",
    "        csv_path (str): Path to the aggregated CSV.\n",
    "        test_frac (float): Fraction of data to be used for testing.\n",
    "        val_frac (float): Fraction of data to be used for validation.\n",
    "        seed (int): Random seed for reproducibility.\n",
    "\n",
    "    Returns:\n",
    "        df_train, df_val, df_test (pd.DataFrame): Split DataFrames\n",
    "    \"\"\"\n",
    "    df = pd.read_csv(csv_path)\n",
    "\n",
    "    # Create a stratification label from the three parameters\n",
    "    strat_labels = df[['energy_loss', 'alpha', 'q0']].astype(str).agg('_'.join, axis=1)\n",
    "\n",
    "    # First split: Train vs Temp (Val + Test)\n",
    "    df_train, df_temp = train_test_split(\n",
    "        df, test_size=test_frac + val_frac, stratify=strat_labels, random_state=seed\n",
    "    )\n",
    "\n",
    "    # Create new strat labels for remaining temp set\n",
    "    strat_labels_temp = df_temp[['energy_loss', 'alpha', 'q0']].astype(str).agg('_'.join, axis=1)\n",
    "\n",
    "    # Second split: Val vs Test\n",
    "    val_ratio = val_frac / (val_frac + test_frac)\n",
    "    df_val, df_test = train_test_split(\n",
    "        df_temp, test_size=1 - val_ratio, stratify=strat_labels_temp, random_state=seed\n",
    "    )\n",
    "\n",
    "    return df_train, df_val, df_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "554a8167",
   "metadata": {},
   "outputs": [],
   "source": [
    "csv_path = agg_csv_out\n",
    "df_train, df_val, df_test = load_and_stratified_split(csv_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "063bb0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_csv_out = os.path.join(dataset_root, f'file_labels_aggregated_g{group_size}_train.csv')\n",
    "val_csv_out   = os.path.join(dataset_root, f'file_labels_aggregated_g{group_size}_val.csv')\n",
    "test_csv_out  = os.path.join(dataset_root, f'file_labels_aggregated_g{group_size}_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61556961",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.to_csv(train_csv_out, index=False)\n",
    "df_val.to_csv(val_csv_out, index=False)\n",
    "df_test.to_csv(test_csv_out, index=False)\n",
    "\n",
    "print(f\"Saved:\\n{train_csv_out}\\n{val_csv_out}\\n{test_csv_out}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fa77888",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Cell 6: Enhanced Single Plotter with Hist2D Style\n",
    "def plot_single_jet(x, y):\n",
    "    \"\"\"\n",
    "    Plot a single jet image with human-readable labels.\n",
    "\n",
    "    Args:\n",
    "        x (torch.Tensor): Input tensor of shape (1, 32, 32).\n",
    "        y (dict): Dictionary containing labels.\n",
    "\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    import math\n",
    "    import numpy as np\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "\n",
    "    # Maps for real parameter values\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # Fetch first batch sample (ensure x, y from Cell 5 are in scope)\n",
    "    img_tensor = x[0]                  # shape: (1,32,32)\n",
    "    img = img_tensor.squeeze()         # shape: (32,32)\n",
    "    if hasattr(img, 'cpu'):\n",
    "        img = img.cpu().numpy()\n",
    "\n",
    "    # Extract true labels\n",
    "    energy_idx = y['energy_loss_output'][0].item()\n",
    "    alpha_idx  = y['alpha_output'][0].item()\n",
    "    q0_idx     = y['q0_output'][0].item()\n",
    "\n",
    "    # Human-readable labels\n",
    "    e_str = energy_map[energy_idx]\n",
    "    α = alpha_vals[alpha_idx]\n",
    "    Q0 = q0_vals[q0_idx]\n",
    "\n",
    "    # Mask zero pixels\n",
    "    img_masked = np.ma.masked_where(img == 0, img)\n",
    "\n",
    "    # Define bin edges for [-π, π]\n",
    "    x_edges = np.linspace(-math.pi, math.pi, img.shape[1] + 1)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, img.shape[0] + 1)\n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(5, 5), dpi=200)\n",
    "    pcm = plt.pcolormesh(\n",
    "        x_edges, y_edges, img_masked,\n",
    "        norm=colors.LogNorm(vmin=img_masked.min() or 1e-6, vmax=img_masked.max()),\n",
    "        cmap='jet', shading='auto'\n",
    "    )\n",
    "    plt.colorbar(pcm, label='Normalized Intensity')\n",
    "    plt.title(f'{e_str}, αₛ={α}, Q₀={Q0}', fontsize=12)\n",
    "\n",
    "    # Shared axis ticks\n",
    "    plt.xticks([-math.pi, 0, math.pi], [r'$-\\pi$', '0', r'$\\pi$'])\n",
    "    plt.yticks([-math.pi, 0, math.pi], [r'$-\\pi$', '0', r'$\\pi$'])\n",
    "    plt.xlabel('X (φ)')\n",
    "    plt.ylabel('Y (η)')\n",
    "    plt.show()\n",
    "# plot_single_jet(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c21800",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Cell 7: Grid of Hist2D Plots with Real Labels (12×10)\n",
    "def plot_grid_hist2d(agg_csv, root_dir, global_max):\n",
    "    \"\"\"\n",
    "    Plot a grid of hist2D images for each (alpha, q0) combination.\n",
    "\n",
    "    Args:\n",
    "        agg_csv (str): Path to the aggregated CSV file.\n",
    "        root_dir (str): Path to the root directory of the dataset.\n",
    "        global_max (float): Global maximum value for normalization.\n",
    "\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    import os\n",
    "    import math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    from itertools import product\n",
    "\n",
    "    # Paths & constants\n",
    "    dataset_root = os.path.expanduser(\n",
    "        '~/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_'\n",
    "        'alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/'\n",
    "    )\n",
    "    agg_csv    = os.path.join(dataset_root, 'file_labels_aggregated_g1000.csv')\n",
    "    global_max = 121.79151153564453\n",
    "\n",
    "    # Reverse‐maps for real values\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # Load aggregated entries\n",
    "    df = pd.read_csv(agg_csv)\n",
    "\n",
    "    # All (alpha_idx, q0_idx) combos → 3×4 = 12 rows\n",
    "    combos = list(product([0,1,2], [0,1,2,3]))\n",
    "    n_rows, n_cols = len(combos), 10\n",
    "\n",
    "    # Create subplots\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols*1.5, n_rows*1.2),\n",
    "        sharex='col', sharey='row',\n",
    "        dpi=200\n",
    "    )\n",
    "\n",
    "    # Tight layout\n",
    "    fig.subplots_adjust(\n",
    "        left   = 0.15,  # room for row labels\n",
    "        right  = 0.97,\n",
    "        top    = 0.96,\n",
    "        bottom = 0.02,\n",
    "        hspace = 0.2,\n",
    "        wspace = 0.1\n",
    "    )\n",
    "\n",
    "    # Bin edges\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    for i, (a_idx, q_idx) in enumerate(combos):\n",
    "        subset = df[(df['alpha']==a_idx) & (df['q0']==q_idx)]\n",
    "        samples = subset.sample(n=n_cols, replace=len(subset)<n_cols, random_state=0)\n",
    "        for j, (_, entry) in enumerate(samples.iterrows()):\n",
    "            ax = axes[i, j]\n",
    "            # Load & average\n",
    "            imgs = [\n",
    "                np.load(os.path.join(dataset_root, p)).astype(np.float32)/global_max\n",
    "                for p in entry['file_paths'].split('|')\n",
    "            ]\n",
    "            avg = np.mean(imgs, axis=0)\n",
    "            avg_masked = np.ma.masked_where(avg==0, avg)\n",
    "            pcm = ax.pcolormesh(\n",
    "                x_edges, y_edges, avg_masked,\n",
    "                norm=colors.LogNorm(vmin=avg_masked.min() or 1e-6, vmax=avg_masked.max()),\n",
    "                cmap='jet', shading='auto'\n",
    "            )\n",
    "            ax.set_xticks([]); ax.set_yticks([])\n",
    "            # Real‐value row label\n",
    "            if j == 0:\n",
    "                e_str = energy_map[entry['energy_loss']]\n",
    "                α = alpha_vals[a_idx]\n",
    "                Q0 = q0_vals[q_idx]\n",
    "                ax.text(-0.35, 0.5,\n",
    "                        f'{e_str}\\nαₛ={α}\\nQ₀={Q0}',\n",
    "                        transform=ax.transAxes,\n",
    "                        va='center', ha='right',\n",
    "                        fontsize=8)\n",
    "\n",
    "    # Shared ticks bottom row & left column\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=6)\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=6)\n",
    "\n",
    "    # Colorbar\n",
    "    cbar = fig.colorbar(pcm, ax=axes, fraction=0.015, pad=0.01)\n",
    "    cbar.set_label('Normalized Intensity', fontsize=8)\n",
    "\n",
    "    plt.suptitle('10 Aggregated Samples per (E, αₛ, Q₀) – Hist2D, X,Y ∈ [-π,π]', y=0.995, fontsize=12)\n",
    "    plt.show()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
