"""Benchmarking version of demo.py with full telemetry for ICML reporting.

This script runs the TFMPE hemodynamics inference pipeline with:
1. System information logging (via sys_bench.py)
2. JIT-aware timing with warmup (via bench_utils.py)
3. Energy measurement via RAPL/nvidia-smi
4. Smoothed PPC plots using Savitzky-Golay filtering
5. Comprehensive JSON output for reproducibility

Usage:
    # Quick test run (minimal settings for e2e validation)
    python demo_bench.py --quick-test

    # Full benchmark run
    python demo_bench.py --n_samples 500 --n_iter 1000 --n_posterior 10000

    # Output benchmark results as LaTeX
    python demo_bench.py --latex
"""

import jax
import os

jax.config.update("jax_enable_x64", False)
os.environ["JAX_PLATFORM_NAME"] = "gpu"

import argparse
import json
import time
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Dict, Tuple, Callable, Any
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.signal import savgol_filter
from scipy.interpolate import CubicSpline

import jax.numpy as jnp
from jax import random, lax
from tqdm.auto import tqdm

import diffrax
import optax
from flax import nnx

# TFMPE imports
from tfmpe.estimators.tfmpe import TFMPE, NormalDistribution
from tfmpe.estimators.training import fit_bottom_up
from tfmpe.preprocessing.tokens import Tokens
from tfmpe.preprocessing.utils import Independence, Labeller
from tfmpe.nn.transformer import Transformer, TransformerConfig
from jaxtyping import PRNGKeyArray

# Local imports
from sys_bench import collect_system_info, to_dict as sys_info_to_dict
from bench_utils import (
    TimingResult, EnergyResult, BenchmarkResult,
    time_function, benchmark_function,
    timed_block, energy_monitored,
    EnergyMonitor, sync_jax_devices,
    format_timing_table, format_benchmark_latex,
)

# Import the hemodynamics simulator components from demo.py
from demo import (
    MMHG_IN_PA, SITE_ORDER,
    Vessel, Network, BaseParams, RCRParameters,
    create_arch_network, pack_geometry, terminal_ids_ordered,
    st_like_refs, psi_from_A, pressure_from_A, A_from_pressure,
    characteristic_speed, default_rcr, make_stepper,
    simulate_5cycles_then_sample,
    create_prior_fn, create_local_fn, create_simulator_fn,
    _repeat_context_for_sampling, _numpy_seed_from_key,
)


# =============================================================================
# BENCHMARKING CONFIGURATION
# =============================================================================

@dataclass
class BenchConfig:
    """Configuration for benchmarking run."""
    # Simulation parameters
    N_t: int = 50
    nx: int = 81
    dt_init: float = 2e-4
    eta: float = 0.05

    # Model architecture
    latent_dim: int = 32
    n_encoder: int = 2
    n_decoder: int = 2
    n_heads: int = 2
    n_ff: int = 2

    # Training parameters
    n_rounds: int = 1
    n_samples_per_round: int = 200
    n_val_samples: int = 20
    n_iter_per_round: int = 500
    batch_size: int = 32
    learning_rate: float = 1e-3

    # Inference parameters
    n_posterior_samples: int = 500

    # Benchmarking parameters
    n_sim_warmup: int = 3      # Warmup runs for CFD simulator timing
    n_sim_repeats: int = 10    # Timed repetitions for CFD simulator
    n_sample_warmup: int = 2   # Warmup for posterior sampling
    n_sample_repeats: int = 5  # Timed repetitions for sampling
    measure_energy: bool = True

    # Output
    output_dir: str = "tfmpe_bench_results"

    # PPC plot settings
    ppc_smooth_window: int = 7     # Savitzky-Golay window (must be odd)
    ppc_smooth_polyorder: int = 3  # Polynomial order for smoothing
    n_ppc_curves: int = 20         # Number of posterior predictive curves
    skip_ppc: bool = False         # Skip PPC plot generation (for fast testing)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class CFDTelemetry:
    """Telemetry for CFD forward simulator."""
    n_simulations: int = 0
    n_patient_sims: int = 0  # Total patient-level simulations
    wall_time_per_sim_mean_s: float = 0.0
    wall_time_per_sim_median_s: float = 0.0
    wall_time_per_sim_std_s: float = 0.0
    wall_time_jit_compile_s: float = 0.0
    energy_per_sim_mean_j: Optional[float] = None  # None = not measured
    total_sim_time_s: float = 0.0
    total_energy_j: Optional[float] = None
    # Metadata
    n_time_points: int = 0
    n_spatial_points: int = 0
    n_cycles: int = 5
    telemetry_method: str = "none"  # "nvidia-smi", "rapl", "none"
    parallelism: str = "python_loop"  # Description of parallelism strategy


@dataclass
class TFMPETelemetry:
    """Telemetry for TFMPE inference."""
    # Training
    training_wall_time_s: float = 0.0
    training_energy_j: Optional[float] = None
    n_training_samples: int = 0
    n_training_iterations: int = 0
    final_train_loss: float = 0.0
    final_val_loss: float = 0.0

    # Likelihood sampling (Stage 2: sampling y from q_phi(y|theta))
    like_sampling_wall_time_s: float = 0.0
    like_time_per_patient_mean_s: float = 0.0
    like_time_per_patient_median_s: float = 0.0
    n_like_draws: int = 0

    # Posterior sampling (sampling theta from q_phi(theta|y_obs))
    posterior_sampling_wall_time_s: float = 0.0
    posterior_sampling_energy_j: Optional[float] = None
    n_posterior_samples: int = 0
    posterior_time_per_sample_mean_s: float = 0.0
    posterior_time_per_sample_median_s: float = 0.0

    # Model info
    precision: str = "float32"
    n_model_parameters: int = 0
    solver_settings: str = ""  # e.g., "diffrax.Dopri5(rtol=1e-3, atol=1e-3)"
    telemetry_method: str = "none"  # "nvidia-smi", "rapl", "none"


