import sys
import argparse
import copy
import logging
import math
import random
from pathlib import Path
from typing import Any, Callable, Literal, Sequence, Tuple, Union

import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import torch
from torch import nn
from omegaconf import OmegaConf

from src.data.bav_sampler_deprecated import BavSampler
from src.models.ace import AmortizedConditioningEngine, InferenceEngine2
from src.models.modules import Embedder, MixtureGaussian, Transformer

from datetime import datetime


date_str = datetime.now().strftime("%Y-%m-%d")
Path("logs").mkdir(exist_ok=True)
log_file = f"logs/app-{date_str}.log"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    handlers=[
        logging.FileHandler(log_file, encoding="utf-8"),
        logging.StreamHandler(sys.stdout),  # use sys.stderr by default if you prefer
    ],
    force=True,  # override any prior logging config
)

logging.getLogger("myapp").info("This goes to file AND terminal")

# -----------------------------------------------------------------------------
# Constants / Globals
# -----------------------------------------------------------------------------
Y_LIM = (-20, 23)
DEFAULT_CKPT_RHO43 = Path("checkpoints/bavrho43_model/best_model.pt")
DEFAULT_CKPT_RHO1 = Path("checkpoints/bavrho1_model/best_model.pt")


# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------

def _print_device_type(module: nn.Module) -> str:
    """Print just the module's device TYPE (e.g., 'cuda', 'cpu', 'mps')."""
    t = next(iter(module.parameters()), None)
    if t is None:
        t = next(iter(module.buffers()), None)
    dev_type = t.device.type if t is not None else "cpu"
    print(dev_type, flush=True)
    return dev_type

