# AI Summary: Multi-level BAV model (σV_low/med/high) with per-trial visual noise levels.
# Provides NLL via Gauss–Hermite quadrature, sampling, random-trial wrappers returning
# V_levels, and BavSampler with 4-feature inputs. Backward-compat: 5-D θ uses one σV.

"""
Multi-level Bayesian Audio-Visual (BAV) localization model with causal inference.

Overview
--------
This module implements a BAV model where the *visual* sensory noise can take one of three
levels per trial (low/med/high), indicated by an integer tensor `V_levels ∈ {0,1,2}`.
The rest of the model follows the standard BAV formulation:

- Sensory measurements: x_V ~ N(S_V, σ_V^2), x_A ~ N(ρ·S_A, σ_A^2)
- Gaussian spatial prior over s:   s ~ N(μ, σ_s^2), with μ fixed to 0
- Causal inference: model-averaging between C=1 (common source) and C=2 (independent)
- Motor noise on the reported estimate:  R ~ N(ŝ, σ_m^2)
- Lapse process: with small prob. (0.02) emit a uniform response on [-45°, 45°]

Key features
------------
1) nll_bav_constant_gaussian(...) evaluates a summed negative log-likelihood using
   separable 2D Gauss–Hermite quadrature over (x_V, x_A). For efficiency, trials are
   split by `V_levels` and each subset is integrated with its corresponding σ_V.
2) sample_bav_responses(...) draws synthetic responses by sampling measurements,
   computing model-averaged posterior estimates, then adding motor noise and lapses.
3) Random-trial wrappers sample (S_V, S_A, response_types, V_levels) for convenience.
4) BavSampler assembles batched (context, buffer, target) splits with input features
   (response_type, V_level, S_A, S_V). The parameter vector θ is 7-D in the new layout.

Parameter vector θ (unconstrained → transformed)
------------------------------------------------
New (len=7):
  [ log σV_low, log σV_med, log σV_high, log σA, log σs, log σm, logit(p_same) ]
Compat (len=5): all three σV levels use the single σV from indices [0], and the rest map
  as in the constant-noise model. In both cases, lapse=0.02 and μ=0 are fixed internally.

Default priors for θ sampling (BavSampler._sample_theta)
--------------------------------------------------------
All priors are **truncated Gaussians** (unit variance) in the unconstrained space,
truncated at ±2 standard deviations and resampled until within bounds:
  - log σV_low    ~ N(0.00, 1)   truncated to [−2.00, +2.00]      → σV_low ≈ e^0.00
  - log σV_med    ~ N(0.75, 1)   truncated to [−1.25, +2.75]      → σV_med ≈ e^0.75
  - log σV_high   ~ N(1.50, 1)   truncated to [−0.50, +3.50]      → σV_high ≈ e^1.50
  - log σA        ~ N(0.75, 1)   truncated to [−1.25, +2.75]      → σA ≈ e^0.75
  - log σs        ~ N(0.00, 1)   truncated to [−2.00, +2.00]      (unchanged default)
  - log σm        ~ N(0.00, 1)   truncated to [−2.00, +2.00]      (unchanged default)
  - logit(p_same) ~ N(0.00, 1)   truncated to [−2.00, +2.00]      (unchanged default)

Rationale: Truncation avoids extreme values that can destabilize GH integration or produce
unreasonable noise scales, while keeping priors weakly informative.

Public APIs (signatures)
------------------------
nll_bav_constant_gaussian(RHO_A, theta, R, S_V, S_A, response_types, V_levels, *, gh_deg=51, chunk_size=None)
sample_bav_responses(RHO_A, theta, S_V, S_A, response_types, V_levels, *, N=1)
sample_bav_responses_random_trials(RHO_A, theta, *, N=1, num_points=98, device=None, dtype=torch.float32)
"""

from __future__ import annotations

import math
from typing import Optional, Tuple

import numpy as np
import torch

from src.utils import DataAttr

__all__ = [
    "nll_bav_constant_gaussian",
    "sample_bav_responses",
    "sample_bav_responses_random_trials",
    "BavSampler",
]

# sqrt(2π) used in the univariate Gaussian pdf
SQRT_TWO_PI = math.sqrt(2.0 * math.pi)