@dataclass
class Counts:
    """Explicit counts for reproducibility and derived calculations.

    Key insight: Stage 2 training is simulator-free (uses likelihood sampling
    y ~ q_phi(y|theta) instead of CFD), which is the source of TFMPE's speedup.
    """
    n_patients: int = 0  # Number of patient networks per simulation
    # Stage 1: Learn local likelihood p(y|theta) with n=1 (uses CFD)
    n_param_draws_stage1: int = 0
    n_patient_sims_stage1: int = 0  # = n_param_draws_stage1 * 1
    # Stage 2: Learn global posterior with n=n_patients (simulator-free!)
    n_param_draws_stage2: int = 0
    n_patient_sims_stage2: int = 0  # Should be 0 - Stage 2 uses likelihood sampling, not CFD
    n_like_draws_stage2: int = 0  # Likelihood samples (y|theta) in Stage 2 training
    # Totals (n_cfd_sims_total includes obs gen + Stage 1 + validation)
    n_cfd_sims_total: int = 0
    n_like_samples_total: int = 0


@dataclass
class DerivedMetrics:
    """Derived efficiency metrics for reporting."""
    # Key timing components
    t_sim_mean_s: float = 0.0      # CFD simulation time per patient
    t_sim_median_s: float = 0.0
    t_like_mean_s: float = 0.0     # Likelihood sampling time per patient (y ~ q_phi)
    t_like_median_s: float = 0.0
    t_post_mean_s: float = 0.0     # Posterior sampling time per draw (theta ~ q_phi)
    t_post_median_s: float = 0.0

    # For NPE: N * n_s * t_sim where n_s = number of simulations per parameter
    # For TFMPE: N1 * t_sim + N2 * n_s * t_like
    # Speedup = (N * n_s * t_sim) / (N1 * t_sim + N2 * n_s * t_like)


@dataclass
class BenchmarkResults:
    """Complete benchmark results."""
    config: Dict[str, Any] = field(default_factory=dict)
    system_info: Dict[str, Any] = field(default_factory=dict)
    cfd: CFDTelemetry = field(default_factory=CFDTelemetry)
    tfmpe: TFMPETelemetry = field(default_factory=TFMPETelemetry)
    counts: Counts = field(default_factory=Counts)
    derived: DerivedMetrics = field(default_factory=DerivedMetrics)
    raw_timings: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "config": self.config,
            "system_info": self.system_info,
            "cfd": asdict(self.cfd),
            "tfmpe": asdict(self.tfmpe),
            "counts": asdict(self.counts),
            "derived": asdict(self.derived),
            "raw_timings": self.raw_timings,
        }


# =============================================================================
# INSTRUMENTED SIMULATOR
# =============================================================================

class InstrumentedSimulator:
    """Wrapper for CFD simulator with timing instrumentation.

    Tracks per-simulation wall times and provides access to raw timing data.
    """

    def __init__(
        self,
        net: Network,
        base: BaseParams,
        config: BenchConfig,
        n_terminals: int,
        telemetry: CFDTelemetry,
    ):
        self.net = net
        self.base = base
        self.config = config
        self.n_terminals = n_terminals
        self.telemetry = telemetry
        self.sim_times: List[float] = []
        self.n_param_draws: int = 0
        self.n_patient_sims: int = 0

    def __call__(self, rng: PRNGKeyArray, params_dict, n: int):
        """Run simulations and record timing."""
        if hasattr(params_dict, "decode"):
            params_dict = params_dict.decode()

        n_patients = int(n)

        log_beta = params_dict["log_beta"][:, 0, 0]
        log_mu = params_dict["log_mu"][:, 0, 0]
        log_Qin = params_dict["log_Qin"][:, 0, 0]
        log_Rt = params_dict["log_Rt"][:, :, :, 0]
        log_C = params_dict["log_C"][:, :, :, 0]
        nsamp = int(log_beta.shape[0])

        log_beta_np = np.array(log_beta, dtype=np.float64)
        log_mu_np = np.array(log_mu, dtype=np.float64)
        log_Qin_np = np.array(log_Qin, dtype=np.float64)
        log_Rt_np = np.array(log_Rt, dtype=np.float64)
        log_C_np = np.array(log_C, dtype=np.float64)

        rng_np = np.random.default_rng(_numpy_seed_from_key(rng))
        all_y = []

        tids = terminal_ids_ordered(self.net)

        for i in tqdm(range(nsamp), desc=f"Simulating (n_patients={n_patients})", leave=False):
            theta_g_i = jnp.array([log_beta_np[i], log_mu_np[i], log_Qin_np[i]], dtype=jnp.float32)
            y_groups = np.zeros((n_patients, self.n_terminals, self.config.N_t), dtype=np.float64)

            for p in range(n_patients):
                theta_loc_phys = np.zeros((self.n_terminals, 2), dtype=np.float64)
                for s in range(self.n_terminals):
                    theta_loc_phys[s, 0] = log_Rt_np[i, p, s]
                    theta_loc_phys[s, 1] = log_C_np[i, p, s]

                theta_loc_i = jnp.array(theta_loc_phys, dtype=jnp.float32)

                # Time individual simulation
                t_start = time.perf_counter()
                resampled = simulate_5cycles_then_sample(
                    self.net, self.base, theta_g_i, theta_loc_i,
                    N_t=self.config.N_t, nx=self.config.nx, dt_init=self.config.dt_init
                )
                sync_jax_devices()
                t_end = time.perf_counter()
                self.sim_times.append(t_end - t_start)
                self.n_patient_sims += 1

                for s, site_name in enumerate(SITE_ORDER):
                    q = np.asarray(resampled[site_name]["q"], dtype=np.float64)
                    y_groups[p, s, :] = q

            all_y.append(y_groups)
            self.n_param_draws += 1

        y_np = np.stack(all_y, axis=0)
        if not np.isfinite(y_np).all():
            raise RuntimeError("Non-finite values in simulator outputs.")

        y = jnp.asarray(y_np, dtype=jnp.float32)[..., None]

        # Update telemetry
        self._update_telemetry()

        return {"y": y}, None

    def _update_telemetry(self):
        """Update telemetry with current timing stats."""
        if self.sim_times:
            self.telemetry.n_simulations = len(self.sim_times)
            self.telemetry.n_patient_sims = self.n_patient_sims
            self.telemetry.wall_time_per_sim_mean_s = float(np.mean(self.sim_times))
            self.telemetry.wall_time_per_sim_median_s = float(np.median(self.sim_times))
            self.telemetry.wall_time_per_sim_std_s = float(np.std(self.sim_times))
            self.telemetry.total_sim_time_s = float(np.sum(self.sim_times))
            self.telemetry.n_time_points = self.config.N_t
            self.telemetry.n_spatial_points = self.config.nx
            self.telemetry.parallelism = "python_loop_over_params_and_patients_no_vmap"

    def get_raw_times(self) -> List[float]:
        """Return raw simulation times for storage."""
        return self.sim_times.copy()

    def reset_times(self):
        """Reset timing counters (e.g., between stages)."""
        self.sim_times = []
        self.n_param_draws = 0
        self.n_patient_sims = 0