def _set_random_seed(seed: int) -> None:
    """Set random seed for reproducibility across torch, numpy, random.

    Args:
        seed: Random seed value.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def _build_mc_permutation(batch, n_mc: int):
    """Create Monte Carlo permutations of targets for a single-element batch.

    Input batch is expected to have attributes: xc, yc, xt, yt.

    Shapes:
        In:  batch.xc: [1, T_c, dim_x], batch.yc: [1, T_c, dim_y]
             batch.xt: [1, T_t, dim_x], batch.yt: [1, T_t, dim_y]
        Out: new_batch.xc: [n_mc, T_c, dim_x], new_batch.yc: [n_mc, T_c, dim_y]
             new_batch.xt: [n_mc, T_t, dim_x], new_batch.yt: [n_mc, T_t, dim_y]
             inv_perm:     LongTensor [n_mc, T_t] to invert the permutation

    Notes:
        * Uses Tensor.expand for broadcasting (views) to avoid extra memory.
        * Returns an inverse permutation to map outputs back to original order.
    """
    new_batch = copy.deepcopy(batch)
    new_batch.xc = new_batch.xc.expand(n_mc, -1, -1)
    new_batch.yc = new_batch.yc.expand(n_mc, -1, -1)
    new_batch.xt = new_batch.xt.expand(n_mc, -1, -1)
    new_batch.yt = new_batch.yt.expand(n_mc, -1, -1)

    b, t, d = new_batch.xt.shape
    perm = torch.stack(
        [torch.randperm(t, device=new_batch.xt.device) for _ in range(b)], dim=0
    )
    inv_perm = torch.argsort(perm, dim=1)  # [b, t]

    new_batch.xt = new_batch.xt.gather(
        dim=1, index=perm.unsqueeze(-1).expand(-1, -1, d)
    )
    new_batch.yt = new_batch.yt.gather(
        dim=1, index=perm.unsqueeze(-1).expand(-1, -1, 1)
    )

    return new_batch, inv_perm


# -----------------------------------------------------------------------------
# Model building / loading
# -----------------------------------------------------------------------------
def build_ace_model(config: OmegaConf) -> AmortizedConditioningEngine:
    """Build ACE model from OmegaConf config (embedder → transformer → head)."""
    cfg = config.model

    embedder = Embedder(
        dim_x=cfg.dim_x,
        dim_y=cfg.dim_y,
        hidden_dim=cfg.embedder.hidden_dim,
        out_dim=cfg.dim_model,
        depth=cfg.embedder.depth,
    )

    backbone = Transformer(
        num_layers=cfg.backbone.num_layers,
        dim_model=cfg.dim_model,
        num_head=cfg.backbone.num_heads,
        dim_feedforward=cfg.backbone.dim_feedforward,
        dropout=cfg.backbone.dropout,
    )

    head = MixtureGaussian(
        dim_y=cfg.dim_y,
        dim_model=cfg.dim_model,
        dim_feedforward=cfg.head.dim_feedforward,
        num_components=cfg.head.num_components,
    )

    model = AmortizedConditioningEngine(
        embedder=embedder,
        backbone=backbone,
        head=head,
        max_buffer_size=cfg.max_buffer_size,
        targets_block_size_for_buffer_attend=cfg.targets_block_size_for_buffer_attend,
    )

    return model


def load_model(
    checkpoint_path: Union[str, Path],
    device: Union[str, torch.device] = "cpu",
    model_builder: Callable[[OmegaConf], AmortizedConditioningEngine] = build_ace_model,
    compile_model: bool = False,
) -> Tuple[AmortizedConditioningEngine, OmegaConf]:
    """Load a model from a checkpoint file.

    Args:
        checkpoint_path: Path to the model checkpoint.
        device: Device where the model will be loaded.
        model_builder: Function that builds the model architecture.
        compile_model: Whether to wrap the model using ``torch.compile``.

    Returns:
        (model, config) where model is on ``device`` and in eval mode.
    """
    checkpoint_path = Path(checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = OmegaConf.create(checkpoint["config"])
    model = model_builder(config)
    state_dict = checkpoint["model_state_dict"]

    # Handle torch.compile prefix if present
    if any(key.startswith("_orig_mod.") for key in state_dict.keys()):
        state_dict = {
            key.replace("_orig_mod.", ""): value for key, value in state_dict.items()
        }

    model.load_state_dict(state_dict)
    model = model.to(device)
    if compile_model:
        model = torch.compile(model)
        logging.info("Model compiled with torch.compile().")
    else:
        logging.info("Running without compiled model.")
    model.eval()
    return model, config


# -----------------------------------------------------------------------------
# Evaluation
# -----------------------------------------------------------------------------
import logging
import math
import time
from typing import Tuple, Optional, Union, Literal

import torch


def eval_once(
    engine_a: "InferenceEngine2",
    engine_b: "InferenceEngine2",
    batch,
    name_a: str,
    name_b: str,
    K: int = 8,
    *,
    timer: Literal["cpu", "cuda"] = "cpu",
    cuda_device: Optional[Union[int, str, torch.device]] = None,
) -> Tuple[str, float]:
    """Compare two inference engines on a list of batches, with simple timing.

    Per-iteration logging:
        - Logs only the total elapsed time for each loop iteration.

    Final summary:
        - Shows readable stats (mean±std, p50, p95, min–max) for per-iteration totals.
        - Totals and averages for each engine and overall wall time.
        - Throughput and overhead.

    Args:
        timer: "cpu" -> time.perf_counter(); "cuda" -> torch.cuda.Event timing.
        cuda_device: Which CUDA device to time on (only used if timer="cuda").

    Returns:
        (preferred_model_name, win_rate_of_A)
    """
    def _fmt_s(x: float) -> str:
        if x < 1e-6:
            return f"{x*1e9:.1f} ns"
        if x < 1e-3:
            return f"{x*1e6:.1f} µs"
        if x < 1.0:
            return f"{x*1e3:.2f} ms"
        if x < 60.0:
            return f"{x:.3f} s"
        return f"{x/60.0:.2f} min"

    use_cuda_timing = (timer.lower() == "cuda")
    if use_cuda_timing:
        if not torch.cuda.is_available():
            logging.warning("timer='cuda' requested but CUDA not available; falling back to CPU timing.")
            use_cuda_timing = False
        elif cuda_device is not None:
            torch.cuda.set_device(cuda_device)

    def _time_call(fn, *args, **kwargs):
        if use_cuda_timing:
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = fn(*args, **kwargs)
            end.record()
            end.synchronize()
            return out, start.elapsed_time(end) / 1000.0  # seconds
        else:
            t0 = time.perf_counter()
            out = fn(*args, **kwargs)
            return out, (time.perf_counter() - t0)

    def _start_total():
        if use_cuda_timing:
            torch.cuda.synchronize()
            s = torch.cuda.Event(enable_timing=True)
            e = torch.cuda.Event(enable_timing=True)
            s.record()
            return ("cuda", s, e)
        else:
            return ("cpu", time.perf_counter(), None)

    def _stop_total(tok):
        kind, a, b = tok
        if kind == "cuda":
            b.record(); b.synchronize()
            return a.elapsed_time(b) / 1000.0
        else:
            return time.perf_counter() - a

    with torch.no_grad():
        jll_a_list = []
        jll_b_list = []
        last_shape = None

        # Timing accumulators
        per_iter_totals = []
        a_tot = 0.0
        b_tot = 0.0

        total_tok = _start_total()

        n = len(batch)
        step = max(1, math.ceil(n / 100)) 

        for i in range(n):
            # Run engines (sequentially) and only log the combined elapsed time

            ll_a, dt_a = _time_call(engine_a.evaluate_joint_loglikelihood, batch[i], K=K)
            ll_b, dt_b = _time_call(engine_b.evaluate_joint_loglikelihood, batch[i], K=K)

            iter_total = dt_a + dt_b
            a_tot += dt_a
            b_tot += dt_b
            per_iter_totals.append(iter_total)

            last_shape = tuple(ll_a.shape)
            joint_a = ll_a.sum(dim=1)
            joint_b = ll_b.sum(dim=1)
            jll_a_list.append(joint_a.mean())
            jll_b_list.append(joint_b.mean())

            if (i + 1) % step == 0 or (i + 1) == n:
                logging.info("[%-3d/%-3d] elapsed=%s", i + 1, n, _fmt_s(iter_total))

        grand_total = _stop_total(total_tok)

    # Metrics
    jll_a = torch.tensor(jll_a_list)
    jll_b = torch.tensor(jll_b_list)
    win_rate_a = (jll_a > jll_b).float().mean().item()
    total_log10_bf = (jll_a - jll_b).sum().item() / math.log(10.0)
    preferred = name_a if total_log10_bf > 0 else (name_b if total_log10_bf < 0 else "tie")

    # Readable timing summary
    it = torch.tensor(per_iter_totals) if per_iter_totals else torch.tensor([0.0])
    mean_it = it.mean().item()
    std_it = it.std(unbiased=False).item()
    p50 = torch.quantile(it, torch.tensor(0.5)).item()
    p95 = torch.quantile(it, torch.tensor(0.95)).item()
    mn = it.min().item()
    mx = it.max().item()

    engine_sum = a_tot + b_tot
    overhead = max(0.0, grand_total - engine_sum)  # host/logging/iteration overhead etc.
    throughput = (len(per_iter_totals) / grand_total) if grand_total > 0 else float("nan")
    device_str = (f"cuda:{torch.cuda.current_device()}" if use_cuda_timing and torch.cuda.is_available() else "cpu")

    if last_shape is not None:
        T = last_shape[1] if len(last_shape) > 1 else None

        logging.info(
            "\n"
            "== %s vs %s ==\n"
            "ll_shape=%s | T=%s | K(buffer batch)=%d\n"
            "win rate(%s) = %.1f%% (per-sample joint LL)\n"
            "total log10 BF (A vs B) = %.3f\n"
            "overall preference => %s\n"
            "\n"
            "---- Timing (mode=%s, device=%s) ----\n"
            "Samples: %d | Throughput: %.2f samples/s\n"
            "Per-iteration total (s): mean=%s  std=%s  p50=%s  p95=%s  min=%s  max=%s\n"
            "Engine totals:   A=%s   B=%s   (sum=%s)\n"
            "Grand total:     %s\n"
            "Overhead:        %s  (non-engine time)\n",
            name_a,
            name_b,
            last_shape,
            str(T),
            K,
            name_a,
            win_rate_a * 100.0,
            total_log10_bf,
            preferred,
            ("cuda" if use_cuda_timing else "cpu"),
            device_str,
            len(per_iter_totals),
            throughput if throughput == throughput else float("nan"),
            _fmt_s(mean_it),
            _fmt_s(std_it),
            _fmt_s(p50),
            _fmt_s(p95),
            _fmt_s(mn),
            _fmt_s(mx),
            _fmt_s(a_tot),
            _fmt_s(b_tot),
            _fmt_s(engine_sum),
            _fmt_s(grand_total),
            _fmt_s(overhead),
        )


    return preferred, win_rate_a


def model_selection_check(
    n_samples: int,
    n_points: int = 100,
    K: int = 16,
    n_mc: int = 10,
    device: Union[str, torch.device] = "cpu",
    ckpt_rho43: Union[str, Path] = DEFAULT_CKPT_RHO43,
    ckpt_rho1: Union[str, Path] = DEFAULT_CKPT_RHO1,
) -> bool:
    """Run model selection sanity check on synthetic data for ρ=4/3 and ρ=1."""

    # log args
    logging.info("MODEL SELECTION CHECK: %s", "START")
    logging.info("n_samples: %d", n_samples)
    logging.info("n_points: %d", n_points)
    logging.info("K: %d", K)
    logging.info("n_mc: %d", n_mc)
    logging.info("device: %s", device)
    logging.info("ckpt_rho43: %s", ckpt_rho43)
    logging.info("ckpt_rho1: %s", ckpt_rho1)

    def build_testbatch(batch_list, n_mc: int):
        test_batch = []
        for i in range(len(batch_list)):
            new_batch, _ = _build_mc_permutation(batch_list[i], n_mc=n_mc)
            test_batch.append(new_batch)
        return test_batch

    # --- data ---
    sampler_rho43 = BavSampler(RHO_A=4.0 / 3.0, device=device)
    sampler_rho1 = BavSampler(RHO_A=1.0, device=device)

    data_rho43 = [
        sampler_rho43.generate_test_batch(1, num_target=n_points)
        for _ in range(n_samples)
    ]
    data_rho1 = [
        sampler_rho1.generate_test_batch(1, num_target=n_points)
        for _ in range(n_samples)
    ]

    # to GPU in device = cuda, loop on data_rho
    if device == "cuda":
        for batch in data_rho43:
            batch.xc = batch.xc.to(device)
            batch.yc = batch.yc.to(device)
            batch.xt = batch.xt.to(device)
            batch.yt = batch.yt.to(device)
        for batch in data_rho1:
            batch.xc = batch.xc.to(device)
            batch.yc = batch.yc.to(device)
            batch.xt = batch.xt.to(device)
            batch.yt = batch.yt.to(device)

    data_rho1_mc = build_testbatch(data_rho1, n_mc=n_mc)
    data_rho43_mc = build_testbatch(data_rho43, n_mc=n_mc)

    # --- models ---
    model_rho43, _ = load_model(ckpt_rho43, device=device)
    model_rho1, _ = load_model(ckpt_rho1, device=device)

    # --- inference engines ---
    inf_rho43 = InferenceEngine2.from_trained_model(model_rho43, 128, 128)
    inf_rho1 = InferenceEngine2.from_trained_model(model_rho1, 128, 128)

    # --- checks ---
    pref1, acc1 = eval_once(
        engine_a=inf_rho43,
        engine_b=inf_rho1,
        batch=data_rho43_mc,
        name_a="rho=4/3",
        name_b="rho=1",
        K=K,
        timer=device,
    )

    pref2, acc2 = eval_once(
        engine_a=inf_rho1,
        engine_b=inf_rho43,
        batch=data_rho1_mc,
        name_a="rho=1",
        name_b="rho=4/3",
        K=K,
        timer=device,
    )

    ok = (pref1 == "rho=4/3") and (pref2 == "rho=1")
    logging.info("\nMODEL SELECTION CHECK: %s", "PASS ✅" if ok else "FAIL ❌")
    logging.info("sample-level accuracy on rho=4/3 data: %.1f%%", acc1 * 100)
    logging.info("sample-level accuracy on rho=1 data:   %.1f%%", acc2 * 100)
    return ok


# -----------------------------------------------------------------------------
# Visualization helpers
# -----------------------------------------------------------------------------
def _get_indices_viz(S_V, S_A, rt):
    """Compute sorting indices and split positions for visualization.
       Sorting done by rt -> S_A -> S_V.

    Returns:
        idx:       np.ndarray indices that sort primarily by rt, then S_A, then S_V.
        rt_split:  int | None, first index where rt becomes 1 (if present).
        sa_splits: list[int], indices where S_A changes (computed within each rt block).
    """
    S_V = np.asarray(S_V)
    S_A = np.asarray(S_A)
    rt = np.asarray(rt)

    if not (S_V.shape == S_A.shape == rt.shape):
        raise ValueError("S_V, S_A, and rt must have the same shape")

    idx = np.lexsort((S_V, S_A, rt))
    rt_sorted = rt[idx]
    S_A_sorted = S_A[idx]

    rt_split = int(np.argmax(rt_sorted == 1)) if np.any(rt_sorted == 1) else None

    sa_splits: list[int] = []
    if rt_split is not None and rt_split > 0:
        _, sa_idx0 = np.unique(S_A_sorted[:rt_split], return_index=True)
        sa_splits.extend(sa_idx0[1:].tolist())
        _, sa_idx1 = np.unique(S_A_sorted[rt_split:], return_index=True)
        sa_splits.extend((rt_split + sa_idx1[1:]).tolist())
    else:
        _, sa_idx = np.unique(S_A_sorted, return_index=True)
        sa_splits.extend(sa_idx[1:].tolist())

    return idx, rt_split, sa_splits


def debug_visualize(
    seed: int = 123,
    mode: Literal["bavsimv1", "bavsimv2"] = "bavsimv2",
    num_samples: int = 5,
    num_points: int = 400,
    save_path: Union[str, Path] = "bav_responses_sorted.png",
    device: str = "cpu",
    dtype: torch.dtype = torch.float32,
) -> str:
    """Visualize sorted BAV responses for rho=1 and rho=4/3 across random inputs.

    Returns the path to the saved figure.
    """

    logging.info("Using mode: %s", mode)
    if mode == "bavsimv1":
        from src.data.bav_sampler_deprecated import _sample_inputs, sample_bav_responses
    elif mode == "bavsimv2":
        from src.data.bav_samplerv2 import _sample_inputs, sample_bav_responses
    else:
        raise ValueError(f"Unknown mode: {mode!r}")

    # Sample thetas & inputs (seed first for full reproducibility)
    _set_random_seed(seed)

    def _sample_theta_v1(batch_size: int) -> torch.Tensor:
        loc = torch.zeros(5, dtype=torch.float32)
        scale = torch.ones(5, dtype=torch.float32)
        return torch.distributions.Normal(loc, scale).sample((batch_size,))

    def _sample_theta_v2(batch_size: int) -> torch.Tensor:
        loc = torch.tensor([0.0, 0.75, 1.5, 0.75, 0.0, 0.0, 0.0], dtype=torch.float32)
        scale = torch.ones(7, dtype=torch.float32)

        base = torch.distributions.Normal(torch.zeros_like(loc), torch.ones_like(scale))
        z = base.sample((batch_size,))
        mask = z.abs() > 2.0
        iters = 0
        while mask.any():
            z_new = base.sample((batch_size,))
            z = torch.where(mask, z_new, z)
            mask = z.abs() > 2.0
            iters += 1
            if iters > 100:  # very unlikely
                z = z.clamp_(-2.0, 2.0)
                break
        return loc + z * scale

    def _as_np(x) -> np.ndarray:
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
        return np.asarray(x)

    def _theta_sampler_for_mode(m: str):
        return _sample_theta_v1 if m == "bavsimv1" else _sample_theta_v2

    def _sample_response_for_rho(
        rho: float, theta: torch.Tensor, inputs: Tuple
    ) -> torch.Tensor:
        theta = theta.to(device)
        if len(inputs) == 3:
            S_V, S_A, rt = inputs
            return sample_bav_responses(rho, theta, S_V, S_A, rt, N=1)
        if len(inputs) == 4:
            S_V, S_A, rt, vl = inputs
            return sample_bav_responses(rho, theta, S_V, S_A, rt, vl, N=1)
        raise ValueError(f"Unexpected inputs tuple length: {len(inputs)}")

    thetas = _theta_sampler_for_mode(mode)(num_samples)
    inputs_list = [
        _sample_inputs(device, num_points, dtype) for _ in range(num_samples)
    ]

    # Generate responses for rho=1 and rho=4/3 (re-seed for identical theta across rhos)
    _set_random_seed(seed)
    response_rho1_list = [
        _sample_response_for_rho(1.0, theta, inputs)
        for theta, inputs in zip(thetas, inputs_list)
    ]

    _set_random_seed(seed)
    response_rho43_list = [
        _sample_response_for_rho(4.0 / 3.0, theta, inputs)
        for theta, inputs in zip(thetas, inputs_list)
    ]

    # Plot
    fig, axs = plt.subplots(2, num_samples, figsize=(5 * num_samples, 6), sharey=True)

    def _plot_sorted_response(
        ax,
        response,
        idx: np.ndarray,
        rt_split: Union[int, None],
        sa_splits: Sequence[int] | None,
        title: Union[str, None] = None,
        S_V=None,
        S_A=None,
        vl=None,
        vl_index: np.ndarray | None = None,
        vl_labels: np.ndarray | None = None,
        vl_cmap=None,
    ) -> None:
        """Plot response sorted by idx; optionally subtract S_V/S_A and color by vl.

        If ``vl_index``/``vl_labels``/``vl_cmap`` are provided, they are used for a
        consistent mapping across subplots.
        """
        r_sorted = _as_np(response)[idx]

        if S_V is not None and S_A is not None:
            S_V_sorted = _as_np(S_V)[idx]
            S_A_sorted = _as_np(S_A)[idx]
            if rt_split is not None:
                S_VA = np.concatenate([S_V_sorted[:rt_split], S_A_sorted[rt_split:]])
            else:
                S_VA = S_V_sorted  # fallback (no split)
            r_sorted = r_sorted - S_VA

        x = np.arange(len(r_sorted))

        # Choose coloring strategy
        if vl_index is not None and vl_labels is not None and vl_cmap is not None:
            vl_idx_sorted = _as_np(vl_index)[idx]
            ax.scatter(
                x,
                r_sorted,
                c=vl_idx_sorted,
                cmap=vl_cmap,
                vmin=0,
                vmax=len(vl_labels) - 1,
                s=16,
                alpha=0.7,
                rasterized=True,
            )
            # Legend once per axis
            from matplotlib.lines import Line2D

            handles = [
                Line2D(
                    [0],
                    [0],
                    marker="o",
                    linestyle="",
                    color=vl_cmap(i),
                    label=f"vl={int(v)}",
                    markersize=6,
                )
                for i, v in enumerate(vl_labels)
            ]
            ax.legend(
                handles=handles,
                frameon=True,
                loc="upper right",
                ncol=len(vl_labels),
                handletextpad=0.4,
                columnspacing=0.8,
                borderaxespad=0.2,
            )
        elif vl is not None:
            vl = np.asarray(vl)
            if vl.shape != r_sorted.shape:
                raise ValueError(
                    f"`vl` must have shape {r_sorted.shape}, got {vl.shape}"
                )
            uniq = np.unique(vl)
            idx_map = {v: i for i, v in enumerate(uniq)}
            vl_idx = np.vectorize(idx_map.get)(vl)
            cmap_local = (
                cm.get_cmap("tab20", len(uniq))
                if len(uniq) <= 20
                else cm.get_cmap("nipy_spectral", len(uniq))
            )
            ax.scatter(
                x,
                r_sorted,
                c=vl_idx,
                cmap=cmap_local,
                vmin=0,
                vmax=len(uniq) - 1,
                s=16,
                alpha=0.7,
                rasterized=True,
            )
        else:
            ax.scatter(x, r_sorted, s=12, alpha=0.7, rasterized=True)

        if rt_split is not None and 0 < rt_split < len(r_sorted):
            ax.axvline(rt_split, color="red", linestyle="--", linewidth=1)

        if sa_splits:
            for s in sa_splits:
                ax.axvline(s, color="blue", linestyle="--", alpha=0.3, linewidth=1)

        if title:
            ax.set_title(title)
        ax.set_ylim(*Y_LIM)
        ax.set_xlabel("sorted index by (rt, S_A, S_V)")

    for i in range(num_samples):
        inputs = inputs_list[i]
        if mode == "bavsimv1":
            S_V, S_A, rt = inputs
            S_V_np, S_A_np, rt_np = map(_as_np, (S_V, S_A, rt))
            vl_np = None
            vl_idx = None
            labels = None
            cmap_col = None
        else:  # bavsimv2
            S_V, S_A, rt, vl = inputs
            S_V_np, S_A_np, rt_np, vl_np = map(_as_np, (S_V, S_A, rt, vl))
            # Build a consistent mapping (per column)
            labels = np.unique(vl_np)
            idx_map = {v: j for j, v in enumerate(labels)}
            vl_idx = np.vectorize(idx_map.get)(vl_np)
            cmap_col = (
                cm.get_cmap("tab20", len(labels))
                if len(labels) <= 20
                else cm.get_cmap("nipy_spectral", len(labels))
            )

        idx, rt_split, sa_splits = _get_indices_viz(S_V_np, S_A_np, rt_np)

        resp_rho1 = response_rho1_list[i][0]
        resp_rho43 = response_rho43_list[i][0]

        _plot_sorted_response(
            axs[0, i],
            resp_rho1,
            idx,
            rt_split,
            sa_splits,
            S_V=S_V_np,
            S_A=S_A_np,
            vl=vl_np,
            vl_index=vl_idx,
            vl_labels=labels,
            vl_cmap=cmap_col,
        )
        _plot_sorted_response(
            axs[1, i],
            resp_rho43,
            idx,
            rt_split,
            sa_splits,
            S_V=S_V_np,
            S_A=S_A_np,
            vl=vl_np,
            vl_index=vl_idx,
            vl_labels=labels,
            vl_cmap=cmap_col,
        )

    axs[0, 0].set_ylabel("rho=1", fontsize=12)
    axs[1, 0].set_ylabel("rho=4/3", fontsize=12)

    inputs_set = "S_V, S_A, rt" if mode == "bavsimv1" else "S_V, S_A, rt, vl"
    save_path = Path(save_path).with_suffix("")  # drop existing suffix if any
    out_path = save_path.with_name(
        f"{save_path.name}_mode{mode}_seed{seed}"
    ).with_suffix(".png")

    fig.suptitle(
        f"BAV responses; each column uses the same randomly sampled theta and inputs ({inputs_set}). seed = {seed}",
        fontsize=16,
    )
    fig.tight_layout(rect=(0, 0, 1, 0.95))
    fig.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close(fig)
    logging.info("Saved visualization to: %s", out_path)
    return str(out_path)


def debug_model_prediction_deprecated(
    n_points: int = 200,
    K: int = 16,
    n_mc: int = 32,
    device: Union[str, torch.device] = "cpu",
    true_data: Literal["rho1", "rho43"] = "rho1",
    seed: int = 0,
    plot_diff: bool = True,
    ckpt_rho43: Union[str, Path] = DEFAULT_CKPT_RHO43,
    ckpt_rho1: Union[str, Path] = DEFAULT_CKPT_RHO1,
) -> str:
    """
    Deprecated: used for older log likelihood evals function
    Plot model predictions vs true responses for a single synthetic batch.

    ``K`` is passed to the engine (buffer batch size).

    Returns the path to the saved figure.
    """

    def _inverse_perm(x, inv_perm, y: bool = True):
        d = 1 if y else x.shape[-1]
        return x.gather(1, inv_perm.unsqueeze(-1).expand(-1, -1, d))

    _set_random_seed(seed)

    # --- data ---
    if true_data == "rho43":
        sampler = BavSampler(RHO_A=4.0 / 3.0)
        data = sampler.generate_test_batch(1, num_target=n_points)
        data_name = "Rho = 4/3"
    elif true_data == "rho1":
        sampler = BavSampler(RHO_A=1.0)
        data = sampler.generate_test_batch(1, num_target=n_points)
        data_name = "Rho = 1"
    else:
        raise ValueError("true_data must be 'rho1' or 'rho43'")

    data_raw = data
    data_mc, inv_perm = _build_mc_permutation(data, n_mc)

    # --- models / engines ---
    model_rho43, _ = load_model(ckpt_rho43, device=device)
    model_rho1, _ = load_model(ckpt_rho1, device=device)
    inf_rho43 = InferenceEngine2.from_trained_model(model_rho43, 128, 128)
    inf_rho1 = InferenceEngine2.from_trained_model(model_rho1, 128, 128)

    with torch.no_grad():
        pred_rho43, ll43 = inf_rho43.evaluate_joint_loglikelihood(data_mc, K=K)
        pred_rho1, ll1 = inf_rho1.evaluate_joint_loglikelihood(data_mc, K=K)

    ll43 = _inverse_perm(ll43, inv_perm)
    joint_ll43 = ll43.sum(dim=1)
    joint_ll43_mean = float(joint_ll43.mean())
    joint_ll43_std = float(joint_ll43.std())

    ll1 = _inverse_perm(ll1, inv_perm)
    joint_ll1 = ll1.sum(dim=1)
    joint_ll1_mean = float(joint_ll1.mean())
    joint_ll1_std = float(joint_ll1.std())

    x = data_raw.xt
    rt = x[0, :, 0]
    S_A = x[0, :, 1]
    S_V = x[0, :, 2]

    indices, rt_split, sa_splits = _get_indices_viz(S_V, S_A, rt)

    S_V_sorted = S_V[indices]
    S_A_sorted = S_A[indices]

    if plot_diff and rt_split is not None:
        S_VA = np.concatenate([S_V_sorted[:rt_split], S_A_sorted[rt_split:]])
    else:
        S_VA = torch.zeros_like(S_V)

    true_y = data_raw.yt.squeeze()[indices] - S_VA

    sample_rho43 = _inverse_perm(pred_rho43.yc, inv_perm)
    sample_rho1 = _inverse_perm(pred_rho1.yc, inv_perm)

    sample_rho43_mean = torch.mean(sample_rho43, dim=0) - S_VA.reshape(-1, 1)
    sample_rho1_mean = torch.mean(sample_rho1, dim=0) - S_VA.reshape(-1, 1)

    sample_rho43_std = torch.std(sample_rho43, dim=0)
    sample_rho1_std = torch.std(sample_rho1, dim=0)

    # Plot
    fig, ax = plt.subplots(3, 1, figsize=(10, 9), sharex=True)

    sample_rho1_mean = sample_rho1_mean.squeeze(-1)[indices]
    sample_rho43_mean = sample_rho43_mean.squeeze(-1)[indices]

    sample_rho1_std = sample_rho1_std.squeeze(-1)[indices]
    sample_rho43_std = sample_rho43_std.squeeze(-1)[indices]

    ax[0].set_title(
        f"BAV Model Samples from empty context (trained on rho=1), joint ll {joint_ll1_mean:.2e} ± {joint_ll1_std:.2e}",
        size=9,
    )
    ax[0].scatter(
        np.arange(len(sample_rho1_mean)),
        sample_rho1_mean,
        label="rho=1 mean",
        alpha=0.7,
        s=8,
        rasterized=True,
    )
    ax[0].errorbar(
        np.arange(len(sample_rho1_mean)),
        sample_rho1_mean,
        yerr=sample_rho1_std,
        fmt="none",
        label="rho=1 std",
        alpha=0.5,
        elinewidth=0.5,
        capsize=3,
        capthick=0.5,
    )
    if rt_split is not None:
        ax[0].axvline(rt_split, color="red", linestyle="--", linewidth=1)
    for s in sa_splits:
        ax[0].axvline(s, color="green", linestyle="--", linewidth=1)
    ax[0].set_ylim(*Y_LIM)

    ax[1].set_title(
        f"BAV Model Samples from empty context (trained on rho=4/3), joint ll {joint_ll43_mean:.2e} ± {joint_ll43_std:.2e}",
        size=9,
    )
    ax[1].scatter(
        np.arange(len(sample_rho43_mean)),
        sample_rho43_mean,
        label="rho=4/3 mean",
        alpha=0.7,
        s=8,
        rasterized=True,
    )
    ax[1].errorbar(
        np.arange(len(sample_rho43_mean)),
        sample_rho43_mean,
        yerr=sample_rho43_std,
        fmt="none",
        label="rho=4/3 std",
        alpha=0.5,
        elinewidth=0.5,
        capsize=3,
        capthick=0.5,
    )
    for s in sa_splits:
        ax[1].axvline(s, color="green", linestyle="--", linewidth=1)
    if rt_split is not None:
        ax[1].axvline(rt_split, color="red", linestyle="--", linewidth=1)
    ax[1].set_ylim(*Y_LIM)

    ax[2].set_title(f"True Responses {data_name}", size=9)
    ax[2].scatter(
        np.arange(len(true_y)),
        true_y,
        label="True Responses",
        alpha=0.7,
        s=8,
        rasterized=True,
    )
    for s in sa_splits:
        ax[2].axvline(s, color="green", linestyle="--", linewidth=1)
    if rt_split is not None:
        ax[2].axvline(rt_split, color="red", linestyle="--", linewidth=1)
    ax[2].set_ylim(*Y_LIM)
    ax[2].set_xlabel("sorted index by (rt, S_A, S_V)")

    fig.suptitle(
        f"BAV Model Predictions vs True Responses (n_mc = {n_mc}) on {data_name} True data, seed {seed}",
        fontsize=12,
    )
    out_path = Path(f"bav_model_predictions_data{true_data}_seed{seed}.png")
    fig.tight_layout(rect=(0, 0, 1, 0.95))
    fig.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close(fig)
    logging.info("Saved prediction debug figure to: %s", out_path)
    return str(out_path)


# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="BAV Model Testing")
    parser.add_argument(
        "--mode",
        type=str,
        choices=["model_selection", "debug_v", "debug_p"],
        default="model_selection",
        help="Run mode",
    )
    parser.add_argument("--n_samples", type=int, default=64, help="Number of samples")
    parser.add_argument("--n_points", type=int, default=400, help="Number of points")
    parser.add_argument(
        "--K", type=int, default=16, help="Buffer batch size (passed to engine)"
    )
    parser.add_argument("--device", type=str, default="cpu", help="Device to use")
    parser.add_argument(
        "--n_mc", type=int, default=10, help="Number of Monte Carlo samples"
    )
    parser.add_argument(
        "--plot_diff", action="store_true", help="Whether to plot differences"
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--ckpt_rho1",
        type=str,
        default=str(DEFAULT_CKPT_RHO1),
        help="Path to rho=1 checkpoint",
    )
    parser.add_argument(
        "--ckpt_rho43",
        type=str,
        default=str(DEFAULT_CKPT_RHO43),
        help="Path to rho=4/3 checkpoint",
    )
    parser.add_argument(
        "--log_level",
        type=str,
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        default="INFO",
        help="Logging level",
    )

    args = parser.parse_args()
    logging.basicConfig(
        level=getattr(logging, args.log_level), format="[%(levelname)s] %(message)s"
    )

    logging.info("Starting BAV model testing...")

    if args.mode == "debug_v":
        logging.info("Running visualization debug...")
        debug_visualize(seed=args.seed, num_points=args.n_points, device=args.device)

    elif args.mode == "debug_p":
        logging.info("Running prediction debug...")
        seed_list = [0, 1, 2, 3, 4, 5]
        for s in seed_list:
            debug_model_prediction_deprecated(
                n_points=args.n_points,
                K=args.K,
                device=args.device,
                n_mc=args.n_mc,
                true_data="rho43",
                seed=s,
                plot_diff=args.plot_diff,
                ckpt_rho43=args.ckpt_rho43,
                ckpt_rho1=args.ckpt_rho1,
            )
            debug_model_prediction_deprecated(
                n_points=args.n_points,
                K=args.K,
                device=args.device,
                n_mc=args.n_mc,
                true_data="rho1",
                seed=s,
                plot_diff=args.plot_diff,
                ckpt_rho43=args.ckpt_rho43,
                ckpt_rho1=args.ckpt_rho1,
            )
    elif args.mode == "model_selection":
        logging.info("Running model_selection check...")
        model_selection_check(
            n_samples=args.n_samples,
            n_points=args.n_points,
            K=args.K,
            device=args.device,
            n_mc=args.n_mc,
            ckpt_rho43=args.ckpt_rho43,
            ckpt_rho1=args.ckpt_rho1,
        )
