{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aefdb534",
   "metadata": {},
   "source": [
    "\n",
    "# Figure 1 (g500, 1008 dataset): 4×3 Grid of Sample Events\n",
    "\n",
    "This notebook:\n",
    "- Reads the **1008 g500** dataset CSVs (train/val)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30687fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "\n",
    "# Get the absolute path of the parent folder (where config.py lives)\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), \"../..\"))\n",
    "if parent_dir not in sys.path:\n",
    "    sys.path.insert(0, parent_dir)\n",
    "\n",
    "print(\"Added to sys.path:\", parent_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1618c456",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "from train_utils.gpu_utils import get_device_summary\n",
    "from data.loader import get_dataloaders\n",
    "from train_utils.resume import init_resume_state\n",
    "from train_utils.resume import fill_trackers_from_history\n",
    "from train_utils.resume import load_pretrained_model\n",
    "from train_utils.training_loop import run_training_loop\n",
    "from train_utils.scheduler_utils import create_scheduler\n",
    "from train_utils.training_summary import finalize_training_summary\n",
    "from train_utils.training_summary import print_best_model_summary\n",
    "from train_utils.plot_metrics import plot_train_val_metrics\n",
    "from train_utils.plot_metrics import plot_loss_accuracy\n",
    "from train_utils.plot_metrics import plot_confusion_matrices\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91280b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from models.model_vit import create_model\n",
    "cfg=get_config(config_path=\"/\" \\\n",
    "\"experiments/exp_plotting_event_aggregation_evolotion/config/\" \\\n",
    "\"vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds1008_g500_sched-RLRP_preload_p4.yml\")\n",
    "\n",
    "# from models.model_mamba import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_alpha_s_evaluation_with_vit/config/\" \\\n",
    "# \"vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds7200000_g500_sched-RLRP_preload_p4.yml\")\n",
    "\n",
    "# cfg=get_config()\n",
    "print(json.dumps(vars(cfg), indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574e8ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38bc0afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "device= get_device_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37be9c24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data\n",
    "train_loader, val_loader, test_loader = get_dataloaders(cfg, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ff94323",
   "metadata": {},
   "source": [
    "\n",
    "- Picks **one sample** for each combination of **$Q_0 \\in \\{1.0, 1.5, 2.0, 2.5\\}$** and **$\\alpha_s \\in \\{0.2, 0.3, 0.4\\}$**.\n",
    "- Builds a **4×3** montage image (rows = $Q_0$; cols = $\\alpha_s$) **without subplots** by composing tiles with PIL.\n",
    "- Saves `figure1_g500_4x3.png` and shows it via a single `imshow` call (no seaborn, no custom colors).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05d6fb59",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_avg_hist2d(entry, dataset_root, global_max=121.79151153564453):\n",
    "    import numpy as np\n",
    "    imgs = [\n",
    "        np.load(os.path.join(dataset_root, p)).astype(np.float32)/global_max\n",
    "        for p in entry['file_paths'].split('|')\n",
    "    ]\n",
    "    avg = np.mean(imgs, axis=0)\n",
    "    avg_masked = np.ma.masked_where(avg==0, avg)\n",
    "    return avg_masked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4621234",
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes1=(1,2,4,8,16,32,64,128,256,500)\n",
    "sizes2=(1,4,10,16,25,32,50,75,100,150,200,300,400,500)\n",
    "sizes3=(1,2,4,8,10,16,25,32,50,64,100,125,128,200,256,300,400,500)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2d46bae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_evolution_hist2d_for_cell(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    global_max=121.79151153564453,\n",
    "    alpha_idx=0,          # 0->0.2, 1->0.3, 2->0.4\n",
    "    q0_idx=0,             # 0->1.0, 1->1.5, 2->2.0, 3->2.5\n",
    "    module_value=None,    # None, \"MATTER\", or \"MATTER-LBT\"\n",
    "    sizes=sizes3,\n",
    "    random_state=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Build a 2x5 grid showing the evolution from a single event to an average over 500 events.\n",
    "    We gather raw .npy paths from multiple aggregated rows (same labels) until we have >= max(sizes).\n",
    "    \"\"\"\n",
    "    import os, math\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    import matplotlib.patches as mpatches\n",
    "\n",
    "    # ---- label maps\n",
    "    energy_map = {0: 'MATTER', 1: 'MATTER-LBT'}\n",
    "    energy_inv = {v: k for k, v in energy_map.items()}\n",
    "    alpha_vals = {0: 0.2, 1: 0.3, 2: 0.4}\n",
    "    q0_vals    = {0: 1.0, 1: 1.5, 2: 2.0, 3: 2.5}\n",
    "\n",
    "    # ---- load CSV and filter rows matching the chosen cell\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    agg_csv_path = os.path.join(dataset_root, agg_csv)\n",
    "    df = pd.read_csv(agg_csv_path)\n",
    "\n",
    "    mask = (df['alpha'] == alpha_idx) & (df['q0'] == q0_idx)\n",
    "    if module_value is not None:\n",
    "        mask &= df['energy_loss'] == energy_inv[module_value]\n",
    "    df_cell = df[mask].sample(frac=1.0, random_state=random_state)  # shuffle for randomness\n",
    "    if df_cell.empty:\n",
    "        raise ValueError(\"No entries found for the requested (alpha, q0, module).\")\n",
    "\n",
    "    # ---- collect raw file paths until we have enough to cover 'sizes'\n",
    "    need = max(sizes)\n",
    "    file_list = []\n",
    "    for _, row in df_cell.iterrows():\n",
    "        # each row has a '|' separated list of raw event .npy files\n",
    "        paths = [os.path.join(dataset_root, p) for p in row['file_paths'].split('|')]\n",
    "        file_list.extend(paths)\n",
    "        if len(file_list) >= need:\n",
    "            break\n",
    "    if len(file_list) < need:\n",
    "        raise ValueError(f\"Only found {len(file_list)} files, fewer than requested {need}.\")\n",
    "\n",
    "    # ---- load the first 'need' images and normalize; build cumulative averages efficiently\n",
    "    #     imgs shape: (need, 32, 32); use float32 to keep memory small\n",
    "    imgs = []\n",
    "    for p in file_list[:need]:\n",
    "        arr = np.load(p).astype(np.float32) / float(global_max)\n",
    "        imgs.append(arr)\n",
    "    imgs = np.stack(imgs, axis=0)  # (N,H,W)\n",
    "    cumsum = np.cumsum(imgs, axis=0)  # cumulative sum for fast prefix means\n",
    "\n",
    "    # ---- precompute all averaged images and shared color scale (ignoring zeros)\n",
    "    avgs = []\n",
    "    vmins, vmaxs = [], []\n",
    "    for n in sizes:\n",
    "        avg = cumsum[n-1] / float(n)\n",
    "        avg = np.ma.masked_where(avg == 0, avg)\n",
    "        avgs.append(avg)\n",
    "        # compute robust min>0 and max for log scale\n",
    "        has_pos = np.any(avg > 0)\n",
    "        mn = np.min(avg[avg > 0]) if has_pos else 1e-6\n",
    "        mx = np.max(avg) if has_pos else 1.0\n",
    "        vmins.append(max(mn, 1e-6))\n",
    "        vmaxs.append(mx)\n",
    "\n",
    "    vmin = min(vmins) if vmins else 1e-6\n",
    "    vmax = max(vmaxs) if vmaxs else 1.0\n",
    "    print(f\"[evolution] shared log color scale: vmin={vmin:.3e}, vmax={vmax:.3e}\")\n",
    "\n",
    "    # ---- plotting\n",
    "    # grid: 2 rows x 5 cols for 10 panels\n",
    "    nrows, ncols = 2, 9\n",
    "    assert len(sizes) == nrows * ncols, \"sizes must have length 10 (e.g., 1..256 plus 500).\"\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows, ncols, figsize=(ncols*2.3, nrows*2.4), dpi=300, sharex=True, sharey=True\n",
    "    )\n",
    "\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    last_pcm = None\n",
    "    for idx, n in enumerate(sizes):\n",
    "        r, c = divmod(idx, ncols)\n",
    "        ax = axes[r, c]\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "        pcm = ax.pcolormesh(\n",
    "            x_edges, y_edges, avgs[idx],\n",
    "            norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "            cmap='inferno', shading='auto'\n",
    "        )\n",
    "        last_pcm = pcm\n",
    "        ax.set_title(f\"n = {n}\", fontsize=9, pad=4)\n",
    "\n",
    "    # outer ticks / labels on leftmost and bottom axes\n",
    "    for ax in axes[:, 0]:\n",
    "        ax.set_yticks([-math.pi, 0, math.pi])\n",
    "        ax.set_yticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=9)\n",
    "        ax.set_ylabel(r'$\\eta$', fontsize=9, rotation=90, labelpad=-4)\n",
    "    for ax in axes[-1, :]:\n",
    "        ax.set_xticks([-math.pi, 0, math.pi])\n",
    "        ax.set_xticklabels([r'$-\\pi$', '0', r'$\\pi$'], fontsize=9)\n",
    "        ax.set_xlabel(r'$\\phi$', fontsize=9, labelpad=-2)\n",
    "\n",
    "    # row / figure annotations (which cell we are visualizing)\n",
    "    alpha_str = alpha_vals[alpha_idx]\n",
    "    q0_str = q0_vals[q0_idx]\n",
    "    module_str = (module_value if module_value is not None\n",
    "                  else energy_map[int(df_cell.iloc[0]['energy_loss'])])\n",
    "    # fig.suptitle(f\"Evolution of averaged Hist2D: αₛ={alpha_str}, Q₀={q0_str}, module={module_str}\",\n",
    "    #              fontsize=11, y=0.99)\n",
    "\n",
    "    # bottom colorbar\n",
    "    if last_pcm is not None:\n",
    "        cax = fig.add_axes([0.10, -0.02, 0.60, 0.03])\n",
    "        cbar = fig.colorbar(last_pcm, cax=cax, orientation='horizontal', location='bottom')\n",
    "        cbar.ax.set_title(r\"$p_T$\", fontsize=9, pad=6, loc='left', x=-0.12, y=0.1)\n",
    "        cbar.ax.tick_params(labelsize=9)\n",
    "\n",
    "    # legend with symbol explanations\n",
    "    legend_elements = [\n",
    "        mpatches.Patch(color='none', label=r\"$\\eta$: pseudorapidity\"),\n",
    "        mpatches.Patch(color='none', label=r\"$\\phi$: azimuthal\\;angle\"),\n",
    "        mpatches.Patch(color='none', label=r\"$p_T$: transverse\\;momentum\"),\n",
    "    ]\n",
    "    fig.legend(handles=legend_elements, loc='lower right', bbox_to_anchor=(0.98, -0.06),\n",
    "               fontsize=9, frameon=True, framealpha=0.9, handlelength=0, handletextpad=-0.2)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # save\n",
    "    base = f\"evolution_hist2d_alpha{alpha_str}_q0{q0_str}_{module_str}\".replace('.', 'p')\n",
    "    fig.savefig(base + \".png\", dpi=300, bbox_inches=\"tight\")\n",
    "    fig.savefig(base + \".pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "#MATTER, α=0.2 (idx 0), Q0=1.0 (idx 0)\n",
    "plot_evolution_hist2d_for_cell(alpha_idx=0, q0_idx=0, module_value=\"MATTER\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61ad23bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MATTER-LBT, α=0.4 (idx 2), Q0=2.5 (idx 3)\n",
    "plot_evolution_hist2d_for_cell(alpha_idx=2, q0_idx=3, module_value=\"MATTER-LBT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8df4b784",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_evolution_hist2d_for_cell(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    global_max=121.79151153564453,\n",
    "    alpha_idx=0, q0_idx=0, module_value=None,\n",
    "    sizes=sizes2,\n",
    "    layout=\"2x5\",           # \"2x5\" (default) or \"1xN\"\n",
    "    random_state=0,\n",
    "    cmap='inferno',\n",
    "    add_axes_labels=True,\n",
    "    add_legend=True,\n",
    "    add_colorbar=True,\n",
    "    add_events_count_as_title=True,\n",
    "):\n",
    "    import os, math\n",
    "    import numpy as np, pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "    import matplotlib.patches as mpatches\n",
    "\n",
    "    energy_map = {0:'MATTER',1:'MATTER-LBT'}\n",
    "    energy_inv = {v:k for k,v in energy_map.items()}\n",
    "    alpha_vals = {0:0.2,1:0.3,2:0.4}\n",
    "    q0_vals    = {0:1.0,1:1.5,2:2.0,3:2.5}\n",
    "\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    df = pd.read_csv(os.path.join(dataset_root, agg_csv))\n",
    "    mask = (df['alpha']==alpha_idx) & (df['q0']==q0_idx)\n",
    "    if module_value is not None:\n",
    "        mask &= df['energy_loss']==energy_inv[module_value]\n",
    "    df_cell = df[mask].sample(frac=1.0, random_state=random_state)\n",
    "    if df_cell.empty:\n",
    "        raise ValueError(\"No entries for requested (alpha,q0,module).\")\n",
    "\n",
    "    need = max(sizes)\n",
    "    files=[]\n",
    "    for _,row in df_cell.iterrows():\n",
    "        files += [os.path.join(dataset_root,p) for p in row['file_paths'].split('|')]\n",
    "        if len(files)>=need: break\n",
    "    if len(files)<need:\n",
    "        raise ValueError(f\"Found {len(files)} files, fewer than {need} requested.\")\n",
    "\n",
    "    imgs = np.stack([np.load(p).astype(np.float32)/float(global_max) for p in files[:need]], axis=0)\n",
    "    cumsum = np.cumsum(imgs, axis=0)\n",
    "\n",
    "    avgs=[]; vmins=[]; vmaxs=[]\n",
    "    for n in sizes:\n",
    "        a = cumsum[n-1]/float(n)\n",
    "        a = np.ma.masked_where(a==0, a)\n",
    "        avgs.append(a)\n",
    "        if np.any(a>0):\n",
    "            vmins.append(max(np.min(a[a>0]),1e-6)); vmaxs.append(np.max(a))\n",
    "        else:\n",
    "            vmins.append(1e-6); vmaxs.append(1.0)\n",
    "    vmin, vmax = min(vmins), max(vmaxs)\n",
    "\n",
    "    # --- layout ---\n",
    "    if layout.lower()==\"1xn\":\n",
    "        nrows, ncols = 1, len(sizes)\n",
    "        figsize = (ncols*2.2, 2.6)\n",
    "    else:\n",
    "        nrows, ncols = 2, 5\n",
    "        assert len(sizes)==nrows*ncols, \"sizes must match grid\"\n",
    "        figsize = (ncols*2.3, nrows*2.4)\n",
    "\n",
    "    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, dpi=300, sharex=True, sharey=True)\n",
    "    if nrows==1:\n",
    "        axes = np.atleast_2d(axes)  # shape (1, ncols)\n",
    "\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    last_pcm=None\n",
    "    for i,n in enumerate(sizes):\n",
    "        r, c = (i//ncols, i%ncols)\n",
    "        ax = axes[r, c]\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "        pcm = ax.pcolormesh(x_edges, y_edges, avgs[i],\n",
    "                            norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "                            cmap=cmap, shading='auto')\n",
    "        last_pcm = pcm\n",
    "        if add_events_count_as_title:\n",
    "            ax.set_title(f\"n = {n}\", fontsize=9, pad=4)\n",
    "    if add_axes_labels:\n",
    "        # axes labels only on outer edges\n",
    "        for ax in axes[:,0]:\n",
    "            ax.set_yticks([-math.pi,0,math.pi])\n",
    "            ax.set_yticklabels([r'$-\\pi$','0',r'$\\pi$'], fontsize=9)\n",
    "            ax.set_ylabel(r'$\\eta$', fontsize=9, labelpad=-4)\n",
    "        for ax in axes[-1,:]:\n",
    "            ax.set_xticks([-math.pi,0,math.pi])\n",
    "            ax.set_xticklabels([r'$-\\pi$','0',r'$\\pi$'], fontsize=9)\n",
    "            ax.set_xlabel(r'$\\phi$', fontsize=9, labelpad=-2)\n",
    "\n",
    "    alpha_str, q0_str = alpha_vals[alpha_idx], q0_vals[q0_idx]\n",
    "    module_str = module_value if module_value is not None else energy_map[int(df_cell.iloc[0]['energy_loss'])]\n",
    "    # fig.suptitle(f\"Evolution of averaged Hist2D: αₛ={alpha_str}, Q₀={q0_str}, module={module_str}\",\n",
    "    #              fontsize=11, y=0.99)\n",
    "\n",
    "    # colorbar under the figure; make it a bit wider for 1xN\n",
    "    if (last_pcm is not None) and add_colorbar:\n",
    "        if nrows==1:\n",
    "            cax = fig.add_axes([0.03, -0.06, 0.8, 0.04])\n",
    "        else:\n",
    "            cax = fig.add_axes([0.10, -0.02, 0.60, 0.03])\n",
    "        cbar = fig.colorbar(last_pcm, cax=cax, orientation='horizontal', location='bottom')\n",
    "        cbar.ax.set_title(r\"$p_T$\", fontsize=9, pad=6, loc='left', x=-0.02, y=0.14)\n",
    "        cbar.ax.tick_params(labelsize=9)\n",
    "\n",
    "    if add_legend:\n",
    "        legend_elements = [\n",
    "            mpatches.Patch(color='none', label=r\"$\\eta$: pseudorapidity     $p_T$: transverse momentum\"),\n",
    "            mpatches.Patch(color='none', label=r\"$\\phi$: azimuthal angle   $n$: number of events\"),\n",
    "            # mpatches.Patch(color='none', label=r\"$p_T$: transverse momentum\"),\n",
    "        ]\n",
    "        fig.legend(handles=legend_elements, loc='lower right',\n",
    "                bbox_to_anchor=(0.995, -0.15 if nrows==1 else -0.06),\n",
    "                fontsize=9, frameon=True, framealpha=0.9, handlelength=0, handletextpad=-0.2)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    base = f\"evolution_hist2d_alpha{str(alpha_str).replace('.','p')}_q0{str(q0_str).replace('.','p')}_{module_str}\"\n",
    "    fig.savefig(base + (\"_1xN\" if nrows==1 else \"_2x5\") + \".png\", dpi=300, bbox_inches=\"tight\")\n",
    "    fig.savefig(base + (\"_1xN\" if nrows==1 else \"_2x5\") + \".pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa2d7821",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=0, q0_idx=0, module_value=\"MATTER\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\",\n",
    "    add_legend=False,\n",
    "    add_colorbar=False,\n",
    "    add_axes_labels=False,\n",
    "    add_events_count_as_title=False\n",
    ")\n",
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=2, q0_idx=3, module_value=\"MATTER-LBT\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\",\n",
    "    add_legend=False,\n",
    "    add_colorbar=False,\n",
    "    add_axes_labels=False,\n",
    "    add_events_count_as_title=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af2b63a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=0, q0_idx=0, module_value=\"MATTER\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\",\n",
    "    add_legend=False,\n",
    "    add_colorbar=False,\n",
    "    add_axes_labels=False\n",
    ")\n",
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=2, q0_idx=3, module_value=\"MATTER-LBT\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\",\n",
    "    add_legend=False,\n",
    "    add_colorbar=False,\n",
    "    add_axes_labels=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb6f3f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=0, q0_idx=0, module_value=\"MATTER\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\",\n",
    "    add_legend=False,\n",
    "    add_colorbar=False\n",
    ")\n",
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=2, q0_idx=3, module_value=\"MATTER-LBT\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\",\n",
    "    add_legend=False,\n",
    "    add_colorbar=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7015bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=0, q0_idx=0, module_value=\"MATTER\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\"\n",
    ")\n",
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=2, q0_idx=3, module_value=\"MATTER-LBT\",\n",
    "    sizes=sizes1,\n",
    "    layout=\"1xN\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40e44106",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_evolution_hist2d_for_cell(\n",
    "    alpha_idx=2, q0_idx=3, module_value=\"MATTER-LBT\",\n",
    "    sizes=,\n",
    "    layout=\"1xN\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ad20a69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_hist2d_artistic(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    module_value=None,                # None, \"MATTER\", \"MATTER-LBT\"\n",
    "    global_max=121.79151153564453,\n",
    "    cmap=\"inferno\",\n",
    "    background=\"black\",               # \"black\" or \"white\"\n",
    "    keep_colorbar=False,              # True for a thin bar, False for pure abstract\n",
    "    random_state=0,\n",
    "):\n",
    "    import os, math\n",
    "    import numpy as np, pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "\n",
    "    # label maps\n",
    "    energy_map = {0:'MATTER', 1:'MATTER-LBT'}\n",
    "    energy_inv = {v:k for k,v in energy_map.items()}\n",
    "    alpha_vals = {0:0.2,1:0.3,2:0.4}\n",
    "    q0_vals    = {0:1.0,1:1.5,2:2.0,3:2.5}\n",
    "\n",
    "    # load\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    df = pd.read_csv(os.path.join(dataset_root, agg_csv))\n",
    "    if module_value is not None:\n",
    "        df = df[df['energy_loss'] == energy_inv[module_value]]\n",
    "\n",
    "    alphas = [0,1,2]\n",
    "    q0s    = [0,1,2,3]\n",
    "    n_rows, n_cols = len(q0s), len(alphas)\n",
    "\n",
    "    # figure — large for “gallery” feel\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(n_cols*3.0, n_rows*3.0), dpi=300,\n",
    "        sharex=True, sharey=True\n",
    "    )\n",
    "\n",
    "    # backgrounds\n",
    "    if background.lower() == \"black\":\n",
    "        fig.patch.set_facecolor(\"black\")\n",
    "        for ax in axes.ravel(): ax.set_facecolor(\"black\")\n",
    "        fg = \"white\"\n",
    "    else:\n",
    "        fig.patch.set_facecolor(\"white\")\n",
    "        for ax in axes.ravel(): ax.set_facecolor(\"white\")\n",
    "        fg = \"black\"\n",
    "\n",
    "    # bins (fixed)\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    # collect and scale\n",
    "    imgs, vmins, vmaxs = [], [], []\n",
    "    cell = [[None]*n_cols for _ in range(n_rows)]\n",
    "    rng = np.random.RandomState(random_state)\n",
    "\n",
    "    def _avg(entry):\n",
    "        paths = [os.path.join(dataset_root, p) for p in entry['file_paths'].split('|')]\n",
    "        arrs  = [np.load(p).astype(np.float32)/global_max for p in paths]\n",
    "        avg = np.mean(arrs, axis=0)\n",
    "        return np.ma.masked_where(avg==0, avg)\n",
    "\n",
    "    for r, q in enumerate(q0s):\n",
    "        for c, a in enumerate(alphas):\n",
    "            subset = df[(df['alpha']==a) & (df['q0']==q)]\n",
    "            if subset.empty: continue\n",
    "            entry = subset.sample(n=1, random_state=rng).iloc[0]\n",
    "            img = _avg(entry)\n",
    "            cell[r][c] = img\n",
    "            if np.any(img>0):\n",
    "                vmins.append(max(np.min(img[img>0]), 1e-6))\n",
    "                vmaxs.append(np.max(img))\n",
    "\n",
    "    # shared log scale\n",
    "    vmin = min(vmins) if vmins else 1e-6\n",
    "    vmax = max(vmaxs) if vmaxs else 1.0\n",
    "\n",
    "    last = None\n",
    "    for r in range(n_rows):\n",
    "        for c in range(n_cols):\n",
    "            ax = axes[r, c]\n",
    "            ax.axis(\"off\")                     # <- minimalist\n",
    "            if cell[r][c] is None: continue\n",
    "            last = ax.pcolormesh(\n",
    "                x_edges, y_edges, cell[r][c],\n",
    "                norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "                cmap=cmap, shading='auto'\n",
    "            )\n",
    "\n",
    "    # optional ultra-thin colorbar\n",
    "    if keep_colorbar and last is not None:\n",
    "        cax = fig.add_axes([0.10, 0.02, 0.80, 0.015])  # thin, unobtrusive\n",
    "        cbar = fig.colorbar(last, cax=cax, orientation='horizontal')\n",
    "        cbar.outline.set_visible(False)\n",
    "        cbar.ax.tick_params(length=0, labelsize=8, colors=fg)\n",
    "\n",
    "    plt.subplots_adjust(wspace=0.02, hspace=0.02)      # tight tiles\n",
    "    base = f\"jet_images_art_grid_4x3_{background}_{cmap}\"\n",
    "    fig.savefig(base + \".png\", dpi=400, bbox_inches=\"tight\", facecolor=fig.get_facecolor())\n",
    "    fig.savefig(base + \".pdf\", bbox_inches=\"tight\", facecolor=fig.get_facecolor())\n",
    "    plt.show()\n",
    "\n",
    "# 4×3 gallery, black background, inferno, no colorbar\n",
    "plot_grid_hist2d_artistic(background=\"black\", cmap=\"inferno\", keep_colorbar=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a65cccba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_evolution_hist2d_artistic(\n",
    "    agg_csv='file_labels_aggregated_ds1008_g500_train.csv',\n",
    "    dataset_root=\"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
    "    alpha_idx=0, q0_idx=0, module_value=None,          # choose the cell\n",
    "    global_max=121.79151153564453,\n",
    "    sizes=,\n",
    "    cmap=\"inferno\",\n",
    "    background=\"black\",                                # \"black\" or \"white\"\n",
    "    keep_colorbar=False,\n",
    "    random_state=0\n",
    "):\n",
    "    import os, math\n",
    "    import numpy as np, pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.colors as colors\n",
    "\n",
    "    # label maps\n",
    "    energy_map = {0:'MATTER', 1:'MATTER-LBT'}\n",
    "    energy_inv = {v:k for k,v in energy_map.items()}\n",
    "\n",
    "    # load & filter\n",
    "    dataset_root = os.path.expanduser(dataset_root)\n",
    "    df = pd.read_csv(os.path.join(dataset_root, agg_csv))\n",
    "    mask = (df['alpha']==alpha_idx) & (df['q0']==q0_idx)\n",
    "    if module_value is not None:\n",
    "        mask &= df['energy_loss']==energy_inv[module_value]\n",
    "    df_cell = df[mask].sample(frac=1.0, random_state=random_state)\n",
    "    if df_cell.empty:\n",
    "        raise ValueError(\"No entries found for the requested (alpha, q0, module).\")\n",
    "\n",
    "    need = max(sizes)\n",
    "    files=[]\n",
    "    for _, row in df_cell.iterrows():\n",
    "        files += [os.path.join(dataset_root,p) for p in row['file_paths'].split('|')]\n",
    "        if len(files) >= need: break\n",
    "    if len(files) < need:\n",
    "        raise ValueError(f\"Only {len(files)} files available; need {need}.\")\n",
    "\n",
    "    # load & cumulative means\n",
    "    imgs = np.stack([np.load(p).astype(np.float32)/global_max for p in files[:need]], axis=0)\n",
    "    cumsum = np.cumsum(imgs, axis=0)\n",
    "    avgs = [np.ma.masked_where((cumsum[n-1]/float(n))==0, cumsum[n-1]/float(n)) for n in sizes]\n",
    "\n",
    "    # shared log scale\n",
    "    posmins = []\n",
    "    vmaxs = []EfficientNetV2 – – – – – – –\n",
    "ConvNeXt V2 – – – – – – –\n",
    "ViT-CoMer – – – – – – –\n",
    "Swin Transformer V2 – – – – – – –\n",
    "Mamba – – – – – – –\n",
    "Vision Mamba – – – – – – –\n",
    "Hybrid (CNN+Transf.) – – – – – – –\n",
    "    for a in avgs:\n",
    "        if np.any(a>0):\n",
    "            posmins.append(max(np.min(a[a>0]), 1e-6))\n",
    "            vmaxs.append(np.max(a))\n",
    "    vmin = min(posmins) if posmins else 1e-6\n",
    "    vmax = max(vmaxs)  if vmaxs  else 1.0\n",
    "\n",
    "    # figure — very wide panoramic strip\n",
    "    n = len(sizes)\n",
    "    fig, axes = plt.subplots(1, n, figsize=(n*2.6, 2.8), dpi=300, sharex=True, sharey=True)\n",
    "\n",
    "    # backgrounds\n",
    "    if background.lower() == \"black\":\n",
    "        fig.patch.set_facecolor(\"black\")\n",
    "        for ax in np.atleast_1d(axes): ax.set_facecolor(\"black\")\n",
    "        fg = \"white\"\n",
    "    else:\n",
    "        fig.patch.set_facecolor(\"white\")\n",
    "        for ax in np.atleast_1d(axes): ax.set_facecolor(\"white\")\n",
    "        fg = \"black\"\n",
    "\n",
    "    x_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "    y_edges = np.linspace(-math.pi, math.pi, 33)\n",
    "\n",
    "    last=None\n",
    "    for i, ax in enumerate(np.atleast_1d(axes)):\n",
    "        ax.axis(\"off\")  # <- minimalist / no ticks, no titles\n",
    "        last = ax.pcolormesh(\n",
    "            x_edges, y_edges, avgs[i],\n",
    "            norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
    "            cmap=cmap, shading='auto'\n",
    "        )\n",
    "\n",
    "    # optional thin colorbar\n",
    "    if keep_colorbar and last is not None:\n",
    "        cax = fig.add_axes([0.05, 0.05, 0.90, 0.02])\n",
    "        cbar = fig.colorbar(last, cax=cax, orientation='horizontal')\n",
    "        cbar.outline.set_visible(False)\n",
    "        cbar.ax.tick_params(length=0, labelsize=8, colors=fg)\n",
    "\n",
    "    plt.subplots_adjust(wspace=0.02, hspace=0.0)\n",
    "    base = f\"jet_evolution_art_1xN_{background}_{cmap}\"\n",
    "    fig.savefig(base + \".png\", dpi=400, bbox_inches=\"tight\", facecolor=fig.get_facecolor())\n",
    "    fig.savefig(base + \".pdf\", bbox_inches=\"tight\", facecolor=fig.get_facecolor())\n",
    "    plt.show()\n",
    "# 1×N evolution strip for (α=0.2 idx=0, Q0=1.0 idx=0, MATTER), white background with thin bar\n",
    "plot_evolution_hist2d_artistic(alpha_idx=0, q0_idx=0, module_value=\"MATTER\",\n",
    "                               background=\"white\", cmap=\"magma\", keep_colorbar=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