# =============================================================================
# SMOOTHED PPC PLOTTING
# =============================================================================

def smooth_waveform(
    y: np.ndarray,
    method: str = "savgol",
    window_length: int = 7,
    polyorder: int = 3,
) -> np.ndarray:
    """Smooth a waveform using Savitzky-Golay or spline interpolation.

    For PPC plots, Savitzky-Golay is preferred because it:
    - Preserves peak locations and heights better than moving average
    - Doesn't introduce phase shift (unlike causal filters)
    - Is computationally efficient
    - Works well with the relatively smooth CFD waveforms

    Parameters
    ----------
    y : np.ndarray
        Input waveform
    method : str
        "savgol" (default) or "spline"
    window_length : int
        Window length for Savitzky-Golay (must be odd, default: 7)
    polyorder : int
        Polynomial order for Savitzky-Golay (default: 3)

    Returns
    -------
    y_smooth : np.ndarray
        Smoothed waveform
    """
    if len(y) < window_length:
        return y

    if method == "savgol":
        # Ensure window_length is odd and <= len(y)
        wl = min(window_length, len(y))
        if wl % 2 == 0:
            wl -= 1
        wl = max(wl, polyorder + 2)
        return savgol_filter(y, window_length=wl, polyorder=polyorder)

    elif method == "spline":
        t = np.arange(len(y))
        cs = CubicSpline(t, y)
        return cs(t)

    else:
        return y


def plot_posterior_predictive_smoothed(
    net: Network,
    base: BaseParams,
    config: BenchConfig,
    theta_g_samples: np.ndarray,
    theta_l_samples: np.ndarray,
    y_obs: np.ndarray,
    output_dir: str,
):
    """Create PPC plot with smoothed observed waveform.

    The observed waveform from the CFD simulator can appear jagged due to
    numerical artifacts from the Lax-Wendroff scheme. We apply Savitzky-Golay
    smoothing to produce publication-quality plots.

    Savitzky-Golay is chosen over alternatives because:
    - It preserves the shape of peaks better than moving average
    - It doesn't introduce phase shift (important for time-series)
    - The polynomial fitting reduces high-frequency noise while
      maintaining the underlying waveform morphology
    """
    n_terminals = len(SITE_ORDER)
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    t_grid = np.linspace(0, base.T, config.N_t, endpoint=False)

    for s_idx, (ax, site_name) in enumerate(zip(axes, SITE_ORDER)):
        # Get observed data and apply smoothing
        y_site_raw = y_obs[0, 0, s_idx, :, 0]  # sample 0, patient 0

        # Apply Savitzky-Golay smoothing
        y_site_smooth = smooth_waveform(
            y_site_raw,
            method="savgol",
            window_length=config.ppc_smooth_window,
            polyorder=config.ppc_smooth_polyorder,
        )

        # Plot smoothed observed
        ax.plot(t_grid, y_site_smooth, 'k-', linewidth=2.5, label='Observed', zorder=10)

        # Optionally show raw data as thin gray line
        # ax.plot(t_grid, y_site_raw, 'gray', linewidth=0.5, alpha=0.5, label='Raw')

        # Plot posterior predictive samples
        indices = np.random.choice(
            len(theta_g_samples),
            min(config.n_ppc_curves, len(theta_g_samples)),
            replace=False
        )

        for idx in tqdm(indices, desc=f"PPC {site_name}", leave=False):
            theta_g_i = jnp.array(theta_g_samples[idx])
            theta_l_i = jnp.array(theta_l_samples[idx])
            try:
                resampled = simulate_5cycles_then_sample(
                    net, base, theta_g_i, theta_l_i,
                    N_t=config.N_t, nx=config.nx, dt_init=config.dt_init
                )
                y_pred = resampled[site_name]['q']
                # Also smooth posterior predictive for consistency
                y_pred_smooth = smooth_waveform(
                    y_pred,
                    method="savgol",
                    window_length=config.ppc_smooth_window,
                    polyorder=config.ppc_smooth_polyorder,
                )
                ax.plot(t_grid, y_pred_smooth, 'b-', alpha=0.3, linewidth=0.8)
            except Exception:
                pass

        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Flow (m³/s)')
        ax.set_title(f'{site_name}')
        ax.grid(True, alpha=0.3)
        if s_idx == 0:
            ax.plot([], [], 'b-', alpha=0.5, label='Posterior predictive')
            ax.legend()

    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, 'posterior_predictive.png'),
        dpi=150, bbox_inches='tight'
    )
    plt.savefig(
        os.path.join(output_dir, 'posterior_predictive.pdf'),
        bbox_inches='tight'
    )
    plt.close()


# =============================================================================
# VISUALIZATION FUNCTIONS (from demo.py)
# =============================================================================