def _gaussian_pdf(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """Elementwise univariate Gaussian PDF.

    Args:
        x: (...,) evaluation points.
        mu: (...,) mean (broadcastable to x).
        sigma: (...,) std dev (broadcastable to x).

    Returns:
        (...,) PDF values for N(mu, sigma^2).
    """
    return torch.exp(-0.5 * ((x - mu) / sigma) ** 2) / (SQRT_TWO_PI * sigma)


def _gauss_hermite_tensor(n: int, device: torch.device, dtype: torch.dtype):
    """Return physicists' Gauss–Hermite nodes/weights for ∫ e^{−y²} f(y) dy.

    We use the change of variables x = mean + √2·σ·y to integrate w.r.t. a Gaussian.

    Args:
        n: number of nodes (accuracy/speed tradeoff).
        device, dtype: torch placement for returned tensors.

    Returns:
        (y, w): shape (n,) nodes and weights as torch tensors.
    """
    y, w = np.polynomial.hermite.hermgauss(n)
    return (
        torch.as_tensor(y, dtype=dtype, device=device),
        torch.as_tensor(w, dtype=dtype, device=device),
    )


def _unpack_theta(theta: torch.Tensor, device: torch.device, dtype: torch.dtype):
    """Map unconstrained θ → model parameters (supports 7-D new or 5-D compat layout).

    New layout (len=7):
        [ log σV_low, log σV_med, log σV_high, log σA, log σs, log σm, logit(p_same) ]
    Compat layout (len=5):
        [ log σV, log σA, log σs, log σm, logit(p_same) ]
        → same σV is used for all three visual noise levels.

    Fixed constants:
        lapse = 0.02, μ = 0.0

    Returns:
        (σV_low, σV_med, σV_high, σA, σs, σm, lapse, p_same, μ)
        All returned as torch tensors on (device, dtype).
    """
    th = theta.to(device=device, dtype=dtype).flatten()
    if th.numel() >= 7:
        sigV_low, sigV_med, sigV_high = torch.exp(th[0]), torch.exp(th[1]), torch.exp(th[2])
        sigA, sigS, sigM = torch.exp(th[3]), torch.exp(th[4]), torch.exp(th[5])
        p_same = torch.sigmoid(th[6])
    else:
        # Backward compatibility (len=5): reuse single σV for all levels
        sigV = torch.exp(th[0])
        sigV_low = sigV_med = sigV_high = sigV
        sigA, sigS, sigM = torch.exp(th[1]), torch.exp(th[2]), torch.exp(th[3])
        p_same = torch.sigmoid(th[4])

    lapse = torch.tensor(0.02, device=device, dtype=dtype)  # fixed lapse prob
    mu_p = torch.tensor(0.0, device=device, dtype=dtype)    # fixed prior mean
    return sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, mu_p


def _chunk_nll(
    RHO_A: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    R: torch.Tensor,
    rt: torch.Tensor,
    rel_V: torch.Tensor,
    rel_A: torch.Tensor,
    weight_mat: torch.Tensor,
    mu_p: torch.Tensor,
    sigma_m: torch.Tensor,
    lapse: torch.Tensor,
    p_same: torch.Tensor,
    iv_V: torch.Tensor,
    iv_A: torch.Tensor,
    iv_s: torch.Tensor,
    weight_sum_c1: torch.Tensor,
    weight_V: torch.Tensor,
    weight_A: torch.Tensor,
    inv00: torch.Tensor,
    inv11: torch.Tensor,
    inv01: torch.Tensor,
    log_norm_c1: torch.Tensor,
    log_norm_c2_V: torch.Tensor,
    log_norm_c2_A: torch.Tensor,
) -> torch.Tensor:
    """Summed NLL for a batch using precomputed constants and GH rel_V/rel_A.

    This function is called repeatedly on (possibly) small slices of trials (chunks),
    which reduces memory footprint. The math follows the conceptual model:

      1) Build GH grids over x_V, x_A around S_V, ρ·S_A using rel_V, rel_A
      2) Compute p(x|C=1), p(x|C=2) (log domain) and posterior Pr(C=1|x)
      3) Compute posterior means μ̂_{C=1}, μ̂_{C=2} and model-averaged estimate ŝ
      4) Evaluate the motor-noise likelihood p(R | ŝ) and integrate over x via GH
      5) Mix with lapses and sum −log likelihoods over trials

    Shapes:
        S_V, S_A, R, rt: (B,)
        rel_V: (N_V,), rel_A: (N_A,), weight_mat: (N_V, N_A)
        Returns scalar NLL for this chunk.
    """
    # 1) Expand GH grids for this batch of trials
    xV = S_V[:, None, None] + rel_V[None, :, None]
    xA = RHO_A * (S_A[:, None, None] + rel_A[None, None, :])

    # 2) Posterior over causal structure P(C=1|x): compute in log domain then exponentiate
    zV, zA = xV - mu_p, xA - mu_p
    quad_c1 = inv00 * zV * zV + 2 * inv01 * zV * zA + inv11 * zA * zA
    log_p_c1 = log_norm_c1 - 0.5 * quad_c1

    vVbar = 1.0 / iv_V + 1.0 / iv_s
    vAbar = 1.0 / iv_A + 1.0 / iv_s
    log_p_c2 = log_norm_c2_V - 0.5 * (zV * zV) / vVbar + log_norm_c2_A - 0.5 * (zA * zA) / vAbar

    post_c1 = torch.exp(
        torch.log(p_same) + log_p_c1
        - torch.logaddexp(torch.log(p_same) + log_p_c1, torch.log1p(-p_same) + log_p_c2)
    )

    # 3) Posterior means under C=1 and C=2 (report-dependent for C=2)
    mu_c1 = (iv_V * xV + iv_A * xA + iv_s * mu_p) / weight_sum_c1
    mu_c2_V = (iv_V * xV + iv_s * mu_p) / weight_V
    mu_c2_A = (iv_A * xA + iv_s * mu_p) / weight_A
    mu_c2 = torch.where(rt[:, None, None] == 0, mu_c2_V, mu_c2_A)

    # Model-averaged estimate ŝ
    s_hat = post_c1 * mu_c1 + (1.0 - post_c1) * mu_c2

    # 4) Likelihood of response given estimate, integrated over x with GH weights
    ll_r = _gaussian_pdf(R[:, None, None], s_hat, sigma_m)
    prob_r = torch.sum(ll_r * weight_mat, dim=(1, 2))

    # 5) Lapse mixture and final NLL for the chunk
    prob_r = (1.0 - lapse) * prob_r + lapse / 90.0
    return -torch.sum(torch.log(prob_r + 1e-12))


def nll_bav_constant_gaussian(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    R: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    V_levels: torch.Tensor,
    *,
    gh_deg: int = 51,
    chunk_size: Optional[int] = None,
) -> torch.Tensor:
    """Summed NLL with three σ_V levels selected by V_levels ∈ {0,1,2}.

    Args:
        RHO_A: auditory rescaling factor ρ (float/tensor; broadcasts).
        theta: parameters; supports len 7 (3×σV + σA, σs, σm, p_same) or len 5 (compat).
        R: (T,) observed responses (deg).
        S_V, S_A: (T,) true stimulus locations (deg).
        response_types: (T,) 0=BV (visual report), 1=BA (auditory report).
        V_levels: (T,) integers in {0,1,2} indicating low/med/high visual noise.
        gh_deg: Gauss–Hermite nodes per dimension (51 is accurate but costlier).
        chunk_size: optional memory bound; trials are processed in pieces.

    Returns:
        Scalar tensor: negative log-likelihood summed across trials.

    Implementation notes:
        - GH nodes/weights are computed once. rel_A depends only on σ_A; rel_V depends
          on the σ_V of the current level and is recomputed per level.
        - Trials are partitioned by `V_levels`. Each subset uses its σ_V-level constants.
    """
    device, dtype = R.device, R.dtype
    sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, mu_p = _unpack_theta(
        theta, device, dtype
    )
    RHO_A = torch.as_tensor(RHO_A, device=device, dtype=dtype)

    # Precompute GH nodes/weights (shared) and rel_A (depends only on σ_A)
    nodes_V, w_V = _gauss_hermite_tensor(gh_deg, device, dtype)
    nodes_A, w_A = _gauss_hermite_tensor(gh_deg, device, dtype)
    rel_A = sigA * math.sqrt(2.0) * nodes_A
    weight_mat = (w_V[:, None] * w_A[None, :]) / math.pi

    nll_total = torch.tensor(0.0, device=device, dtype=dtype)
    sigV_levels = (sigV_low, sigV_med, sigV_high)

    for lvl, sigV in enumerate(sigV_levels):
        # Select the subset of trials that belong to this visual noise level
        idx = torch.nonzero(V_levels == lvl, as_tuple=False).squeeze(-1)
        if idx.numel() == 0:
            continue

        # Level-specific variances/precisions and normalization constants
        vV = sigV * sigV
        vA = sigA * sigA
        vS = sigS * sigS
        iv_V = 1.0 / vV
        iv_A = 1.0 / vA
        iv_s = 1.0 / vS

        # C=1 (common source) covariance Σ for x = (xV, xA):
        # Σ = [[vV+vS, vS],[vS, vA+vS]], with det and inverse entries
        a = vV + vS
        b = vS
        d = vA + vS
        det = a * d - b * b
        inv00 = d / det
        inv11 = a / det
        inv01 = -b / det
        log_norm_c1 = -0.5 * (math.log((2 * math.pi) ** 2) + torch.log(det))

        # C=2 (independent sources, same prior per modality)
        vVbar = vV + vS
        vAbar = vA + vS
        log_norm_c2_V = -0.5 * (math.log(2 * math.pi) + torch.log(vVbar))
        log_norm_c2_A = -0.5 * (math.log(2 * math.pi) + torch.log(vAbar))

        weight_sum_c1 = iv_V + iv_A + iv_s
        weight_V = iv_V + iv_s
        weight_A = iv_A + iv_s

        # rel_V depends on σ_V of this level
        rel_V = sigV * math.sqrt(2.0) * nodes_V

        # Chunked computation on the selected subset of trials
        sub_R = R[idx]
        cs = sub_R.numel() if chunk_size is None else chunk_size
        for start in range(0, sub_R.numel(), cs):
            sl = slice(start, min(start + cs, sub_R.numel()))
            nll_total += _chunk_nll(
                RHO_A,
                S_V[idx][sl],
                S_A[idx][sl],
                sub_R[sl],
                response_types[idx][sl],
                rel_V,
                rel_A,
                weight_mat,
                mu_p,
                sigM,
                lapse,
                p_same,
                iv_V,
                iv_A,
                iv_s,
                weight_sum_c1,
                weight_V,
                weight_A,
                inv00,
                inv11,
                inv01,
                log_norm_c1,
                log_norm_c2_V,
                log_norm_c2_A,
            )

    return nll_total


def sample_bav_responses(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    V_levels: torch.Tensor,
    *,
    N: int = 1,
) -> torch.Tensor:
    """Draw synthetic responses with per-trial visual noise selected by V_levels.

    Args:
        RHO_A: auditory rescaling factor ρ (float/tensor).
        theta: parameter vector (7-D new or 5-D compat); lapse, μ fixed internally.
        S_V, S_A: (T,) stimuli (deg).
        response_types: (T,) 0=BV (visual), 1=BA (auditory).
        V_levels: (T,) 0/1/2 select σV_low/med/high for each trial.
        N: number of responses per trial to sample.

    Returns:
        (N, T) sampled responses.
    """
    device, dtype = S_V.device, S_V.dtype
    (sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, mu_p) = _unpack_theta(
        theta, device, dtype
    )
    RHO_A = torch.as_tensor(RHO_A, device=device, dtype=dtype)

    T = S_V.numel()
    # Build per-trial visual noise σ_V by indexing levels
    sigV_levels = torch.stack([sigV_low, sigV_med, sigV_high])
    sigV_trial = sigV_levels[V_levels]  # (T,)

    # 1) Sensory measurements x_V, x_A with appropriate noise levels
    # Generative auditory rescaling: scale the noisy observation itself, but
    # KEEP inference covariance unscaled (σ_A) in all subsequent computations.
    x_V = S_V.unsqueeze(0) + sigV_trial.unsqueeze(0) * torch.randn((N, T), device=device, dtype=dtype)
    x_A = RHO_A * (S_A.unsqueeze(0) + sigA * torch.randn((N, T), device=device, dtype=dtype))

    # 2) Compute posterior terms for C=1 and C=2
    vV = sigV_trial * sigV_trial  # (T,)
    vA = sigA * sigA
    vS = sigS * sigS
    iv_V = 1.0 / vV
    iv_A = torch.tensor(1.0 / vA, device=device, dtype=dtype)
    iv_s = torch.tensor(1.0 / vS, device=device, dtype=dtype)

    a = vV + vS
    b = vS
    d = vA + vS
    det = a * d - b * b
    inv00 = d / det
    inv11 = a / det
    inv01 = -b / det
    log_norm_c1 = -0.5 * (math.log((2 * math.pi) ** 2) + torch.log(det))

    vVbar = vV + vS
    vAbar = vA + vS
    log_norm_c2_V = -0.5 * (math.log(2 * math.pi) + torch.log(vVbar))
    # BUGFIX: use torch.log for tensor vAbar (was math.log)
    log_norm_c2_A = -0.5 * (math.log(2 * math.pi) + torch.log(vAbar))

    zV = x_V - mu_p
    zA = x_A - mu_p
    quad_c1 = inv00.unsqueeze(0) * (zV * zV) + 2 * inv01.unsqueeze(0) * (zV * zA) + inv11.unsqueeze(0) * (zA * zA)
    log_p_c1 = log_norm_c1.unsqueeze(0) - 0.5 * quad_c1
    log_p_c2 = (
        log_norm_c2_V.unsqueeze(0) - 0.5 * (zV * zV) / vVbar.unsqueeze(0)
        + log_norm_c2_A - 0.5 * (zA * zA) / vAbar
    )

    logit_pc1 = torch.log(p_same) + log_p_c1 - (torch.log1p(-p_same) + log_p_c2)
    post_c1 = torch.sigmoid(logit_pc1)

    # 3) Posterior means and model-averaged estimate
    weight_sum_c1 = iv_V + iv_A + iv_s
    weight_V = iv_V + iv_s
    weight_A = iv_A + iv_s

    mu_c1 = (iv_V.unsqueeze(0) * x_V + iv_A * x_A + iv_s * mu_p) / weight_sum_c1.unsqueeze(0)
    mu_c2_V = (iv_V.unsqueeze(0) * x_V + iv_s * mu_p) / weight_V.unsqueeze(0)
    mu_c2_A = (iv_A * x_A + iv_s * mu_p) / weight_A

    rt = response_types.unsqueeze(0)
    mu_c2 = torch.where(rt == 0, mu_c2_V, mu_c2_A)
    s_hat = post_c1 * mu_c1 + (1.0 - post_c1) * mu_c2

    # 4) Motor noise and lapses
    R_noisy = s_hat + sigM * torch.randn_like(s_hat)
    if float(lapse) > 0.0:
        mask = torch.rand_like(R_noisy) < lapse
        R_uniform = -45.0 + 90.0 * torch.rand_like(R_noisy)
        R_noisy = torch.where(mask, R_uniform, R_noisy)
    return R_noisy


def _sample_inputs(
    device: torch.device,
    num_points: int = 98,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate synthetic stimuli and per-trial visual level.

    Procedure:
        - S_A: discrete uniform over {-15, -10, -5, 0, 5, 10, 15}
        - With 50% probability set S_V = S_A; else S_V ~ Uniform(-20, 20)
        - response_types ~ Bernoulli(0.5) (0=BV, 1=BA)
        - V_levels ~ Uniform{0,1,2} indicating low/med/high σ_V

    Args:
        device: torch device for tensors.
        num_points: number of trials to generate.
        dtype: tensor dtype.

    Returns:
        (S_V, S_A, response_types, V_levels) each of shape (num_points,).
    """
    if num_points <= 0:
        raise ValueError("num_points must be positive.")

    sa_vals = torch.tensor([-15.0, -10.0, -5.0, 0.0, 5.0, 10.0, 15.0], device=device, dtype=dtype)
    S_A = sa_vals[torch.randint(0, sa_vals.numel(), (num_points,), device=device)]

    same_mask = torch.rand(num_points, device=device) < 0.5
    S_V_alt = -20.0 + 40.0 * torch.rand(num_points, device=device, dtype=dtype)
    S_V = torch.where(same_mask, S_A, S_V_alt).to(dtype)

    response_types = (torch.rand(num_points, device=device) < 0.5).to(torch.long)
    V_levels = torch.randint(0, 3, (num_points,), device=device, dtype=torch.long)
    return S_V, S_A, response_types, V_levels


def sample_bav_responses_random_trials(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    *,
    N: int = 1,
    num_points: int = 98,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Convenience wrapper: sample random trials and simulate responses.

    Args:
        RHO_A: auditory rescaling factor ρ.
        theta: parameter vector (7-D new or 5-D compat).
        N: number of responses per trial.
        num_points: number of trials.
        device, dtype: torch placement and dtype.

    Returns:
        (S_V, S_A, response_types, V_levels, R_sim) with shapes:
            S_V, S_A: (num_points,)
            response_types, V_levels: (num_points,)
            R_sim: (N, num_points)
    """
    if device is None:
        device = torch.device("cpu")
    S_V, S_A, rt, V_levels = _sample_inputs(device, num_points, dtype)
    R_sim = sample_bav_responses(RHO_A, theta.to(device), S_V, S_A, rt, V_levels, N=N)
    return S_V, S_A, rt, V_levels, R_sim


class BavSampler:
    """Assembles (context, buffer, target) splits with 4-feature inputs.

    Features:
        x = (response_type, V_level, S_A, S_V), y = responses
    Parameter sampling:
        θ is sampled in 7-D unconstrained space with **truncated Gaussian** priors on
        log noise parameters (±2 SD; see `_sample_theta`); internal transforms per `_unpack_theta`.
    Notes:
        If num_context == 0, a 1×4 dummy context row [2.0, 0.0, 0.0, 0.0] is inserted.
    """
    THETA_DIM = 7

    def __init__(self, RHO_A: float = 4.0 / 3.0, device: str = "cpu", dtype: torch.dtype = torch.float32):
        self.RHO_A = RHO_A
        self.device = device
        self.dtype = dtype

    def _sample_theta(self, batch_size: int) -> torch.Tensor:
        """Sample θ in the unconstrained space with **truncated Gaussian** priors.

        Priors (applied to the *log* scale for noise std devs), each truncated at ±2 SD
        around the mean and **resampled** until within bounds:
        - log σV_low    ~ N(0.00, 1.5)
        - log σV_med    ~ N(log σV_low + 1, 1)
        - log σV_high   ~ N(log σV_med + 0.75, 0.5)
        - log σA        ~ N(1.75, 0.5)
        - log σs        ~ N(2.5, 1)
        - log σm        ~ N(0.00, 0.5)
        - logit(p_same) ~ N(1.5, 1.5)
        """
        # Use module's device/dtype when available
        try:
            p = next(self.parameters())
            device, dtype = p.device, p.dtype
        except Exception:
            device, dtype = torch.device("cpu"), torch.float32

        def trunc_normal(mean: torch.Tensor, std: float, max_iters: int = 100) -> torch.Tensor:
            """Vectorized truncated normal: resample until within [mean-2σ, mean+2σ]."""
            std_t = torch.as_tensor(std, device=device, dtype=dtype)
            x = mean + std_t * torch.randn_like(mean, device=device, dtype=dtype)
            lower = mean - 2.0 * std_t
            upper = mean + 2.0 * std_t

            mask = (x < lower) | (x > upper)
            iters = 0
            while mask.any() and iters < max_iters:
                n_bad = int(mask.sum().item())
                x_new = mean[mask] + std_t * torch.randn(n_bad, device=device, dtype=dtype)
                x = x.clone()
                x[mask] = x_new
                mask = (x < lower) | (x > upper)
                iters += 1

            if mask.any():
                # Extremely unlikely fallback: hard-clip remaining offenders to bounds
                x = torch.minimum(torch.maximum(x, lower), upper)
            return x

        # 1) Hierarchical visual-noise scales on log-scale
        mu_low  = torch.full((batch_size,), 0.0, device=device, dtype=dtype)
        log_sigmaV_low  = trunc_normal(mu_low, 1.5)

        mu_med  = log_sigmaV_low + 1.0
        log_sigmaV_med  = trunc_normal(mu_med, 1.0)

        mu_high = log_sigmaV_med + 0.75
        log_sigmaV_high = trunc_normal(mu_high, 0.5)

        # 2) Other log-scales (independent)
        log_sigmaA = trunc_normal(torch.full((batch_size,), 1.75, device=device, dtype=dtype), 0.5)
        log_sigmas = trunc_normal(torch.full((batch_size,), 2.5,  device=device, dtype=dtype), 1.0)
        log_sigmam = trunc_normal(torch.full((batch_size,), 0.0,  device=device, dtype=dtype), 0.5)

        # 3) Logit of p_same
        logit_p_same = trunc_normal(torch.full((batch_size,), 1.5, device=device, dtype=dtype), 1.5)

        # Stack in the order specified above
        theta = torch.stack(
            [
                log_sigmaV_low,
                log_sigmaV_med,
                log_sigmaV_high,
                log_sigmaA,
                log_sigmas,
                log_sigmam,
                logit_p_same,
            ],
            dim=1,
        )

        return theta

    def _sample_num_context(self, context_range: Tuple[int, int]) -> int:
        """Sample a context count from [low, high] inclusive, or uniformly from a list."""
        n = len(context_range)
        if n == 2:
            low, high = context_range
            if low > high:
                raise ValueError("Invalid `context_range`: low > high.")
            return torch.randint(low, high + 1, (1,)).item()
        if n > 2:
            choices = torch.as_tensor(context_range)
            return choices[torch.randint(0, choices.numel(), (1,))].item()
        raise ValueError("`context_range` must have length 2 or > 2.")

    def generate_batch(
        self,
        batch_size: int,
        num_context: Optional[int] = None,
        num_buffer: int = 50,
        num_target: int = 50,
        context_range: Tuple[int, int] = (3, 47),
    ) -> DataAttr:
        """Generate a batch of random trials split into (context, buffer, target).

        Args:
            batch_size: number of independent θ draws / tasks.
            num_context: number of context points per task (if None, sampled from range).
            num_buffer: number of buffer points per task.
            num_target: number of target points per task.
            context_range: used when num_context is None; inclusive [low, high].

        Returns:
            DataAttr with fields:
              - xc, yc: (B, C, 4), (B, C, 1)
              - xb, yb: (B, Bf, 4), (B, Bf, 1)
              - xt, yt: (B, T, 4), (B, T, 1)
        """
        if num_context is None:
            num_context = self._sample_num_context(context_range)
        num_total = num_context + num_buffer + num_target

        thetas = self._sample_theta(batch_size)
        S_Vs, S_As, rts, Vls, Y = [], [], [], [], []

        for th in thetas:
            S_V, S_A, rt, Vl, resp = sample_bav_responses_random_trials(
                self.RHO_A, th, num_points=num_total, device=self.device, dtype=self.dtype
            )
            S_Vs.append(S_V)
            S_As.append(S_A)
            rts.append(rt)
            Vls.append(Vl)
            Y.append(resp.squeeze(0))

        # Stack per-task tensors and assemble features
        S_Vs, S_As, rts, Vls = (torch.stack(t) for t in (S_Vs, S_As, rts, Vls))
        x = torch.stack((rts, Vls, S_As, S_Vs), dim=-1)   # (B, N, 4)
        y = torch.stack(Y).unsqueeze(-1)                  # (B, N, 1)

        # Random split indices (shared across tasks for simplicity)
        perm = torch.randperm(num_total, device=x.device)
        ctx_idx = perm[:num_context]
        buf_idx = perm[num_context : num_context + num_buffer]
        tar_idx = perm[num_context + num_buffer :]

        if num_context == 0:
            # Dummy one-row context when none requested
            xc = torch.tensor([2.0, 0.0, 0.0, 0.0], device=x.device, dtype=x.dtype)[None, None, :].expand(
                batch_size, 1, 4
            )
            yc = torch.tensor([0.0], device=y.device, dtype=y.dtype)[None, None, :].expand(batch_size, 1, 1)
        else:
            xc = x[:, ctx_idx]
            yc = y[:, ctx_idx]

        return DataAttr(xc=xc, yc=yc, xb=x[:, buf_idx], yb=y[:, buf_idx], xt=x[:, tar_idx], yt=y[:, tar_idx])

    def generate_test_batch(self, batch_size: int, num_target: int = 400) -> DataAttr:
        """No-context, all-target batch generation helper (useful for evaluation)."""
        return self.generate_batch(batch_size, num_context=0, num_buffer=0, num_target=num_target)
    