{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aefdb534",
   "metadata": {},
   "source": [
    "\n",
    "# Figure 1 (g500, 1008 dataset): 4×3 Grid of Sample Events\n",
    "\n",
    "This notebook:\n",
    "- Reads the **1008 g500** dataset CSVs (train/val)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30687fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "\n",
    "# Get the absolute path of the parent folder (where config.py lives)\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), \"../..\"))\n",
    "if parent_dir not in sys.path:\n",
    "    sys.path.insert(0, parent_dir)\n",
    "\n",
    "print(\"Added to sys.path:\", parent_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1618c456",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "from train_utils.gpu_utils import get_device_summary\n",
    "from data.loader import get_dataloaders\n",
    "from train_utils.resume import init_resume_state\n",
    "from train_utils.resume import fill_trackers_from_history\n",
    "from train_utils.resume import load_pretrained_model\n",
    "from train_utils.training_loop import run_training_loop\n",
    "from train_utils.scheduler_utils import create_scheduler\n",
    "from train_utils.training_summary import finalize_training_summary\n",
    "from train_utils.training_summary import print_best_model_summary\n",
    "from train_utils.plot_metrics import plot_train_val_metrics\n",
    "from train_utils.plot_metrics import plot_loss_accuracy\n",
    "from train_utils.plot_metrics import plot_confusion_matrices\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91280b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from models.model_vit import create_model\n",
    "cfg=get_config(config_path=\"/\" \\\n",
    "\"experiments/exp_plotting_sample_g500_events_4x3/config/\" \\\n",
    "\"vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds1008_g500_sched-RLRP_preload_p4.yml\")\n",
    "\n",
    "# from models.model_mamba import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_alpha_s_evaluation_with_vit/config/\" \\\n",
    "# \"vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds7200000_g500_sched-RLRP_preload_p4.yml\")\n",
    "\n",
    "# cfg=get_config()\n",
    "print(json.dumps(vars(cfg), indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574e8ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38bc0afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "device= get_device_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37be9c24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data\n",
    "train_loader, val_loader, test_loader = get_dataloaders(cfg, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ff94323",
   "metadata": {},
   "source": [
    "\n",
    "- Picks **one sample** for each combination of **$Q_0 \\in \\{1.0, 1.5, 2.0, 2.5\\}$** and **$\\alpha_s \\in \\{0.2, 0.3, 0.4\\}$**.\n",
    "- Builds a **4×3** montage image (rows = $Q_0$; cols = $\\alpha_s$) **without subplots** by composing tiles with PIL.\n",
    "- Saves `figure1_g500_4x3.png` and shows it via a single `imshow` call (no seaborn, no custom colors).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c23e6de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Cell 6: Enhanced Single Plotter with Hist2D Style\n",
    "def plot_single_jet(x, y):\n",
    "    \"\"\"\n",
    "    Plot a single jet image with human-readable labels.\n",
    "\n",
    "    Args:\n",
    "        x (torch.Tensor): Input tensor of shape (1, 32, 32).\n",
    "        y (dict): Dictionary containing labels.\n",
    "\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    import math\n",
    "    import numpy as np\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "\n",
    "    # Maps for real parameter values\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # Fetch first batch sample (ensure x, y from Cell 5 are in scope)\n",
    "    img_tensor = x[0]                  # shape: (1,32,32)\n",
    "    img = img_tensor.squeeze()         # shape: (32,32)\n",
    "    if hasattr(img, 'cpu'):\n",
    "        img = img.cpu().numpy()\n",
    "\n",
    "    # Extract true labels\n",
    "    energy_idx = y['energy_loss_output'][0].item()\n",
    "    alpha_idx  = y['alpha_output'][0].item()\n",
    "    q0_idx     = y['q0_output'][0].item()\n",
    "\n",
    "    # Human-readable labels\n",
    "    e_str = energy_map[energy_idx]\n",
    "    α = alpha_vals[alpha_idx]\n",
    "    Q0 = q0_vals[q0_idx]\n",
    "\n",
    "    # Mask zero pixels\n",
    "    img_masked = np.ma.masked_where(img == 0, img)\n",
    "\n",
    "    # Define bin edges for [-π, π]\n",
    "    x_edges = np.linspace(-math.pi, math.pi, img.shape[1] + 1)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, img.shape[0] + 1)\n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(5, 5), dpi=200)\n",
    "    pcm = plt.pcolormesh(\n",
    "        x_edges, y_edges, img_masked,\n",
    "        norm=colors.LogNorm(vmin=img_masked.min() or 1e-6, vmax=img_masked.max()),\n",
    "        cmap='jet', shading='auto'\n",
    "    )\n",
    "    plt.colorbar(pcm, label='Normalized Intensity')\n",
    "    plt.title(f'{e_str}, αₛ={α}, Q₀={Q0}', fontsize=12)\n",
    "\n",
    "    # Shared axis ticks\n",
    "    plt.xticks([-math.pi, 0, math.pi], [r'$-\\pi$', '0', r'$\\pi$'])\n",
    "    plt.yticks([-math.pi, 0, math.pi], [r'$-\\pi$', '0', r'$\\pi$'])\n",
    "    plt.xlabel('X (φ)')\n",
    "    plt.ylabel('Y (η)')\n",
    "    plt.show()\n",
    "# plot_single_jet(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76bce7f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Cell 7: Grid of Hist2D Plots with Real Labels (12×10)\n",
    "def plot_grid_hist2d(agg_csv=None, root_dir=None, global_max=None):\n",
    "    \"\"\"\n",
    "    Plot a grid of hist2D images for each (alpha, q0) combination.\n",
    "\n",
    "    Args:\n",
    "        agg_csv (str): Path to the aggregated CSV file.\n",
    "        root_dir (str): Path to the root directory of the dataset.\n",
    "        global_max (float): Global maximum value for normalization.\n",
    "\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    import os\n",
    "    import math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    from itertools import product\n",
    "\n",
    "    # Paths & constants\n",
    "    dataset_root = os.path.expanduser(\n",
    "        \"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\"\n",
    "    )\n",
    "    agg_csv    = os.path.join(dataset_root, 'file_labels_aggregated_ds1008_g500_train.csv')\n",
    "    global_max = 121.79151153564453\n",
    "\n",
    "    # Reverse‐maps for real values\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # Load aggregated entries\n",
    "    df = pd.read_csv(agg_csv)\n",
    "\n",
    "    # All (alpha_idx, q0_idx) combos → 3×4 = 12 rows\n",
    "    combos = list(product([0,1,2], [0,1,2,3]))\n",
    "    n_rows, n_cols = len(combos), 10\n",
    "\n",
    "    # Create subplots\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols*1.5, n_rows*1.2),\n",
    "        sharex='col', sharey='row',\n",
    "        dpi=200\n",
    "    )\n",
    "\n",
    "    # Tight layout\n",
    "    fig.subplots_adjust(\n",
    "        left   = 0.15,  # room for row labels\n",
    "        right  = 0.97,\n",
    "        top    = 0.96,\n",
    "        bottom = 0.02,\n",
    "        hspace = 0.2,\n",
    "        wspace = 0.1\n",
    "    )\n",
    "\n",
    "    # Bin edges\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    for i, (a_idx, q_idx) in enumerate(combos):\n",
    "        subset = df[(df['alpha']==a_idx) & (df['q0']==q_idx)]\n",
    "        samples = subset.sample(n=n_cols, replace=len(subset)<n_cols, random_state=0)\n",
    "        for j, (_, entry) in enumerate(samples.iterrows()):\n",
    "            ax = axes[i, j]\n",
    "            # Load & average\n",
    "            imgs = [\n",
    "                np.load(os.path.join(dataset_root, p)).astype(np.float32)/global_max\n",
    "                for p in entry['file_paths'].split('|')\n",
    "            ]\n",
    "            avg = np.mean(imgs, axis=0)\n",
    "            avg_masked = np.ma.masked_where(avg==0, avg)\n",
    "            pcm = ax.pcolormesh(\n",
    "                x_edges, y_edges, avg_masked,\n",
    "                norm=colors.LogNorm(vmin=avg_masked.min() or 1e-6, vmax=avg_masked.max()),\n",
    "                cmap='jet', shading='auto'\n",
    "            )\n",
    "            ax.set_xticks([]); ax.set_yticks([])\n",
    "            # Real‐value row label\n",
    "            if j == 0:\n",
    "                e_str = energy_map[entry['energy_loss']]\n",
    "                α = alpha_vals[a_idx]\n",
    "                Q0 = q0_vals[q_idx]\n",
    "                ax.text(-0.35, 0.5,\n",
    "                        f'{e_str}\\nαₛ={α}\\nQ₀={Q0}',\n",
    "                        transform=ax.transAxes,\n",
    "                        va='center', ha='right',\n",
    "                        fontsize=8)\n",
    "\n",
    "    # Shared ticks bottom row & left column\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=6)\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=6)\n",
    "\n",
    "    # Colorbar\n",
    "    cbar = fig.colorbar(pcm, ax=axes, fraction=0.015, pad=0.01)\n",
    "    cbar.set_label(r'$p_T$', fontsize=8)\n",
    "\n",
    "    plt.suptitle('10 Aggregated Samples per (E, αₛ, Q₀) – Hist2D, X,Y ∈ [-π,π]', y=0.995, fontsize=12)\n",
    "    plt.savefig('jet_images_g500_12x10_ten_sample_per_combo.png', dpi=300, bbox_inches=\"tight\")\n",
    "    plt.savefig('jet_images_g500_12x10_ten_sample_per_combo.pdf', bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "# plot_grid_hist2d(\n",
    "#     agg_csv='/home/johndoe/server/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_g1000.csv',\n",
    "#     root_dir='/home/johndoe/server/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/',\n",
    "#     global_max=121.79151153564453\n",
    "# )\n",
    "plot_grid_hist2d()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05d6fb59",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_avg_hist2d(entry, dataset_root, global_max=121.79151153564453):\n",
    "    import numpy as np\n",
    "    imgs = [\n",
    "        np.load(os.path.join(dataset_root, p)).astype(np.float32)/global_max\n",
    "        for p in entry['file_paths'].split('|')\n",
    "    ]\n",
    "    avg = np.mean(imgs, axis=0)\n",
    "    avg_masked = np.ma.masked_where(avg==0, avg)\n",
    "    return avg_masked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3bab69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Cell 7: Grid of Hist2D Plots with Real Labels (12×10)\n",
    "def plot_grid_hist2d_one_sample_per_combo(agg_csv='file_labels_aggregated_ds1008_g500_train.csv', \n",
    "                     dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\"):\n",
    "    \"\"\"\n",
    "    Plot a grid of hist2D images for each (alpha, q0) combination.\n",
    "\n",
    "    Args:\n",
    "        agg_csv (str): Path to the aggregated CSV file.\n",
    "        root_dir (str): Path to the root directory of the dataset.\n",
    "        global_max (float): Global maximum value for normalization.\n",
    "\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    import os\n",
    "    import math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    from itertools import product\n",
    "\n",
    "    # Paths & constants\n",
    "    dataset_root = os.path.expanduser(\n",
    "        dataset_root\n",
    "    )\n",
    "    agg_csv    = os.path.join(dataset_root, agg_csv)\n",
    "\n",
    "    # Reverse‐maps for real values\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # Load aggregated entries\n",
    "    df = pd.read_csv(agg_csv)\n",
    "\n",
    "    # All (alpha_idx, q0_idx) combos → 3×4 = 12 rows\n",
    "    combos = list(product([0,1,2], [0,1,2,3]))\n",
    "    n_rows, n_cols = len(combos), 2\n",
    "\n",
    "    # Create subplots\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols*1.5, n_rows*1.2),\n",
    "        sharex='col', sharey='row',\n",
    "        dpi=300\n",
    "    )\n",
    "\n",
    "    # Tight layout\n",
    "    fig.subplots_adjust(\n",
    "        left   = 0.15,  # room for row labels\n",
    "        right  = 0.97,\n",
    "        top    = 0.96,\n",
    "        bottom = 0.02,\n",
    "        hspace = 0.2,\n",
    "        wspace = 0.1\n",
    "    )\n",
    "\n",
    "    # Bin edges\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    for i, (a_idx, q_idx) in enumerate(combos):\n",
    "        subset = df[(df['alpha']==a_idx) & (df['q0']==q_idx)]\n",
    "        samples = subset.sample(n=n_cols, replace=len(subset)<n_cols, random_state=0)\n",
    "\n",
    "        for j, (_, entry) in enumerate(samples.iterrows()):\n",
    "            \n",
    "            ax = axes[i, j]\n",
    "            # Load & average\n",
    "            avg_masked = plot_avg_hist2d(entry, dataset_root)\n",
    "            pcm = ax.pcolormesh(\n",
    "                x_edges, y_edges, avg_masked,\n",
    "                norm=colors.LogNorm(vmin=avg_masked.min() or 1e-6, vmax=avg_masked.max()),\n",
    "                cmap='jet', shading='auto'\n",
    "            )\n",
    "            ax.set_xticks([]); ax.set_yticks([])\n",
    "            # Real‐value row label\n",
    "            if j == 0:\n",
    "                e_str = energy_map[entry['energy_loss']]\n",
    "                α = alpha_vals[a_idx]\n",
    "                Q0 = q0_vals[q_idx]\n",
    "                ax.text(-0.35, 0.5,\n",
    "                        f'{e_str}\\nαₛ={α}\\nQ₀={Q0}',\n",
    "                        transform=ax.transAxes,\n",
    "                        va='center', ha='right',\n",
    "                        fontsize=8)\n",
    "\n",
    "    # Shared ticks bottom row & left column\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=6)\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=6)\n",
    "\n",
    "    # Colorbar\n",
    "    cbar = fig.colorbar(pcm, ax=axes, fraction=0.015, pad=0.01)\n",
    "    cbar.set_label('Normalized Intensity', fontsize=8)\n",
    "\n",
    "    plt.suptitle('10 Aggregated Samples per (E, αₛ, Q₀) – Hist2D, X,Y ∈ [-π,π]', y=0.995, fontsize=12)\n",
    "    plt.show()\n",
    "\n",
    "plot_grid_hist2d_one_sample_per_combo()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ad7bc31",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_hist2d_one_sample_per_row(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv', \n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\"\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot a vertical stack of hist2D images (one per row) for each (alpha, q0) combination.\n",
    "    \"\"\"\n",
    "\n",
    "    import os, math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    from itertools import product\n",
    "\n",
    "    # Paths\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    agg_csv = os.path.join(dataset_root, agg_csv)\n",
    "\n",
    "    # Reverse maps\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # Load data\n",
    "    df = pd.read_csv(agg_csv)\n",
    "\n",
    "    # 12 combos (3 αs × 4 Q₀)\n",
    "    combos = list(product([0,1,2], [0,1,2,3]))\n",
    "    n_rows = len(combos)\n",
    "\n",
    "    # Subplots: one column only\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, 1,\n",
    "        figsize=(2.5, n_rows * 1.5),\n",
    "        dpi=300\n",
    "    )\n",
    "\n",
    "    if n_rows == 1:\n",
    "        axes = [axes]  # ensure iterable\n",
    "\n",
    "    # Bin edges\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    for i, (a_idx, q_idx) in enumerate(combos):\n",
    "        ax = axes[i]\n",
    "        subset = df[(df['alpha']==a_idx) & (df['q0']==q_idx)]\n",
    "        if subset.empty:\n",
    "            ax.text(0.5, 0.5, \"No Data\", ha=\"center\", va=\"center\")\n",
    "            ax.axis(\"off\")\n",
    "            continue\n",
    "\n",
    "        entry = subset.sample(n=1, random_state=0).iloc[0]\n",
    "        avg_masked = plot_avg_hist2d(entry, dataset_root)\n",
    "\n",
    "        pcm = ax.pcolormesh(\n",
    "            x_edges, y_edges, avg_masked,\n",
    "            norm=colors.LogNorm(vmin=avg_masked.min() or 1e-6, vmax=avg_masked.max()),\n",
    "            cmap='jet', shading='auto'\n",
    "        )\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "\n",
    "        # Row label\n",
    "        e_str = energy_map[entry['energy_loss']]\n",
    "        α = alpha_vals[a_idx]\n",
    "        Q0 = q0_vals[q_idx]\n",
    "        ax.set_ylabel(f'{e_str}\\nαₛ={α}\\nQ₀={Q0}', fontsize=8, rotation=0, labelpad=30, va='center')\n",
    "\n",
    "    # Colorbar on the side\n",
    "    cbar = fig.colorbar(pcm, ax=axes, fraction=0.02, pad=0.02)\n",
    "    cbar.set_label('Normalized Intensity', fontsize=8)\n",
    "\n",
    "    plt.suptitle('One Sample per (E, αₛ, Q₀) – Hist2D', y=0.995, fontsize=12)\n",
    "    plt.tight_layout(rect=[0,0,1,0.97])\n",
    "    plt.show()\n",
    "plot_grid_hist2d_one_sample_per_row()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "296031d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_hist2d_one_sample_per_combo_4x3(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    module_value=None,  # None, \"MATTER\", or \"MATTER-LBT\"\n",
    "    random_state=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot a 4x3 grid (rows=Q0, cols=alpha_s), one Hist2D sample per (alpha,q0) combo.\n",
    "    If module_value is provided, restrict picks to that energy-loss module.\n",
    "    \"\"\"\n",
    "    import os, math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "\n",
    "    # ---- paths & constants\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    agg_csv_path = os.path.join(dataset_root, agg_csv)\n",
    "\n",
    "    # label maps\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    energy_inv = {v: k for k, v in energy_map.items()}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    alphas = [0, 1, 2]   # indices for columns\n",
    "    q0s    = [0, 1, 2, 3]  # indices for rows\n",
    "\n",
    "    # ---- load\n",
    "    df = pd.read_csv(agg_csv_path)\n",
    "\n",
    "    # optional module filter\n",
    "    if module_value is not None:\n",
    "        if module_value not in energy_inv:\n",
    "            raise ValueError(\"module_value must be 'MATTER' or 'MATTER-LBT' or None\")\n",
    "        df = df[df['energy_loss'] == energy_inv[module_value]]\n",
    "\n",
    "    # ---- figure\n",
    "    n_rows, n_cols = len(q0s), len(alphas)\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols * 2.2, n_rows * 2.2),\n",
    "        dpi=300,\n",
    "        sharex=True, sharey=True\n",
    "    )\n",
    "\n",
    "    # bins (32×32 edges)\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    # first pass: collect one sample image per cell to get shared vmin/vmax\n",
    "    cell_images = [[None for _ in alphas] for _ in q0s]\n",
    "    vmin_list, vmax_list = [], []\n",
    "    rng = np.random.RandomState(random_state)\n",
    "\n",
    "    for r, q_idx in enumerate(q0s):\n",
    "        for c, a_idx in enumerate(alphas):\n",
    "            subset = df[(df['alpha'] == a_idx) & (df['q0'] == q_idx)]\n",
    "            if subset.empty:\n",
    "                continue\n",
    "            entry = subset.sample(n=1, random_state=rng).iloc[0]\n",
    "            img = plot_avg_hist2d(entry, dataset_root)  # assumes available in your notebook\n",
    "            cell_images[r][c] = (img, entry)\n",
    "            # guard zero for LogNorm\n",
    "            mn = np.nanmin(img[img > 0]) if (img > 0).any() else 1e-6\n",
    "            mx = np.nanmax(img)\n",
    "            vmin_list.append(max(mn, 1e-6))\n",
    "            vmax_list.append(mx)\n",
    "\n",
    "    # shared color scale (log)\n",
    "    if vmin_list and vmax_list:\n",
    "        vmin = min(vmin_list)\n",
    "        vmax = max(vmax_list)\n",
    "    else:\n",
    "        vmin, vmax = 1e-6, 1.0  # fallback\n",
    "\n",
    "    # second pass: plot\n",
    "    last_pcm = None\n",
    "    for r, q_idx in enumerate(q0s):\n",
    "        for c, a_idx in enumerate(alphas):\n",
    "            ax = axes[r, c]\n",
    "            ax.set_xticks([]); ax.set_yticks([])\n",
    "\n",
    "            payload = cell_images[r][c]\n",
    "            if payload is None:\n",
    "                ax.text(0.5, 0.5, \"N/A\", ha=\"center\", va=\"center\", fontsize=9)\n",
    "                continue\n",
    "\n",
    "            img, entry = payload\n",
    "            pcm = ax.pcolormesh(\n",
    "                x_edges, y_edges, img,\n",
    "                norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "                cmap='jet', shading='auto'\n",
    "            )\n",
    "            last_pcm = pcm\n",
    "\n",
    "            # top col titles: alpha\n",
    "            if r == 0:\n",
    "                ax.set_title(f\"αₛ = {alpha_vals[a_idx]}\", fontsize=10, pad=6)\n",
    "\n",
    "            # left row labels: Q0 (+ module if fixed)\n",
    "            if c == 0:\n",
    "                row_label = f\"Q₀ = {q0_vals[q_idx]}\"\n",
    "                if module_value is None:\n",
    "                    # show the sampled entry's module so row label is still clear\n",
    "                    row_label = f\"{row_label}\\n{energy_map[entry['energy_loss']]}\"\n",
    "                else:\n",
    "                    row_label = f\"{row_label}\\n{module_value}\"\n",
    "                ax.text(-0.15, 0.5, row_label, va='center', ha='right',\n",
    "                        rotation=90, transform=ax.transAxes, fontsize=10)\n",
    "\n",
    "    # ticks only on outer edges for readability\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=8)\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=8)\n",
    "\n",
    "    # colorbar\n",
    "    if last_pcm is not None:\n",
    "        cbar = fig.colorbar(last_pcm, ax=axes, fraction=0.02, pad=0.02)\n",
    "        cbar.set_label('Normalized Intensity', fontsize=9)\n",
    "\n",
    "    title_mod = f\" — {module_value}\" if module_value else \"\"\n",
    "    plt.suptitle(f'One sample per (αₛ, Q₀){title_mod} — Hist2D, X,Y ∈ [−π, π]', y=0.995, fontsize=12)\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.97])\n",
    "    plt.show()\n",
    "plot_grid_hist2d_one_sample_per_combo_4x3()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77d893aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_hist2d_one_sample_per_combo_4x3(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    module_value=None,  # None, \"MATTER\", or \"MATTER-LBT\"\n",
    "    random_state=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot a 4x3 grid (rows=Q0, cols=alpha_s), one Hist2D sample per (alpha,q0) combo.\n",
    "    Adds a single horizontal colorbar at the bottom labeled 'Pseudorapidity'.\n",
    "    \"\"\"\n",
    "    import os, math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "\n",
    "    # ---- paths & constants\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    agg_csv_path = os.path.join(dataset_root, agg_csv)\n",
    "\n",
    "    # label maps\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    energy_inv = {v: k for k, v in energy_map.items()}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    alphas = [0, 1, 2]   # indices for columns\n",
    "    q0s    = [0, 1, 2, 3]  # indices for rows\n",
    "\n",
    "    # ---- load\n",
    "    df = pd.read_csv(agg_csv_path)\n",
    "    if module_value is not None:\n",
    "        df = df[df['energy_loss'] == energy_inv[module_value]]\n",
    "\n",
    "    # ---- figure\n",
    "    n_rows, n_cols = len(q0s), len(alphas)\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols * 2.2, n_rows * 2.2),\n",
    "        dpi=300,\n",
    "        sharex=True, sharey=True\n",
    "    )\n",
    "\n",
    "    # bins\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    # collect one sample per cell\n",
    "    cell_images = [[None for _ in alphas] for _ in q0s]\n",
    "    vmin_list, vmax_list = [], []\n",
    "    rng = np.random.RandomState(random_state)\n",
    "\n",
    "    for r, q_idx in enumerate(q0s):\n",
    "        for c, a_idx in enumerate(alphas):\n",
    "            subset = df[(df['alpha'] == a_idx) & (df['q0'] == q_idx)]\n",
    "            if subset.empty:\n",
    "                continue\n",
    "            entry = subset.sample(n=1, random_state=rng).iloc[0]\n",
    "            img = plot_avg_hist2d(entry, dataset_root)\n",
    "            cell_images[r][c] = (img, entry)\n",
    "            mn = np.nanmin(img[img > 0]) if (img > 0).any() else 1e-6\n",
    "            mx = np.nanmax(img)\n",
    "            vmin_list.append(max(mn, 1e-6))\n",
    "            vmax_list.append(mx)\n",
    "\n",
    "    # shared scale\n",
    "    vmin = min(vmin_list) if vmin_list else 1e-6\n",
    "    vmax = max(vmax_list) if vmax_list else 1.0\n",
    "    print(f\"Color scale: vmin={vmin}, vmax={vmax}\")\n",
    "    last_pcm = None\n",
    "    for r, q_idx in enumerate(q0s):\n",
    "        for c, a_idx in enumerate(alphas):\n",
    "            ax = axes[r, c]\n",
    "            ax.set_xticks([]); ax.set_yticks([])\n",
    "            payload = cell_images[r][c]\n",
    "            if payload is None:\n",
    "                ax.text(0.5, 0.5, \"N/A\", ha=\"center\", va=\"center\", fontsize=9)\n",
    "                continue\n",
    "            img, entry = payload\n",
    "            pcm = ax.pcolormesh(\n",
    "                x_edges, y_edges, img,\n",
    "                norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "                cmap='jet', shading='auto'\n",
    "            )\n",
    "            last_pcm = pcm\n",
    "            if r == 0:\n",
    "                ax.set_title(f\"αₛ = {alpha_vals[a_idx]}\", fontsize=10, pad=6)\n",
    "            if c == 0:\n",
    "                row_label = f\"Q₀ = {q0_vals[q_idx]}\"\n",
    "                if module_value is None:\n",
    "                    row_label = f\"{row_label}:{energy_map[entry['energy_loss']]}\"\n",
    "                else:\n",
    "                    row_label = f\"{row_label}\\n{module_value}\"\n",
    "                ax.text(-0.22, 0.5, row_label, va='center', ha='right',\n",
    "                        rotation=90, transform=ax.transAxes, fontsize=10)\n",
    "\n",
    "    # outer ticks\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=9)\n",
    "        ax.set_xlabel(r'$\\phi$', fontsize=9, labelpad=-1)\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=9)\n",
    "        ax.set_ylabel(r'$\\eta$', fontsize=9, rotation=90, labelpad=-5)\n",
    "\n",
    "    # horizontal colorbar at bottom\n",
    "    if last_pcm is not None:\n",
    "        # [left, bottom, width, height] in figure coordinates (0–1)\n",
    "        cax = fig.add_axes([0.1, -0.02, 0.56, 0.03])  \n",
    "        # left=0.1 → margin from left edge\n",
    "        # bottom=0.05 → distance from bottom\n",
    "        # width=0.35 → how long the bar is\n",
    "        # height=0.03 → thickness of the bar\n",
    "        cbar = fig.colorbar(\n",
    "            last_pcm, cax=cax,\n",
    "            fraction=0.03, pad=-0.2,\n",
    "            orientation='horizontal', location='bottom'\n",
    "        )\n",
    "        # cbar.set_label(r'p_T', fontsize=9)\n",
    "        cbar.ax.set_title(r\"$p_T$\", fontsize=9, pad=6, loc='left',x=-0.12, y=0.1)\n",
    "        cbar.ax.tick_params(labelsize=9)   # change 10 to whatever size you want\n",
    "\n",
    "    \n",
    "    import matplotlib.patches as mpatches\n",
    "\n",
    "    # --- after plotting and colorbar ---\n",
    "    # Create invisible patches just for the legend labels\n",
    "    legend_elements = [\n",
    "        mpatches.Patch(color='none', label=r\"$\\eta$: pseudorapidity\"),\n",
    "        mpatches.Patch(color='none', label=r\"$\\phi$: azimuthal angle\"),\n",
    "        mpatches.Patch(color='none', label=r\"$p_T$: transverse momentum\")\n",
    "    ]\n",
    "\n",
    "    # Add legend to the bottom right\n",
    "    fig.legend(\n",
    "        handles=legend_elements,\n",
    "        loc='lower right',\n",
    "        bbox_to_anchor=(0.97, -0.06), \n",
    "        fontsize=9,\n",
    "        frameon=True,\n",
    "        framealpha=0.9,\n",
    "        handlelength=0,  # no marker box\n",
    "        handletextpad=-0.2\n",
    "    )\n",
    "    # plt.tight_layout(rect=[0, 0.05, 1, 0.95])\n",
    "    plt.tight_layout()\n",
    "    base=\"jet_images_g500_4x3_one_sample_per_combo\"\n",
    "    fig.savefig(base + \".png\", dpi=300, bbox_inches=\"tight\")\n",
    "    fig.savefig(base + \".pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "plot_grid_hist2d_one_sample_per_combo_4x3()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "325a1e7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_hist2d_alpha_rows_q0_cols(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    module_value=None,  # None, \"MATTER\", or \"MATTER-LBT\" (optional hard filter)\n",
    "    random_state=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot a 3x4 grid with rows = alpha_s (0.2,0.3,0.4) and columns = Q0/energy module\n",
    "    (Q0 in {1.0,1.5,2.0,2.5}, with module = MATTER for 1.0 and MATTER-LBT otherwise).\n",
    "    One Hist2D sample per (alpha, q0) combo. Shared LogNorm color scale, bottom colorbar.\n",
    "    \"\"\"\n",
    "    import os, math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    import matplotlib.patches as mpatches\n",
    "\n",
    "    # ---- paths & constants\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    agg_csv_path = os.path.join(dataset_root, agg_csv)\n",
    "\n",
    "    # label maps\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    energy_inv = {v: k for k, v in energy_map.items()}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # >>> FORCE 3x4: rows = alpha (3), cols = Q0 (4)\n",
    "    alphas = [0, 1, 2]        # rows\n",
    "    q0s    = [0, 1, 2, 3]     # cols\n",
    "\n",
    "    # ---- load\n",
    "    df = pd.read_csv(agg_csv_path)\n",
    "    if module_value is not None:\n",
    "        df = df[df['energy_loss'] == energy_inv[module_value]]\n",
    "\n",
    "    # ---- figure (rows, cols) = (alpha, Q0)\n",
    "    n_rows, n_cols = len(alphas), len(q0s)\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols * 2.25, n_rows * 2.35),\n",
    "        dpi=300,\n",
    "        sharex=True, sharey=True\n",
    "    )\n",
    "    fig.subplots_adjust(left=0.08, right=0.98, top=0.90, bottom=0.16, wspace=0.1, hspace=0.1)\n",
    "\n",
    "    # bins\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    # collect one sample per cell\n",
    "    cell_images = [[None for _ in q0s] for _ in alphas]\n",
    "    vmin_list, vmax_list = [], []\n",
    "    rng = np.random.RandomState(random_state)\n",
    "\n",
    "    for r, a_idx in enumerate(alphas):\n",
    "        for c, q_idx in enumerate(q0s):\n",
    "            subset = df[(df['alpha'] == a_idx) & (df['q0'] == q_idx)]\n",
    "            if subset.empty:\n",
    "                continue\n",
    "            entry = subset.sample(n=1, random_state=rng).iloc[0]\n",
    "            img = plot_avg_hist2d(entry, dataset_root)\n",
    "            cell_images[r][c] = (img, entry)\n",
    "            if (img > 0).any():\n",
    "                vmin_list.append(max(np.nanmin(img[img > 0]), 1e-6))\n",
    "                vmax_list.append(np.nanmax(img))\n",
    "            else:\n",
    "                vmin_list.append(1e-6); vmax_list.append(1.0)\n",
    "\n",
    "    # shared color scale\n",
    "    vmin = min(vmin_list) if vmin_list else 1e-6\n",
    "    vmax = max(vmax_list) if vmax_list else 1.0\n",
    "    print(f\"Color scale: vmin={vmin}, vmax={vmax}\")\n",
    "\n",
    "    # helper: module label from Q0 if not filtered explicitly\n",
    "    def module_for_q0(q_idx, fallback_energy_int):\n",
    "        if module_value is not None:\n",
    "            return module_value\n",
    "        # physics rule: Q0=1.0 => MATTER, else MATTER-LBT; fall back to entry if present\n",
    "        if q0_vals[q_idx] == 1.0:\n",
    "            return 'MATTER'\n",
    "        if q0_vals[q_idx] in (1.5, 2.0, 2.5):\n",
    "            return 'MATTER-LBT'\n",
    "        return energy_map.get(int(fallback_energy_int), 'MATTER')\n",
    "\n",
    "    last_pcm = None\n",
    "    for r, a_idx in enumerate(alphas):\n",
    "        for c, q_idx in enumerate(q0s):\n",
    "            ax = axes[r, c]\n",
    "            ax.set_xticks([]); ax.set_yticks([])\n",
    "            payload = cell_images[r][c]\n",
    "            if payload is None:\n",
    "                ax.text(0.5, 0.5, \"N/A\", ha=\"center\", va=\"center\", fontsize=9)\n",
    "                continue\n",
    "            img, entry = payload\n",
    "            pcm = ax.pcolormesh(\n",
    "                x_edges, y_edges, img,\n",
    "                norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "                cmap='jet', shading='auto'\n",
    "            )\n",
    "            last_pcm = pcm\n",
    "\n",
    "            # column titles (Q0 / module) on top row\n",
    "            if r == 0:\n",
    "                mod_lbl = module_for_q0(q_idx, entry['energy_loss'])\n",
    "                ax.set_title(f\"Q₀ = {q0_vals[q_idx]} : {mod_lbl}\", fontsize=10, pad=6)\n",
    "\n",
    "            # row labels (alpha_s) on left column\n",
    "            if c == 0:\n",
    "                row_label = f\"αₛ = {alpha_vals[a_idx]}\"\n",
    "                ax.text(-0.20, 0.5, row_label, va='center', ha='right',\n",
    "                        rotation=90, transform=ax.transAxes, fontsize=10)\n",
    "\n",
    "    # outer ticks\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=9)\n",
    "        ax.set_xlabel(r'$\\phi$', fontsize=9, labelpad=-1)\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=9)\n",
    "        ax.set_ylabel(r'$\\eta$', fontsize=9, rotation=90, labelpad=-5)\n",
    "\n",
    "    # horizontal colorbar at bottom\n",
    "    if last_pcm is not None:\n",
    "        cax = fig.add_axes([0.12, 0.06, 0.62, 0.025])\n",
    "        cbar = fig.colorbar(last_pcm, cax=cax, orientation='horizontal')\n",
    "        cbar.ax.set_title(r\"$p_T$\", fontsize=9, pad=6, loc='left', x=-0.05, y=0.22)\n",
    "        cbar.ax.tick_params(labelsize=9)\n",
    "\n",
    "    # legend with physics labels (no markers)\n",
    "    legend_elements = [\n",
    "        mpatches.Patch(color='none', label=r\"$\\eta$: pseudorapidity\"),\n",
    "        mpatches.Patch(color='none', label=r\"$\\phi$: azimuthal angle\"),\n",
    "        mpatches.Patch(color='none', label=r\"$p_T$: transverse momentum\")\n",
    "    ]\n",
    "    fig.legend(\n",
    "        handles=legend_elements,\n",
    "        loc='lower right',\n",
    "        bbox_to_anchor=(0.98, 0.01),\n",
    "        fontsize=9,\n",
    "        frameon=True, framealpha=0.9,\n",
    "        handlelength=0, handletextpad=0.2\n",
    "    )\n",
    "    base = \"jet_images_g500_3x4_one_sample_per_combo\"\n",
    "    fig.savefig(base + \".png\", dpi=300, bbox_inches=\"tight\")\n",
    "    fig.savefig(base + \".pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "plot_grid_hist2d_alpha_rows_q0_cols()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