def plot_posterior_marginals(
    theta_g_samples: np.ndarray,
    theta_l_samples: np.ndarray,
    true_theta_g: np.ndarray,
    true_theta_l: np.ndarray,
    output_dir: str
):
    """Plot posterior marginal distributions."""
    # Global parameters
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    param_names = [r'$\log \beta_{scale}$', r'$\log \mu$', r'$\log Q_{in}$']

    for i, (ax, name) in enumerate(zip(axes, param_names)):
        samples = theta_g_samples[:, i]
        ax.hist(samples, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white')
        ax.axvline(true_theta_g[i], color='red', linestyle='--', linewidth=2, label='True')
        ax.axvline(np.mean(samples), color='green', linestyle='-', linewidth=2, label='Mean')
        ax.set_xlabel(name, fontsize=12)
        ax.set_ylabel('Density' if i == 0 else '')
        ax.legend()
        ax.set_title(f'Posterior: {name}')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_global.png'), dpi=150, bbox_inches='tight')
    plt.close()

    # Local parameters
    n_terminals = theta_l_samples.shape[1]
    fig, axes = plt.subplots(n_terminals, 2, figsize=(10, 3 * n_terminals))
    for s in range(n_terminals):
        for j, param_name in enumerate([r'$\log R_T$', r'$\log C_T$']):
            ax = axes[s, j]
            samples = theta_l_samples[:, s, j]
            ax.hist(samples, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white')
            ax.axvline(true_theta_l[s, j], color='red', linestyle='--', linewidth=2, label='True')
            ax.axvline(np.mean(samples), color='green', linestyle='-', linewidth=2, label='Mean')
            ax.set_xlabel(param_name)
            ax.set_title(f'{SITE_ORDER[s]}: {param_name}')
            if s == 0 and j == 0:
                ax.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_local.png'), dpi=150, bbox_inches='tight')
    plt.close()


def plot_training_losses(all_losses: List, output_dir: str):
    """Plot training loss curves."""
    n_rounds = len(all_losses)
    fig, axes = plt.subplots(n_rounds, 2, figsize=(12, 4 * n_rounds))
    if n_rounds == 1:
        axes = axes.reshape(1, -1)

    for r, (train_local, val_local, train_global, val_global) in enumerate(all_losses):
        ax = axes[r, 0]
        ax.plot(train_local, label='Train', alpha=0.8)
        ax.plot(val_local, label='Val', alpha=0.8)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title(f'Round {r}: Local Likelihood')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)

        ax = axes[r, 1]
        if len(train_global) > 0:
            ax.plot(train_global, label='Train', alpha=0.8)
            ax.plot(val_global, label='Val', alpha=0.8)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title(f'Round {r}: Global Posterior')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_losses.png'), dpi=150, bbox_inches='tight')
    plt.close()


def create_summary_table(
    theta_g_samples: np.ndarray,
    theta_l_samples: np.ndarray,
    true_theta_g: np.ndarray,
    true_theta_l: np.ndarray,
    output_dir: str
) -> pd.DataFrame:
    """Create summary table of posterior statistics."""
    rows = []

    global_names = ['log_beta', 'log_mu', 'log_Qin']
    for i, name in enumerate(global_names):
        samples = theta_g_samples[:, i]
        rows.append({
            'Parameter': name,
            'Site': 'Global',
            'True': f'{true_theta_g[i]:.3f}',
            'Mean': f'{np.mean(samples):.3f}',
            'Std': f'{np.std(samples):.3f}',
            '2.5%': f'{np.percentile(samples, 2.5):.3f}',
            '97.5%': f'{np.percentile(samples, 97.5):.3f}',
            'Covers True': 'Yes' if np.percentile(samples, 2.5) <= true_theta_g[i] <= np.percentile(samples, 97.5) else 'No'
        })

    local_names = ['log_R_T', 'log_C_T']
    for s_idx, site_name in enumerate(SITE_ORDER):
        for j, param_name in enumerate(local_names):
            samples = theta_l_samples[:, s_idx, j]
            true_val = true_theta_l[s_idx, j]
            rows.append({
                'Parameter': param_name,
                'Site': site_name,
                'True': f'{true_val:.3f}',
                'Mean': f'{np.mean(samples):.3f}',
                'Std': f'{np.std(samples):.3f}',
                '2.5%': f'{np.percentile(samples, 2.5):.3f}',
                '97.5%': f'{np.percentile(samples, 97.5):.3f}',
                'Covers True': 'Yes' if np.percentile(samples, 2.5) <= true_val <= np.percentile(samples, 97.5) else 'No'
            })

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(output_dir, 'posterior_summary.csv'), index=False)

    latex_str = df.to_latex(index=False, caption='Posterior summary statistics',
                            label='tab:posterior_summary')
    with open(os.path.join(output_dir, 'posterior_summary.tex'), 'w') as f:
        f.write(latex_str)

    return df


# =============================================================================
# MAIN BENCHMARK PIPELINE
# =============================================================================

