#!/usr/bin/env python3
"""Evaluate a trained TabularACE model on synthetic 2D functions and visualize.

For each of N functions sampled from a source (MLP-SCM prior or synthetic grid):
- Pick Nc=128 context points and Nb=8 buffer points
- Use Nt (default 512) targets
- Fit TabICL-style scaler on context; transform context/buffer/target
- Run a single forward pass (no AR inference) to get mixture predictions
- Plot side-by-side: true target function (heatmap) vs model mean prediction

Saves a grid figure to outputs/tabular_eval_grid.png by default.
"""

import argparse
import json
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

from src.models.ace import AmortizedConditioningEngine
from src.models.modules import Transformer, MixtureGaussian
from src.models.tabular_embedder import TabularACE
from src.models.masks import create_training_block_mask
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr

# TabICL prior (lazy import handled in _generate_mlpscm)
MLPSCM = None  # lazy import when needed


def _infer_arch_from_state_dict(sd: dict) -> dict:
    """Infer TabularACE architecture hyperparameters from a state_dict.

    This covers older/newer checkpoints where the saved config may be missing
    or incompatible with current defaults.
    """
    # Helper: get a key that might have multiple prefixes in older ckpts
    def get_first_key(keys):
        for k in keys:
            if k in sd:
                return k
        return None

    # Effective model dim from first backbone layer
    q_w_key = get_first_key([
        "backbone.layers.0.attn.q_proj.weight",
    ])
    if q_w_key is None:
        raise RuntimeError("Could not locate backbone attention weights to infer dims")
    dim_model_eff = sd[q_w_key].shape[0]

    # FF dim from first layer
    ff1_w_key = get_first_key([
        "backbone.layers.0.ff1.weight",
    ])
    dim_ff_eff = sd[ff1_w_key].shape[0] if ff1_w_key is not None else max(256, 2 * dim_model_eff)

    # Number of backbone layers
    num_layers = 0
    for k in sd.keys():
        if k.startswith("backbone.layers.") and k.endswith(".attn.q_proj.weight"):
            try:
                idx = int(k.split(".")[2])
                num_layers = max(num_layers, idx + 1)
            except Exception:
                pass

    # Embedder cls tokens → (C, E)
    cls_key = get_first_key([
        "tabular_embedder.cls_tokens",
        "embedder.cls_tokens",
    ])
    if cls_key is None:
        # Fallback: try ar_tokens to get E, assume C=1 (no concat)
        ar_key = get_first_key([
            "tabular_embedder.ar_tokens",
            "ar_token",
            "embedder.ar_tokens",
        ])
        if ar_key is None:
            raise RuntimeError("Could not locate embedder tokens to infer embed dim")
        embed_dim = sd[ar_key].shape[1]
        num_cls_tokens = 1
    else:
        num_cls_tokens, embed_dim = sd[cls_key].shape

    # Decide whether CLS are concatenated before backbone
    concat_cls = (dim_model_eff == embed_dim * num_cls_tokens)

    # Max buffer size from AR tokens
    ar_key = get_first_key([
        "tabular_embedder.ar_tokens",
        "ar_token",
        "embedder.ar_tokens",
    ])
    max_buffer_size = sd[ar_key].shape[0] if ar_key is not None else 32

    # Head: num components and ff dim/model checks
    w1_key = get_first_key([
        "head.head.w1",
        "head.w1",
    ])
    if w1_key is not None and hasattr(sd[w1_key], "shape") and len(sd[w1_key].shape) == 3:
        num_components = sd[w1_key].shape[0]
        # Sanity: ensure dim_ff matches if available from head
        dim_ff_eff = sd[w1_key].shape[1]
        # dim_model_eff is sd[w1_key].shape[2], but prefer backbone-derived value
    else:
        num_components = 20

    # Embedder internals: ISAB blocks (column encoder)
    num_isab_blocks = 0
    num_inducing_points = 64
    for k, v in sd.items():
        if k.startswith("tabular_embedder.column_processor.blocks.") and k.endswith(".inducing_points"):
            try:
                idx = int(k.split(".")[3])
                num_isab_blocks = max(num_isab_blocks, idx + 1)
                num_inducing_points = v.shape[0]
            except Exception:
                pass

    # Row encoder depth
    row_num_blocks = 0
    for k in sd.keys():
        if k.startswith("tabular_embedder.row_encoder.layers.") and k.endswith(".self_attn.in_proj_weight"):
            try:
                idx = int(k.split(".")[3])
                row_num_blocks = max(row_num_blocks, idx + 1)
            except Exception:
                pass
    if row_num_blocks == 0:
        # Some versions may not have explicit row encoder layers saved; default to 1
        row_num_blocks = 1

    return {
        "embed_dim": embed_dim,
        "num_cls_tokens": num_cls_tokens,
        "concat_cls": concat_cls,
        "transformer_layers": num_layers,
        "dim_model_eff": dim_model_eff,
        "dim_feedforward": dim_ff_eff,
        "num_components": num_components,
        "max_buffer_size": max_buffer_size,
        "num_isab_blocks": num_isab_blocks if num_isab_blocks > 0 else 1,
        "num_inducing_points": num_inducing_points,
        "row_num_blocks": row_num_blocks,
    }


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> AmortizedConditioningEngine:
    """Build the current TabularACE architecture and adaptively load a checkpoint.

    Current architecture is the ground truth. Any tensor shape mismatches are
    deterministically resized (slice/pad) so that all keys are loaded.
    """
    ckpt = torch.load(ckpt_path, map_location=device)
    cfg = ckpt.get("config", {})
    sd_src = ckpt.get("model_state_dict", {})

    # Build CURRENT architecture defaults (aligned with train_tabular_online.yaml)
    # Use conservative fallbacks for evaluation-time inputs (2D functions)
    mcfg = cfg.get("model", {})
    embed_cfg = mcfg.get("embedder", {})
    back_cfg = mcfg.get("backbone", {})
    head_cfg = mcfg.get("head", {})

    # Feature count is driven by the eval source (2D); weights don't depend on F
    num_features = 2
    embed_dim = 64
    nhead = 4
    transformer_layers = 12
    num_components = int(head_cfg.get("num_components", 20))
    max_buffer_size = 32
    num_target_points = int(mcfg.get("num_target_points", 512))
    targets_block = int(mcfg.get("targets_block_size_for_buffer_attend", 32))
    dropout = float(back_cfg.get("dropout", 0.0))
    num_isab_blocks = int(embed_cfg.get("num_isab_blocks", 3))
    num_inducing_points = int(embed_cfg.get("num_inducing_points", 128))
    row_rope_base = int(embed_cfg.get("row_rope_base", 100000))
    row_num_blocks = int(embed_cfg.get("num_layers", 3))
    num_cls_tokens = int(embed_cfg.get("num_cls_tokens", 4))
    concat_cls = bool(embed_cfg.get("concat_cls", True))

    model = TabularACE(
        num_features=num_features,
        embed_dim=embed_dim,
        transformer_layers=transformer_layers,
        nhead=nhead,
        dim_feedforward=None,
        num_components=num_components,
        max_buffer_size=max_buffer_size,
        num_target_points=num_target_points,
        targets_block_size_for_buffer_attend=targets_block,
        dropout=dropout,
        num_isab_blocks=num_isab_blocks,
        num_inducing_points=num_inducing_points,
        row_rope_base=row_rope_base,
        row_num_blocks=row_num_blocks,
        num_cls_tokens=num_cls_tokens,
        concat_cls=concat_cls,
        ff_factor=2.0,
    ).to(device)

    # Adapt source weights into target model shapes: slice/pad deterministically
    tgt_sd = model.state_dict()

    def adapt_tensor(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        # Create a zeros tensor like target and copy overlapping slices
        out = torch.zeros_like(tgt)
        src_view = src.detach()
        # Ensure dtype/device match
        src_view = src_view.to(dtype=tgt.dtype, device=tgt.device)
        # Compute overlap per dimension
        dims = min(src_view.ndim, out.ndim)
        slices_src = []
        slices_tgt = []
        for i in range(out.ndim):
            ts = out.shape[i]
            ss = src_view.shape[i] if i < src_view.ndim else 1
            n = min(ts, ss)
            slices_src.append(slice(0, n))
            slices_tgt.append(slice(0, n))
        out[tuple(slices_tgt)] = src_view[tuple(slices_src)]
        return out

    adapted = {}
    converted, carried = 0, 0
    for k, tgt in tgt_sd.items():
        if k in sd_src:
            adapted[k] = adapt_tensor(sd_src[k], tgt)
            converted += 1
        else:
            # carry target init
            adapted[k] = tgt
            carried += 1

    model.load_state_dict(adapted, strict=True)
    print(f"[eval_tabular_grid] Adaptive load complete: converted={converted}, carried_default={carried}.")
    model.eval()
    return model


def mixture_mean(means: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """Compute mixture mean from component means and weights.

    means: [B, T, K, Dy]; weights: [B, T, K, Dy] or [B, T, K, 1]
    returns: [B, T, Dy]
    """
    if weights.shape[-1] == 1:
        w = weights
    else:
        w = weights
    return (w * means).sum(dim=2)


def _choose_uniform_context(X_np: np.ndarray, nc: int) -> np.ndarray:
    """Pick ~nc indices whose 2D coordinates are spread across the domain.

    Uses a simple grid binning (GxG) and picks at most one point per bin.
    Fills any deficit with random unused points.
    """
    assert X_np.shape[1] >= 2, "Uniform context selection expects 2D inputs"
    n = X_np.shape[0]
    gx = int(np.ceil(np.sqrt(nc)))
    # Compute bin edges
    mins = X_np[:, :2].min(axis=0)
    maxs = X_np[:, :2].max(axis=0)
    edges0 = np.linspace(mins[0], maxs[0], gx + 1)
    edges1 = np.linspace(mins[1], maxs[1], gx + 1)

    chosen = []
    # Assign bins
    ix0 = np.clip(np.digitize(X_np[:, 0], edges0) - 1, 0, gx - 1)
    ix1 = np.clip(np.digitize(X_np[:, 1], edges1) - 1, 0, gx - 1)
    # Map (i,j) -> first index (or random among those in bin)
    bins = {}
    rng = np.random.default_rng(0)
    for idx in rng.permutation(n):
        key = (ix0[idx], ix1[idx])
        if key not in bins:
            bins[key] = idx
            chosen.append(idx)
            if len(chosen) >= nc:
                break
    if len(chosen) < nc:
        remaining = np.setdiff1d(np.arange(n), np.array(chosen, dtype=int), assume_unique=False)
        extra = rng.choice(remaining, size=nc - len(chosen), replace=False) if remaining.size >= (nc - len(chosen)) else remaining
        chosen = list(chosen) + list(extra)
    return np.array(chosen[:nc], dtype=int)


def _generate_mlpscm():
    global MLPSCM
    if MLPSCM is None:
        # Lazy import (and repo-local fallback) only when mlpscm is used
        try:
            from tabicl.prior.mlp_scm import MLPSCM as _MLP
        except Exception:
            import sys as _sys
            _sys.path.append(str(Path(__file__).resolve().parent.parent / "tabicl" / "src"))
            from tabicl.prior.mlp_scm import MLPSCM as _MLP
        MLPSCM = _MLP
    return MLPSCM(
        seq_len=5000,
        num_features=2,
        num_outputs=1,
        is_causal=True,
        num_causes=np.random.randint(2, 5),
        y_is_effect=True,
        in_clique=False,
        sort_features=True,
        num_layers=np.random.randint(3, 6),
        hidden_dim=32,
        mlp_activations=torch.nn.Tanh,
        init_std=np.random.uniform(0.8, 2.0),
        block_wise_dropout=True,
        mlp_dropout_prob=np.random.uniform(0.05, 0.2),
        sampling="mixed",
        pre_sample_cause_stats=True,
        noise_std=0.01,
        pre_sample_noise_std=False,
        device="cpu",
    )


def _generate_synth_grid(grid_size: int = 64, noise_std: float = 0.01):
    """Generate a diverse synthetic function on a uniform 2D grid in [0,1]^2.

    Returns X_all (G*G,2), y_all (G*G,1).
    """
    g = grid_size
    x = np.linspace(0.0, 1.0, g)
    y = np.linspace(0.0, 1.0, g)
    Xg, Yg = np.meshgrid(x, y)

    # Random affine warp (rotation + anisotropic scale) around center
    cx, cy = 0.5, 0.5
    X0, Y0 = Xg - cx, Yg - cy
    theta = np.random.uniform(0, 2 * np.pi)
    s1 = np.random.uniform(0.6, 1.4)
    s2 = np.random.uniform(0.6, 1.4)
    cth, sth = np.cos(theta), np.sin(theta)
    Xw = s1 * (cth * X0 - sth * Y0)
    Yw = s2 * (sth * X0 + cth * Y0)

    # Pick a recipe
    kinds = ["polywave", "rbf_mix", "ripples", "checker", "saddle", "spiral", "fourier", "ridge"]
    kind = np.random.choice(kinds)

    if kind == "polywave":
        a = np.random.uniform(0.3, 0.8)
        b = np.random.uniform(0.2, 0.6)
        f = (
            a * np.sin(2 * np.pi * (np.random.uniform(1.0, 3.0) * Xw + np.random.uniform(1.0, 3.0) * Yw))
            + b * np.cos(2 * np.pi * (np.random.uniform(0.5, 2.0) * Xw - np.random.uniform(0.5, 2.0) * Yw))
            + np.random.uniform(0.2, 0.5) * (Xw)**2
            - np.random.uniform(0.1, 0.3) * (Yw)
        )
    elif kind == "rbf_mix":
        # Sum of a few anisotropic radial basis bumps
        k = np.random.randint(2, 5)
        cx0 = np.random.uniform(-0.3, 0.3, size=k)
        cy0 = np.random.uniform(-0.3, 0.3, size=k)
        sigx = np.random.uniform(0.05, 0.2, size=k)
        sigy = np.random.uniform(0.05, 0.2, size=k)
        amp = np.random.uniform(-1.0, 1.0, size=k)
        f = np.zeros_like(Xw)
        for i in range(k):
            rot = np.random.uniform(0, 2 * np.pi)
            cr, sr = np.cos(rot), np.sin(rot)
            Xr = cr * (Xw - cx0[i]) - sr * (Yw - cy0[i])
            Yr = sr * (Xw - cx0[i]) + cr * (Yw - cy0[i])
            r2 = (Xr / sigx[i])**2 + (Yr / sigy[i])**2
            f += amp[i] * np.exp(-0.5 * r2)
    elif kind == "ripples":
        f = (
            np.random.uniform(0.3, 0.7) * np.sin(2 * np.pi * np.random.uniform(3, 8) * Xw)
            * np.sin(2 * np.pi * np.random.uniform(3, 8) * Yw)
            + np.random.uniform(0.2, 0.5) * np.sin(2 * np.pi * (Xw + np.random.uniform(0.2, 0.8) * Yw))
        )
    elif kind == "checker":
        kx = np.random.randint(2, 8)
        ky = np.random.randint(2, 8)
        phx = np.random.uniform(0, 2 * np.pi)
        phy = np.random.uniform(0, 2 * np.pi)
        f = np.sin(2 * np.pi * kx * Xw + phx) * np.sin(2 * np.pi * ky * Yw + phy)
    elif kind == "saddle":
        a = np.random.uniform(0.4, 1.0)
        b = np.random.uniform(-0.5, 0.5)
        f = a * (Xw**2 - Yw**2) + b * (Xw * Yw) + 0.3 * np.sin(2 * np.pi * (Xw + 0.5 * Yw))
    elif kind == "spiral":
        R = np.sqrt(Xw**2 + Yw**2)
        T = np.arctan2(Yw, Xw)
        a = np.random.uniform(6, 14)
        b = np.random.uniform(3, 8)
        f = np.sin(a * R + b * T)
    elif kind == "fourier":
        n = np.random.randint(3, 6)
        f = np.zeros_like(Xw)
        for _ in range(n):
            ux, uy = np.random.uniform(-4, 4), np.random.uniform(-4, 4)
            phase = np.random.uniform(0, 2 * np.pi)
            amp = np.random.uniform(-0.7, 0.7)
            f += amp * np.sin(2 * np.pi * (ux * Xw + uy * Yw) + phase)
    else:  # ridge
        theta_r = np.random.uniform(0, 2 * np.pi)
        cr, sr = np.cos(theta_r), np.sin(theta_r)
        proj = cr * Xw + sr * Yw
        m = np.random.uniform(4, 12)
        c0 = np.random.uniform(-0.3, 0.3)
        f = np.tanh(m * (proj - c0)) + 0.3 * np.cos(2 * np.pi * (Xw + 0.3 * Yw))

    # Random output nonlinearity for extra variety
    if np.random.rand() < 0.3:
        f = np.tanh(f)
    elif np.random.rand() < 0.3:
        f = np.sign(f) * np.sqrt(np.abs(f))

    f += noise_std * np.random.randn(*f.shape)
    X_all = np.stack([Xg.ravel(), Yg.ravel()], axis=1)
    y_all = f.reshape(-1, 1)
    return torch.from_numpy(X_all).float(), torch.from_numpy(y_all).float()


def eval_one_function(model: AmortizedConditioningEngine, device: torch.device, nc: int, nb: int, nt: int,
                      norm_method: str = "power", uniform_context: bool = False,
                      source: str = "mlpscm", grid_size: int = 64, noise_std: float = 0.01):
    # Generate a random 2D function sampled densely
    if source == "mlpscm":
        prior = _generate_mlpscm()
        with torch.no_grad():
            X_all, y_all = prior()
    else:
        X_all, y_all = _generate_synth_grid(grid_size=grid_size, noise_std=noise_std)
    if y_all.dim() == 1:
        y_all = y_all.unsqueeze(-1)

    # Select Nc, Nb, Nt from X_all
    N = X_all.shape[0]
    if uniform_context:
        idx_c = _choose_uniform_context(X_all.cpu().numpy(), nc)
        mask = np.ones(N, dtype=bool)
        mask[idx_c] = False
        rem = np.nonzero(mask)[0]
        rem_perm = np.random.permutation(rem)
        idx_b = rem_perm[:nb]
        idx_t = rem_perm[nb:nb + nt]
        idx_c_t = torch.from_numpy(idx_c)
        idx_b_t = torch.from_numpy(idx_b)
        idx_t_t = torch.from_numpy(idx_t)
    else:
        perm = torch.randperm(N)
        idx_c_t = perm[:nc]
        idx_b_t = perm[nc:nc + nb]
        idx_t_t = perm[nc + nb:nc + nb + nt]

    xc = X_all[idx_c_t]
    yc = y_all[idx_c_t]
    xb = X_all[idx_b_t]
    yb = y_all[idx_b_t]
    xt = X_all[idx_t_t]
    yt = y_all[idx_t_t]

    # Fit scaler on context and transform all splits
    scaler = TabICLScaler(normalization_method=norm_method, outlier_threshold=4.0, random_state=0)
    scaler.fit(xc.numpy(), yc.numpy())
    xc_s, yc_s = scaler.transform_xy(xc.numpy(), yc.numpy())
    xb_s, yb_s = scaler.transform_xy(xb.numpy(), yb.numpy())
    xt_s, yt_s = scaler.transform_xy(xt.numpy(), yt.numpy())

    # Build batch on device
    batch = DataAttr(
        xc=torch.from_numpy(xc_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        yc=torch.from_numpy(yc_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        xb=torch.from_numpy(xb_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        yb=torch.from_numpy(yb_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        xt=torch.from_numpy(xt_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        yt=torch.from_numpy(yt_s).unsqueeze(0).to(device=device, dtype=torch.float32),
    )

    # Create block mask for training-style forward
    R = nc + nb + nt
    mask = create_training_block_mask(
        current_total_q_len=R,
        current_total_kv_len=R,
        current_context_section_len=nc,
        current_buffer_section_len=nb,
        device=device,
    )

    with torch.no_grad():
        out = model(batch, mask)
        means = out.means  # [B, T, K, Dy]
        weights = out.weights  # [B, T, K, Dy] or [B, T, K, 1]
        pred_mean = mixture_mean(means, weights)  # [B, T, Dy]

    # Return raw (unscaled) for plotting of truth, and scaled pred back to numpy
    # We'll visualize both on the same (scaled) coordinate system for fairness
    return (xc_s, yc_s, xt_s, yt_s, pred_mean.squeeze(0).cpu().numpy())


def plot_side_by_side(ax_true, ax_pred, Xc, Yc, Xt, Yt, Ypred, title_idx: int):
    # Determine bounds from both context + target (scaled space)
    x_min = min(Xc[:, 0].min(), Xt[:, 0].min())
    x_max = max(Xc[:, 0].max(), Xt[:, 0].max())
    y_min = min(Xc[:, 1].min(), Xt[:, 1].min())
    y_max = max(Xc[:, 1].max(), Xt[:, 1].max())

    grid_size = 80
    xi = np.linspace(x_min, x_max, grid_size)
    yi = np.linspace(y_min, y_max, grid_size)
    Xi, Yi = np.meshgrid(xi, yi)

    # Interpolate true and predicted means at target points
    Zi_true = griddata(Xt[:, :2], Yt[:, 0], (Xi, Yi), method='linear', fill_value=np.nan)
    Zi_pred = griddata(Xt[:, :2], Ypred[:, 0], (Xi, Yi), method='linear', fill_value=np.nan)

    vmin = np.nanmin([Zi_true, Zi_pred])
    vmax = np.nanmax([Zi_true, Zi_pred])

    # True
    im1 = ax_true.contourf(Xi, Yi, Zi_true, levels=15, cmap='RdBu_r', alpha=0.9, vmin=vmin, vmax=vmax)
    ax_true.scatter(Xc[:, 0], Xc[:, 1], c=Yc[:, 0], cmap='RdBu_r', s=30, edgecolors='black', linewidth=1.0,
                    vmin=vmin, vmax=vmax, zorder=10)
    ax_true.set_title(f'Function {title_idx} (true)', fontsize=10)
    ax_true.set_aspect('equal')
    ax_true.set_xlabel('x1'); ax_true.set_ylabel('x2')

    # Predicted mean
    im2 = ax_pred.contourf(Xi, Yi, Zi_pred, levels=15, cmap='RdBu_r', alpha=0.9, vmin=vmin, vmax=vmax)
    ax_pred.scatter(Xc[:, 0], Xc[:, 1], c=Yc[:, 0], cmap='RdBu_r', s=30, edgecolors='black', linewidth=1.0,
                    vmin=vmin, vmax=vmax, zorder=10)
    ax_pred.set_title(f'Function {title_idx} (predicted mean)', fontsize=10)
    ax_pred.set_aspect('equal')
    ax_pred.set_xlabel('x1'); ax_pred.set_ylabel('x2')

    return im1, im2


def main():
    ap = argparse.ArgumentParser(description="Eval trained TabularACE on MLPSCM 2D functions (grid plots)")
    ap.add_argument('--checkpoint', required=True, help='Path to trained checkpoint .pt')
    ap.add_argument('--device', default='cpu', help='cpu or cuda')
    ap.add_argument('--nfuncs', type=int, default=6)
    ap.add_argument('--nc', type=int, default=128)
    ap.add_argument('--nb', type=int, default=8)
    ap.add_argument('--nt', type=int, default=512)
    ap.add_argument('--norm', default='power', choices=['power', 'quantile', 'quantile_rtdl', 'none'])
    ap.add_argument('--source', default='mlpscm', choices=['mlpscm', 'synth'], help='Function source')
    ap.add_argument('--grid-size', type=int, default=64, help='Grid size for synth source (GxG)')
    ap.add_argument('--noise-std', type=float, default=0.01, help='Noise std for synth source')
    ap.add_argument('--out', default='outputs/tabular_eval_grid.png')
    ap.add_argument('--uniform-context', action='store_true', help='Choose context points approximately uniformly over x-space')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)

    rows = args.nfuncs
    fig, axes = plt.subplots(rows, 2, figsize=(12, 3 * rows))
    if rows == 1:
        axes = np.array([axes])

    for i in range(rows):
        xc, yc, xt, yt, ypred = eval_one_function(
            model, device, nc=args.nc, nb=args.nb, nt=args.nt, norm_method=args.norm,
            uniform_context=args.uniform_context, source=args.source,
            grid_size=args.grid_size, noise_std=args.noise_std
        )
        ax_true, ax_pred = axes[i]
        plot_side_by_side(ax_true, ax_pred, xc, yc, xt, yt, ypred, i + 1)

    plt.tight_layout()
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(args.out, dpi=150, bbox_inches='tight')
    print(f"Saved evaluation grid to {args.out}")


if __name__ == '__main__':
    main()