def run_benchmark(config: Optional[BenchConfig] = None) -> BenchmarkResults:
    """Run the full benchmarking pipeline with telemetry."""

    if config is None:
        config = BenchConfig()

    os.makedirs(config.output_dir, exist_ok=True)

    results = BenchmarkResults()
    results.config = config.to_dict()

    print("=" * 70)
    print("TFMPE HEMODYNAMICS BENCHMARK")
    print("=" * 70)

    # Collect system info
    print("\n[0/6] Collecting system information...")
    sys_info = collect_system_info(
        tfmpe_path=str(Path(__file__).parent.parent / "tfmpe")
    )
    results.system_info = sys_info_to_dict(sys_info)

    with open(os.path.join(config.output_dir, 'sys_info.json'), 'w') as f:
        json.dump(results.system_info, f, indent=2)

    print(f"  Backend: {sys_info.jax_backend}")
    print(f"  GPU: {sys_info.gpu.name}")
    print(f"  Precision: float32")

    # Setup network and parameters
    net = create_arch_network()
    base = BaseParams(T=1.0)
    n_patients = 2
    n_terminals = len(terminal_ids_ordered(net))

    print(f"\n[1/6] Setting up model...")
    print(f"  Network: {net.n_vessels} vessels, {n_terminals} terminals")
    print(f"  Sites: {SITE_ORDER}")
    print(f"  N_t={config.N_t}, nx={config.nx}")

    labeller = Labeller.for_keys(['log_beta', 'log_mu', 'log_Qin', 'log_Rt', 'log_C', 'y'])
    independence = Independence()

    # Create functions with instrumentation
    prior_fn = create_prior_fn(net, base, n_terminals)
    local_fn = create_local_fn(net, base, n_terminals)

    # Create instrumented simulator
    cfd_telemetry = CFDTelemetry()
    simulator = InstrumentedSimulator(
        net, base, config, n_terminals, cfd_telemetry
    )
    simulator_fn = simulator  # Use as callable

    # Initialize counts
    results.counts.n_patients = n_patients

    # Generate true parameters and observations
    print("\n[2/6] Generating observations...")
    rng = random.PRNGKey(42)
    rng, key = random.split(rng)
    true_params, _ = prior_fn(key, n=n_patients, n_samples=1)

    true_theta_g = np.array([
        float(true_params['log_beta'][0, 0, 0]),
        float(true_params['log_mu'][0, 0, 0]),
        float(true_params['log_Qin'][0, 0, 0]),
    ])
    true_theta_l = np.stack([
        np.array(true_params['log_Rt'][0, 0, :, 0]),
        np.array(true_params['log_C'][0, 0, :, 0])
    ], axis=-1)

    print(f"  True beta_scale = {np.exp(true_theta_g[0]):.2e}")
    print(f"  True mu = {np.exp(true_theta_g[1]):.4f}")
    print(f"  True Q_in = {np.exp(true_theta_g[2]):.1f} mL/s")

    # Time observation generation
    print("\n  Timing observation generation...")
    with timed_block("obs_gen", sync_devices=True) as timer:
        rng, key = random.split(rng)
        y_obs_dict, _ = simulator_fn(key, true_params, n=n_patients)

    y_obs = {'y': y_obs_dict['y'].astype(jnp.float32)}
    print(f"  Observation shape: {y_obs['y'].shape}")
    print(f"  Time: {timer.elapsed_s:.2f}s")

    # Initialize TFMPE model
    print("\n[3/6] Initializing TFMPE model...")
    rng, key = random.split(rng)

    template_params, _ = prior_fn(key, n=n_patients, n_samples=10)
    params_tokens = Tokens.from_pytree(
        template_params,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )

    transformer_config = TransformerConfig(
        latent_dim=config.latent_dim,
        n_encoder=config.n_encoder,
        n_decoder=config.n_decoder,
        n_heads=config.n_heads,
        n_ff=config.n_ff,
    )

    rngs = nnx.Rngs(params=random.PRNGKey(0), dropout=random.PRNGKey(1))
    transformer = Transformer(config=transformer_config, tokens=params_tokens, rngs=rngs)
    base_dist = NormalDistribution(rngs=rngs)

    tfmpe = TFMPE(
        vf_network=transformer,
        base_dist=base_dist,
        solver=diffrax.Dopri5(),
        ode_kwargs={'rtol': 1e-3, 'atol': 1e-3}
    )

    # Count parameters
    def count_params(model):
        return sum(x.size for x in jax.tree_util.tree_leaves(nnx.state(model)))

    n_params = count_params(tfmpe)
    print(f"  Model parameters: {n_params:,}")
    results.tfmpe.n_model_parameters = n_params
    results.tfmpe.precision = "float32"

    optimizer = optax.adam(learning_rate=config.learning_rate)
    opt = nnx.Optimizer(tfmpe, optimizer, wrt=nnx.Param)
    effective_batch_size = min(config.batch_size, config.n_samples_per_round)

    # Training with timing and energy
    print("\n[4/6] Training (fit_bottom_up)...")
    print(f"  n_samples_per_round: {config.n_samples_per_round}")
    print(f"  n_iter_per_round: {config.n_iter_per_round}")
    print(f"  batch_size: {effective_batch_size}")

    energy_monitor = EnergyMonitor() if config.measure_energy else None
    if energy_monitor:
        energy_monitor.start()

    t_train_start = time.perf_counter()
    sync_jax_devices()

    rng, key = random.split(rng)
    trained_tfmpe, all_losses = fit_bottom_up(
        tfmpe=tfmpe,
        y_obs=y_obs,
        simulator_fn=simulator_fn,
        prior_fn=prior_fn,
        local_fn=local_fn,
        global_names=['log_beta', 'log_mu', 'log_Qin'],
        n_groups=n_patients,
        n_rounds=config.n_rounds,
        n_samples_per_round=config.n_samples_per_round,
        n_val_samples=config.n_val_samples,
        opt=opt,
        n_iter_per_round=config.n_iter_per_round,
        batch_size=effective_batch_size,
        rng=key,
        independence=independence,
        labeller=labeller,
    )

    sync_jax_devices()
    t_train_end = time.perf_counter()

    if energy_monitor:
        train_energy = energy_monitor.stop()
        results.tfmpe.training_energy_j = train_energy.energy_joules

    results.tfmpe.training_wall_time_s = t_train_end - t_train_start
    results.tfmpe.n_training_samples = config.n_samples_per_round
    results.tfmpe.n_training_iterations = config.n_iter_per_round

    # Extract final losses
    if all_losses:
        train_local, val_local, train_global, val_global = all_losses[-1]
        results.tfmpe.final_train_loss = float(train_global[-1]) if len(train_global) > 0 else float(train_local[-1])
        results.tfmpe.final_val_loss = float(val_global[-1]) if len(val_global) > 0 else float(val_local[-1])

    print(f"  Training time: {results.tfmpe.training_wall_time_s:.1f}s")
    if config.measure_energy:
        print(f"  Training energy: {results.tfmpe.training_energy_j:.1f}J")

    plot_training_losses(all_losses, config.output_dir)

    # Posterior sampling with timing
    print("\n[5/6] Posterior sampling...")

    y_obs_for_sampling = _repeat_context_for_sampling(y_obs, config.n_posterior_samples)
    context_tokens = Tokens.from_pytree(
        y_obs_for_sampling,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )

    params_template = {
        'log_beta': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32),
        'log_mu': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32),
        'log_Qin': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32),
        'log_Rt': jnp.zeros((config.n_posterior_samples, n_patients, n_terminals, 1), dtype=jnp.float32),
        'log_C': jnp.zeros((config.n_posterior_samples, n_patients, n_terminals, 1), dtype=jnp.float32),
    }
    params_tokens = Tokens.from_pytree(
        params_template,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )

    # Warmup sampling (JIT compilation)
    print(f"  Warmup runs: {config.n_sample_warmup}")
    for _ in range(config.n_sample_warmup):
        _ = trained_tfmpe.sample_posterior(
            context=context_tokens,
            params=params_tokens
        )
        sync_jax_devices()

    # Timed sampling
    print(f"  Timed runs: {config.n_sample_repeats}")
    sample_times = []

    if config.measure_energy:
        energy_monitor = EnergyMonitor()
        energy_monitor.start()

    for _ in range(config.n_sample_repeats):
        sync_jax_devices()
        t_start = time.perf_counter()
        posterior_samples = trained_tfmpe.sample_posterior(
            context=context_tokens,
            params=params_tokens
        )
        sync_jax_devices()
        t_end = time.perf_counter()
        sample_times.append(t_end - t_start)

    if config.measure_energy:
        sample_energy = energy_monitor.stop()
        results.tfmpe.posterior_sampling_energy_j = sample_energy.energy_joules
        results.tfmpe.telemetry_method = sample_energy.source

    # Posterior sampling results (renamed from wall_time_per_sample)
    results.tfmpe.n_posterior_samples = config.n_posterior_samples
    results.tfmpe.posterior_sampling_wall_time_s = float(np.mean(sample_times))
    results.tfmpe.posterior_time_per_sample_mean_s = float(np.mean(sample_times)) / config.n_posterior_samples
    results.tfmpe.posterior_time_per_sample_median_s = float(np.median(sample_times)) / config.n_posterior_samples

    # Solver settings for reproducibility
    results.tfmpe.solver_settings = "diffrax.Dopri5(rtol=1e-3, atol=1e-3)"

    print(f"  Posterior sampling time (mean): {results.tfmpe.posterior_sampling_wall_time_s:.3f}s for {config.n_posterior_samples} samples")
    print(f"  Per-sample time: {results.tfmpe.posterior_time_per_sample_mean_s*1000:.3f}ms")

    # Extract samples (needed for likelihood timing and analysis)
    samples_dict = posterior_samples.decode()

    # Likelihood sampling timing (direct measurement)
    # Likelihood sampling (y ~ q_phi(y|theta)) uses the same ODE solver as posterior sampling,
    # but with swapped roles: context=theta, output=y. We measure this directly.
    print("\n  Timing likelihood sampling...")

    # Create context (theta) and output (y) templates for likelihood sampling
    like_n_samples = config.n_posterior_samples  # Use same sample count for comparison
    theta_template_for_like = {
        'log_beta': samples_dict['log_beta'],
        'log_mu': samples_dict['log_mu'],
        'log_Qin': samples_dict['log_Qin'],
        'log_Rt': samples_dict['log_Rt'],
        'log_C': samples_dict['log_C'],
    }

    theta_context_tokens = Tokens.from_pytree(
        theta_template_for_like,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )

    # y output template (what we're sampling)
    y_output_template = {
        'y': jnp.zeros((like_n_samples, n_patients, n_terminals, config.N_t, 1), dtype=jnp.float32),
    }
    y_output_tokens = Tokens.from_pytree(
        y_output_template,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )

    # Warmup likelihood sampling
    for _ in range(config.n_sample_warmup):
        _ = trained_tfmpe.sample_posterior(
            context=theta_context_tokens,
            params=y_output_tokens
        )
        sync_jax_devices()

    # Timed likelihood sampling
    like_sample_times = []
    for _ in range(config.n_sample_repeats):
        sync_jax_devices()
        t_start = time.perf_counter()
        _ = trained_tfmpe.sample_posterior(
            context=theta_context_tokens,
            params=y_output_tokens
        )
        sync_jax_devices()
        t_end = time.perf_counter()
        like_sample_times.append(t_end - t_start)

    # Likelihood sampling results
    # Divide by (n_samples * n_patients) to get TRUE per-patient time
    # This makes t_like directly comparable to t_sim (both are per-patient)
    results.tfmpe.like_sampling_wall_time_s = float(np.mean(like_sample_times))
    results.tfmpe.like_time_per_patient_mean_s = float(np.mean(like_sample_times)) / (like_n_samples * n_patients)
    results.tfmpe.like_time_per_patient_median_s = float(np.median(like_sample_times)) / (like_n_samples * n_patients)
    results.tfmpe.n_like_draws = like_n_samples

    print(f"  Likelihood sampling time (mean): {results.tfmpe.like_sampling_wall_time_s:.3f}s for {like_n_samples} draws × {n_patients} patients")
    print(f"  Per-patient time: {results.tfmpe.like_time_per_patient_mean_s*1000:.3f}ms")
    theta_g_samples = np.stack([
        np.array(samples_dict['log_beta'][:, 0, 0]),
        np.array(samples_dict['log_mu'][:, 0, 0]),
        np.array(samples_dict['log_Qin'][:, 0, 0]),
    ], axis=-1)

    theta_l_samples = np.stack([
        np.array(samples_dict['log_Rt'][:, 0, :, 0]),
        np.array(samples_dict['log_C'][:, 0, :, 0]),
    ], axis=-1)

    # Update CFD telemetry
    results.cfd = cfd_telemetry
    if config.measure_energy:
        results.cfd.telemetry_method = "nvidia-smi" if sample_energy.source == "nvidia-smi" else "none"
    else:
        results.cfd.telemetry_method = "none"

    # Update counts from simulator
    results.counts.n_cfd_sims_total = simulator.n_patient_sims
    # Stage 1: n=1 patient, n_samples_per_round param draws (uses CFD)
    results.counts.n_param_draws_stage1 = config.n_samples_per_round
    results.counts.n_patient_sims_stage1 = config.n_samples_per_round * 1  # n=1 in Stage 1
    # Stage 2: Training is simulator-free (uses likelihood sampling y ~ q_phi(y|theta))
    # NO CFD sims for Stage 2 training - that's the key speedup
    results.counts.n_param_draws_stage2 = config.n_samples_per_round
    results.counts.n_patient_sims_stage2 = 0  # Stage 2 training is simulator-free!
    results.counts.n_like_draws_stage2 = config.n_samples_per_round * n_patients
    # Likelihood samples actually drawn during Stage 2 training
    results.counts.n_like_samples_total = results.counts.n_like_draws_stage2
    # Note: Validation uses CFD (counted in n_cfd_sims_total), but Stage 2 training does not

    # Derived metrics (using directly measured values)
    results.derived.t_sim_mean_s = cfd_telemetry.wall_time_per_sim_mean_s
    results.derived.t_sim_median_s = cfd_telemetry.wall_time_per_sim_median_s
    results.derived.t_post_mean_s = results.tfmpe.posterior_time_per_sample_mean_s
    results.derived.t_post_median_s = results.tfmpe.posterior_time_per_sample_median_s
    results.derived.t_like_mean_s = results.tfmpe.like_time_per_patient_mean_s  # directly measured
    results.derived.t_like_median_s = results.tfmpe.like_time_per_patient_median_s  # directly measured

    # Store raw timings
    results.raw_timings = {
        "posterior_sample_times_s": sample_times,
        "like_sample_times_s": like_sample_times,
        "cfd_sim_times_s": simulator.get_raw_times(),
    }

    # Generate outputs
    print("\n[6/6] Generating outputs...")

    print("  Creating posterior summary...")
    summary_df = create_summary_table(
        theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, config.output_dir
    )

    print("  Creating posterior marginal plots...")
    plot_posterior_marginals(
        theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, config.output_dir
    )

    if config.skip_ppc:
        print("  Skipping PPC plot (--skip-ppc enabled)")
    else:
        print("  Creating smoothed PPC plot...")
        plot_posterior_predictive_smoothed(
            net, base, config, theta_g_samples, theta_l_samples,
            np.array(y_obs['y']), config.output_dir
        )

    # Save all results
    print("  Saving benchmark results...")

    np.savez(
        os.path.join(config.output_dir, 'posterior_samples.npz'),
        theta_g=theta_g_samples,
        theta_l=theta_l_samples,
        true_theta_g=true_theta_g,
        true_theta_l=true_theta_l,
        site_names=SITE_ORDER
    )

    with open(os.path.join(config.output_dir, 'bench_results.json'), 'w') as f:
        json.dump(results.to_dict(), f, indent=2)

    # Print summary
    print("\n" + "=" * 70)
    print("BENCHMARK SUMMARY")
    print("=" * 70)
    print(f"\nCFD Simulator:")
    print(f"  Total patient sims: {results.cfd.n_patient_sims}")
    print(f"  t_sim (median): {results.cfd.wall_time_per_sim_median_s:.4f}s")
    print(f"  t_sim (mean):   {results.cfd.wall_time_per_sim_mean_s:.4f}s")
    print(f"  t_sim (std):    {results.cfd.wall_time_per_sim_std_s:.4f}s")
    print(f"  Telemetry method: {results.cfd.telemetry_method}")
    print(f"  Parallelism: {results.cfd.parallelism}")

    print(f"\nTFMPE Inference:")
    print(f"  Training time: {results.tfmpe.training_wall_time_s:.1f}s")
    energy_str = f"{results.tfmpe.training_energy_j:.1f}J" if results.tfmpe.training_energy_j else "N/A"
    print(f"  Training energy: {energy_str} (via {results.tfmpe.telemetry_method})")
    print(f"  Posterior samples: {config.n_posterior_samples}")
    print(f"  t_post (per sample): {results.tfmpe.posterior_time_per_sample_mean_s*1000:.4f}ms")
    print(f"  t_like (per patient): {results.tfmpe.like_time_per_patient_mean_s*1000:.4f}ms  [comparable to t_sim]")
    print(f"  Solver: {results.tfmpe.solver_settings}")

    print(f"\nCounts:")
    print(f"  n_patients: {results.counts.n_patients}")
    print(f"  Stage 1: {results.counts.n_param_draws_stage1} param draws × 1 patient = {results.counts.n_patient_sims_stage1} CFD sims")
    print(f"  Stage 2: {results.counts.n_param_draws_stage2} param draws × {results.counts.n_patients} patients = {results.counts.n_like_draws_stage2} like draws (simulator-free!)")
    print(f"  Total CFD sims: {results.counts.n_cfd_sims_total} (obs gen + Stage 1 + validation)")
    print(f"  Total like samples: {results.counts.n_like_samples_total}")

    print(f"\nOutputs saved to: {os.path.abspath(config.output_dir)}")
    print("=" * 70)

    return results


def generate_latex_table(results: BenchmarkResults) -> str:
    """Generate LaTeX table matching the ICML template format."""
    cfd = results.cfd
    tfmpe = results.tfmpe
    counts = results.counts
    derived = results.derived
    sys_info = results.system_info

    cpu_info = sys_info.get("cpu", {})
    gpu_info = sys_info.get("gpu", {})
    software = sys_info.get("software", {})
    git_info = sys_info.get("git", {})

    # Format energy strings (handle None/0)
    train_energy_str = f"{tfmpe.training_energy_j:.1f}J" if tfmpe.training_energy_j else "N/A"
    cfd_energy_str = f"{cfd.energy_per_sim_mean_j:.2f}J" if cfd.energy_per_sim_mean_j else "N/A"

    lines = [
        r"\begin{table}[t]",
        r"\centering",
        r"\small",
        r"\setlength{\tabcolsep}{3pt}",
        r"\begin{tabular}{@{}ll@{}}",
        r"\toprule",
        r"\multicolumn{2}{l}{\textbf{Haemodynamics compute setup and telemetry}} \\",
        r"\midrule",
        "",
        r"\multicolumn{2}{l}{\textbf{CFD forward simulator (1D network)}} \\",
        f"\\multicolumn{{2}}{{l}}{{\\textit{{Hardware:}} CPU {cpu_info.get('model', 'TBD')[:30]} ({cpu_info.get('physical_cores', 'TBD')} cores); GPU {gpu_info.get('name', 'TBD')} ({gpu_info.get('memory_total_mb', 'TBD')} MB VRAM)}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{\\textit{{Software:}} JAX {software.get('jax', 'TBD')}; Diffrax {software.get('diffrax', 'TBD')}; commit \\texttt{{{git_info.get('cfd_commit', 'TBD')}}}}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{\\textit{{Parallelism:}} {cfd.parallelism}}} \\\\",
        r"\multicolumn{2}{l}{\textit{Unit cost:} 1 patient sim (5 cycles, keep final) $\rightarrow$ 4 outlet flows} \\",
        f"\\multicolumn{{2}}{{l}}{{$t_{{\\mathrm{{sim}}}}$ (mean/median): {cfd.wall_time_per_sim_mean_s:.4f}s / {cfd.wall_time_per_sim_median_s:.4f}s}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{Energy / sim: {cfd_energy_str} (via {cfd.telemetry_method})}} \\\\",
        "",
        r"\midrule",
        r"\multicolumn{2}{l}{\textbf{TFMPE inference + posterior sampling}} \\",
        f"\\multicolumn{{2}}{{l}}{{\\textit{{Hardware:}} GPU {gpu_info.get('name', 'TBD')}; CPU {cpu_info.get('model', 'TBD')[:30]}}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{\\textit{{Software:}} Flax {software.get('flax', 'TBD')}; precision {tfmpe.precision}; solver {tfmpe.solver_settings}}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{\\textit{{Software:}} commit \\texttt{{{git_info.get('tfmpe_commit', 'TBD')}}}}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{$t_{{\\mathrm{{post}}}}$ (per posterior sample): {tfmpe.posterior_time_per_sample_mean_s*1000:.4f}ms}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{$t_{{\\mathrm{{like}}}}$ (per likelihood sample): {tfmpe.like_time_per_patient_mean_s*1000:.4f}ms}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{Training time: {tfmpe.training_wall_time_s:.1f}s; Energy: {train_energy_str} (via {tfmpe.telemetry_method})}} \\\\",
        "",
        r"\midrule",
        r"\multicolumn{2}{l}{\textbf{Counts}} \\",
        f"\\multicolumn{{2}}{{l}}{{$n_{{\\mathrm{{patients}}}}$ = {counts.n_patients}; Stage 1: {counts.n_param_draws_stage1} draws $\\times$ 1 patient = {counts.n_patient_sims_stage1} CFD sims}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{Stage 2 (simulator-free): {counts.n_param_draws_stage2} draws $\\times$ {counts.n_patients} patients = {counts.n_like_draws_stage2} likelihood samples}} \\\\",
        f"\\multicolumn{{2}}{{l}}{{Total CFD sims: {counts.n_cfd_sims_total} (obs gen + Stage 1 + validation)}} \\\\",
        "",
        r"\midrule",
        r"\multicolumn{2}{l}{\textbf{Derived (illustrative)}} \\",
        f"\\multicolumn{{2}}{{l}}{{$t_{{\\mathrm{{sim}}}}$ = {derived.t_sim_mean_s:.4f}s; $t_{{\\mathrm{{like}}}}$ = {derived.t_like_mean_s*1000:.4f}ms; $t_{{\\mathrm{{post}}}}$ = {derived.t_post_mean_s*1000:.4f}ms}} \\\\",
        r"\multicolumn{2}{l}{Effective dataset time (NPE): $N n_s\, t_{\mathrm{sim}}$} \\",
        r"\multicolumn{2}{l}{Effective dataset time (TFMPE): $N_1 t_{\mathrm{sim}} + N_2 n_s\, t_{\mathrm{like}}$} \\",
        "",
        r"\bottomrule",
        r"\end{tabular}",
        r"\caption{Compute setup and runtime telemetry for haemodynamics experiments.}",
        r"\label{tab:hemo_compute_telemetry}",
        r"\end{table}",
    ]

    return "\n".join(lines)


# =============================================================================
# ENTRY POINT
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description="TFMPE Hemodynamics Benchmark")

    # Quick test mode
    parser.add_argument("--quick-test", action="store_true",
                        help="Run minimal config for e2e testing")

    # Simulation parameters
    parser.add_argument("--N_t", type=int, default=50)
    parser.add_argument("--nx", type=int, default=81)

    # Training parameters
    parser.add_argument("--n_samples", type=int, default=200)
    parser.add_argument("--n_iter", type=int, default=500)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)

    # Inference parameters
    parser.add_argument("--n_posterior", type=int, default=500)

    # Benchmarking parameters
    parser.add_argument("--n_sim_warmup", type=int, default=3)
    parser.add_argument("--n_sim_repeats", type=int, default=10)
    parser.add_argument("--n_sample_warmup", type=int, default=2)
    parser.add_argument("--n_sample_repeats", type=int, default=5)
    parser.add_argument("--no-energy", action="store_true",
                        help="Disable energy measurement")
    parser.add_argument("--skip-ppc", action="store_true",
                        help="Skip PPC plot generation (faster testing)")

    # Output
    parser.add_argument("--output_dir", type=str, default="tfmpe_bench_results")
    parser.add_argument("--latex", action="store_true",
                        help="Generate LaTeX table from existing results")

    args = parser.parse_args()

    # LaTeX generation mode
    if args.latex:
        results_path = os.path.join(args.output_dir, 'bench_results.json')
        if os.path.exists(results_path):
            with open(results_path) as f:
                data = json.load(f)
            results = BenchmarkResults(
                config=data.get("config", {}),
                system_info=data.get("system_info", {}),
                cfd=CFDTelemetry(**data.get("cfd", {})),
                tfmpe=TFMPETelemetry(**data.get("tfmpe", {})),
                derived=DerivedMetrics(**data.get("derived", {})),
            )
            print(generate_latex_table(results))

            # Also save to file
            latex_path = os.path.join(args.output_dir, 'compute_telemetry.tex')
            with open(latex_path, 'w') as f:
                f.write(generate_latex_table(results))
            print(f"\nLaTeX saved to: {latex_path}")
        else:
            print(f"No results found at {results_path}. Run benchmark first.")
        return

    # Quick test mode
    if args.quick_test:
        print("Running in QUICK TEST mode (minimal settings for e2e validation)")
        config = BenchConfig(
            N_t=20,
            nx=41,
            n_samples_per_round=10,
            n_val_samples=5,
            n_iter_per_round=10,
            batch_size=5,
            n_posterior_samples=20,
            n_sim_warmup=1,
            n_sim_repeats=2,
            n_sample_warmup=1,
            n_sample_repeats=2,
            measure_energy=not args.no_energy,
            output_dir=args.output_dir,
            n_ppc_curves=5,
            skip_ppc=True,  # Skip PPC for fast e2e testing
        )
    else:
        config = BenchConfig(
            N_t=args.N_t,
            nx=args.nx,
            n_samples_per_round=args.n_samples,
            n_iter_per_round=args.n_iter,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            n_posterior_samples=args.n_posterior,
            n_sim_warmup=args.n_sim_warmup,
            n_sim_repeats=args.n_sim_repeats,
            n_sample_warmup=args.n_sample_warmup,
            n_sample_repeats=args.n_sample_repeats,
            measure_energy=not args.no_energy,
            output_dir=args.output_dir,
            skip_ppc=args.skip_ppc,
        )

    results = run_benchmark(config)

    # Generate LaTeX table
    latex_table = generate_latex_table(results)
    latex_path = os.path.join(config.output_dir, 'compute_telemetry.tex')
    with open(latex_path, 'w') as f:
        f.write(latex_table)
    print(f"\nLaTeX table saved to: {latex_path}")


if __name__ == "__main__":
    main()
