"""bav_parametric_torch.py
=================================
IMPORTANT: Deprecated — see `bav_samplerv2.py` for the current version.

PyTorch implementation of a Bayesian Audio-Visual (BAV) localisation
model with

* **constant sensory noise** (σ_V, σ_A),
* **Gaussian spatial prior** 𝒩(μ, σ_s²),
* **Model-Averaging causal inference**,
* **Gaussian motor noise** (σ_m), and
* **Fixed auditory rescaling** (rho = 4/3).

The module exposes two vectorised functions that both operate in an
**unbounded parameter space** — a single flat tensor ``theta`` whose
entries live on ℝ but are internally transformed to their proper ranges.

Functions
---------
```
nll_bav_constant_gaussian(theta, R, S_V, S_A, response_types, ...)
    → scalar negative log-likelihood (∑ over trials)

sample_bav_responses(theta, S_V, S_A, response_types, N=1, rng=None)
    → synthetic responses R_sim (shape: (N, batch))
```

Unconstrained parameter vector ``theta`` (length = 7)
----------------------------------------------------
```
Idx  Name        Raw value in θ        Transform           Effective range
---  ----------  --------------------  ------------------  ---------------
0    log_σ_V     any real             σ_V = exp(θ₀)        (0, ∞)
1    log_σ_A     any real             σ_A = exp(θ₁)        (0, ∞)
2    log_σ_s     any real             σ_s = exp(θ₂)        (0, ∞)
3    log_σ_m     any real             σ_m = exp(θ₃)        (0, ∞)
4    logit_lapse any real             lapse  = σ(θ₄)       (0, 1)
5    logit_p     any real             p_same = σ(θ₅)       (0, 1)
6    μ           any real             mu    = θ₆           (−∞, ∞)
```
``σ(z)`` denotes the logistic (sigmoid) function.
"""

# AI Summary: Implements Bayesian audiovisual localisation NLL & sampler with flat-array convenience wrappers.
# Core update: replaces rectangular grid integration with separable Gauss–Hermite
# quadrature (configurable order) for more accurate and efficient expectation
# under Gaussian sensory noise. Adds helper to generate nodes/weights via NumPy.
# Adds stimulus-grid caching plus `nll_bav_constant_gaussian_flat` and
# `sample_bav_responses_flat` for simplified 98-trial API.

from __future__ import annotations

import math
from typing import Optional, Tuple

import numpy as np
import torch

from src.utils import DataAttr


__all__ = [
    "nll_bav_constant_gaussian",
    "sample_bav_responses",
    "nll_bav_constant_gaussian_flat",
    "sample_bav_responses_flat",
]

two_pi = math.sqrt(2.0 * math.pi)


# ---------------------------------------------------------------------
# Canonical 7×7 stimulus grid & caching for flat-array wrappers
# ---------------------------------------------------------------------

# DEFAULT_STIM_VALUES = torch.tensor(
#     [-15.0, -10.0, -5.0, 0.0, 5.0, 10.0, 15.0], dtype=torch.float32
# )
# _GRID_CACHE: dict[
#     tuple[torch.device, torch.dtype], tuple[torch.Tensor, torch.Tensor, torch.Tensor]
# ] = {}


# def _stimulus_grid(device: torch.device, dtype: torch.dtype = torch.float32):
#     """
#     Return (S_V, S_A, response_types) tensors for the fixed 49-trial grid.

#     Results are cached per (device, dtype) pair to reduce tensor allocation
#     overhead when the wrapper is called repeatedly.
#     """
#     key = (device, dtype)
#     if key in _GRID_CACHE:
#         return _GRID_CACHE[key]

#     stim = DEFAULT_STIM_VALUES.to(device=device, dtype=dtype)
#     grid = torch.cartesian_prod(stim, stim)  # (49, 2)
#     S_V_grid, S_A_grid = grid[:, 0], grid[:, 1]

#     S_V = torch.cat([S_V_grid, S_V_grid], dim=0)  # 49 BV + 49 BA
#     S_A = torch.cat([S_A_grid, S_A_grid], dim=0)

#     response_types = torch.cat(
#         [
#             torch.zeros(49, dtype=torch.long, device=device),  # BV
#             torch.ones(49, dtype=torch.long, device=device),  # BA
#         ],
#         dim=0,
#     )

#     _GRID_CACHE[key] = (S_V, S_A, response_types)
#     return _GRID_CACHE[key]


# -----------------------------------------------------------------------------
# Helper: univariate Gaussian pdf
# -----------------------------------------------------------------------------


def _gaussian_pdf(
    x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor
) -> torch.Tensor:  # noqa: D401,E501
    return torch.exp(-0.5 * ((x - mu) / sigma) ** 2) / (two_pi * sigma)


# -----------------------------------------------------------------------------
# Helper: Gauss–Hermite nodes and weights for ∫ e^{−x²} f(x) dx
# -----------------------------------------------------------------------------


def _gauss_hermite_tensor(
    n: int, device: torch.device, dtype: torch.dtype = torch.float32
):
    """Return nodes *y* and weights *w* (both shape (n,)) as torch tensors."""
    y, w = np.polynomial.hermite.hermgauss(n)  # physicists' Hermite; exp(−x²) weight
    nodes = torch.as_tensor(y, dtype=dtype, device=device)
    weights = torch.as_tensor(w, dtype=dtype, device=device)
    return nodes, weights


# -----------------------------------------------------------------------------
# Utility: bounded parameter conversion
# -----------------------------------------------------------------------------


def _unpack_theta(theta: torch.Tensor, device: torch.device):
    """Convert unconstrained θ → bounded parameters (scalars)."""
    theta = theta.to(device)
    sigma_V, sigma_A, sigma_s, sigma_m = torch.exp(theta[0:4])
    # lapse = torch.sigmoid(theta[4])
    lapse = torch.sigmoid(torch.logit(torch.tensor(0.02)))
    p_same = torch.sigmoid(theta[4])
    # mu_p = theta[6]
    mu_p = 0.0
    return sigma_V, sigma_A, sigma_s, sigma_m, lapse, p_same, mu_p


# -----------------------------------------------------------------------------
# Negative log‑likelihood (Gauss–Hermite integration)
# -----------------------------------------------------------------------------


def nll_bav_constant_gaussian(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    R: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    *,
    gh_deg: int = 51,
    chunk_size: Optional[int] = None,
    # deprecated — retained for backward‑compatibility; ignored
    grid_step: float | None = None,
    grid_range_sd: float | None = None,
) -> torch.Tensor:
    """Summed NLL across trials for the BAV model using Gauss–Hermite quadrature.

    Parameters
    ----------
    gh_deg : int, optional
        Number of Gauss–Hermite nodes *per dimension* (default **21**). Larger
        values increase accuracy at higher computational cost.
    chunk_size : int, optional
        If provided, evaluates trials in chunks to reduce GPU memory usage.
    """

    device = R.device
    (
        sigma_V,
        sigma_A,
        sigma_s,
        sigma_m,
        lapse,
        p_same,
        mu_p,
    ) = _unpack_theta(theta, device)

    # Variances / precisions (scalars)
    v_V, v_A, v_s = sigma_V**2, sigma_A**2, sigma_s**2
    iv_V, iv_A, iv_s = 1.0 / v_V, 1.0 / v_A, 1.0 / v_s

    # --- constants for p(x|C=1) --------------------------------------
    a, b, d = v_V + v_s, v_s, v_A + v_s
    det_c1 = a * d - b * b
    inv00, inv11, inv01 = d / det_c1, a / det_c1, -b / det_c1
    log_norm_c1 = -0.5 * (math.log((2 * math.pi) ** 2) + math.log(det_c1))

    # --- constants for p(x|C=2) --------------------------------------
    v_Vbar, v_Abar = v_V + v_s, v_A + v_s
    log_norm_c2_V = -0.5 * (math.log(2 * math.pi) + math.log(v_Vbar))
    log_norm_c2_A = -0.5 * (math.log(2 * math.pi) + math.log(v_Abar))

    weight_sum_c1 = iv_V + iv_A + iv_s
    weight_V, weight_A = iv_V + iv_s, iv_A + iv_s

    # --- Gauss–Hermite grid (shared across trials) -------------------
    nodes_V, w_V = _gauss_hermite_tensor(gh_deg, device, R.dtype)
    nodes_A, w_A = _gauss_hermite_tensor(gh_deg, device, R.dtype)
    # scale nodes to measurement noise distribution:  x = √2 σ • y
    rel_V = sigma_V * math.sqrt(2.0) * nodes_V  # (N_V,)
    rel_A = sigma_A * math.sqrt(2.0) * nodes_A  # (N_A,)
    weight_mat = (w_V[:, None] * w_A[None, :]) / math.pi  # (N_V, N_A)

    # --- chunked computation -----------------------------------------
    if chunk_size is None:
        chunk_size = R.numel()

    nll_total = torch.tensor(0.0, device=device, dtype=R.dtype)
    for start in range(0, R.numel(), chunk_size):
        end = min(start + chunk_size, R.numel())
        nll_total += _chunk_nll(
            RHO_A,
            S_V[start:end],
            S_A[start:end],
            R[start:end],
            response_types[start:end],
            rel_V,
            rel_A,
            weight_mat,
            mu_p,
            sigma_m,
            lapse,
            p_same,
            iv_V,
            iv_A,
            iv_s,
            weight_sum_c1,
            weight_V,
            weight_A,
            inv00,
            inv11,
            inv01,
            log_norm_c1,
            log_norm_c2_V,
            log_norm_c2_A,
        )

    return nll_total


# -----------------------------------------------------------------------------
# Per‑chunk helper (vectorised across trials)
# -----------------------------------------------------------------------------


def _chunk_nll(
    RHO_A: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    R: torch.Tensor,
    rt: torch.Tensor,
    rel_V: torch.Tensor,
    rel_A: torch.Tensor,
    weight_mat: torch.Tensor,
    mu_p: torch.Tensor,
    sigma_m: torch.Tensor,
    lapse: torch.Tensor,
    p_same: torch.Tensor,
    iv_V: float,
    iv_A: float,
    iv_s: float,
    weight_sum_c1: float,
    weight_V: float,
    weight_A: float,
    inv00: float,
    inv11: float,
    inv01: float,
    log_norm_c1: float,
    log_norm_c2_V: float,
    log_norm_c2_A: float,
) -> torch.Tensor:
    """Compute summed NLL for *this* batch of trials using G–H quadrature."""

    # Expand Gauss–Hermite nodes to trials
    xV = S_V[:, None, None] + rel_V[None, :, None]  # (B, N_V, N_A)
    xA = RHO_A * (S_A[:, None, None]) + rel_A[None, None, :]  # (B, N_V, N_A)

    # Posterior P(C=1|x)
    zV, zA = xV - mu_p, xA - mu_p
    quad_c1 = inv00 * zV * zV + 2 * inv01 * zV * zA + inv11 * zA * zA
    log_p_c1 = log_norm_c1 - 0.5 * quad_c1
    log_p_c2 = (
        log_norm_c2_V
        - 0.5 * (zV**2) / (1.0 / iv_V + 1.0 / iv_s)
        + log_norm_c2_A
        - 0.5 * (zA**2) / (1.0 / iv_A + 1.0 / iv_s)
    )

    log_ps = torch.log(p_same)
    post_c1 = torch.exp(
        log_ps
        + log_p_c1
        - torch.logaddexp(log_ps + log_p_c1, torch.log1p(-p_same) + log_p_c2)
    )

    # Posterior means of s
    mu_c1 = (iv_V * xV + iv_A * xA + iv_s * mu_p) / weight_sum_c1
    mu_c2_V = (iv_V * xV + iv_s * mu_p) / weight_V
    mu_c2_A = (iv_A * xA + iv_s * mu_p) / weight_A
    mu_c2 = torch.where(rt[:, None, None] == 0, mu_c2_V, mu_c2_A)

    s_hat = post_c1 * mu_c1 + (1.0 - post_c1) * mu_c2

    # Response likelihood
    ll_r = _gaussian_pdf(R[:, None, None], s_hat, sigma_m)  # (B, N_V, N_A)
    prob_r = torch.sum(ll_r * weight_mat, dim=(1, 2))  # (B,)

    # Lapse mixture & NLL
    prob_r = (1.0 - lapse) * prob_r + lapse / 90.0
    return -torch.sum(torch.log(prob_r + 1e-12))


# ==============================================================
#  BAV RESPONSE SAMPLER  (constant-noise, model-averaging)
# ==============================================================


def sample_bav_responses(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    *,
    N: int = 1,
) -> torch.Tensor:
    """
    Draw synthetic motor responses for the Bayesian Audio-Visual (BAV)
    localisation model with

    * constant sensory noise (σ_V, σ_A),
    * Gaussian spatial prior 𝒩(μ, σ_s²),
    * **model-averaging** causal inference,
    * Gaussian motor noise (σ_m), and
    * uniform-lapse probability.

    All parameters live in an **unbounded space** exactly like
    `nll_bav_constant_gaussian`:

    ╔════╤══════════════╤════════════════════════════════════╗
    ║idx │ meaning      │   forward transform                ║
    ╟────┼──────────────┼────────────────────────────────────╢
    ║ 0  │ log σ_V      │ σ_V   = exp(θ₀)                   ║
    ║ 1  │ log σ_A      │ σ_A   = exp(θ₁)                   ║
    ║ 2  │ log σ_s      │ σ_s   = exp(θ₂)                   ║
    ║ 3  │ log σ_m      │ σ_m   = exp(θ₃)                   ║
    ║ 4  │ logit(lapse) │ lapse  = sigmoid(θ₄)              ║
    ║ 5  │ logit(pₛ)    │ p_same = sigmoid(θ₅)              ║
    ║ 6  │ μ            │ mu     = θ₆                       ║
    ╚════╧══════════════╧════════════════════════════════════╝

    Parameters
    ----------
    theta : (7,) tensor
        Flat unconstrained parameter vector.
    S_V, S_A : (T,) tensors
        True stimulus locations (deg) for each trial.
    response_types : (T,) tensor
        0 → BV (visual report), 1 → BA (auditory report).
    N : int, optional
        Number of responses *per trial* to sample (default 1).

    Returns
    -------
    R_sim : (N, T) tensor
        Simulated responses (deg).
    """
    device = S_V.device
    theta = theta.to(device)

    # -----------------------------------------------------------
    # 1.  Unpack θ  (all scalar tensors)
    # -----------------------------------------------------------
    sigma_V, sigma_A, sigma_s, sigma_m = torch.exp(theta[:4])
    lapse = torch.sigmoid(torch.logit(torch.tensor(0.02)))
    p_same = torch.sigmoid(theta[4])
    mu_p = 0

    v_V, v_A, v_s = sigma_V**2, sigma_A**2, sigma_s**2
    iv_V, iv_A, iv_s = 1.0 / v_V, 1.0 / v_A, 1.0 / v_s

    # -----------------------------------------------------------
    # 2.  Draw sensory measurements  x_V , x_A
    # -----------------------------------------------------------
    T = S_V.numel()
    x_V = S_V.unsqueeze(0) + sigma_V * torch.randn((N, T), device=device)

    # Auditory rescaling
    x_A = (RHO_A * S_A).unsqueeze(0) + sigma_A * torch.randn((N, T), device=device)

    # -----------------------------------------------------------
    # 3.  Compute posterior Pr(C=1 | x_V, x_A)  (vectorised)
    # -----------------------------------------------------------
    # Constants for p(x | C=1)  ~ N([μ, μ],  Σ_C1 )
    a, b, d = v_V + v_s, v_s, v_A + v_s  # Σ_C1 entries
    det_c1 = a * d - b * b
    inv00, inv11, inv01 = d / det_c1, a / det_c1, -b / det_c1
    log_norm_c1 = -0.5 * (math.log((2 * math.pi) ** 2) + math.log(det_c1))

    # Constants for independent-cause likelihood
    v_V_bar, v_A_bar = v_V + v_s, v_A + v_s
    log_norm_c2_V = -0.5 * (math.log(2 * math.pi) + math.log(v_V_bar))
    log_norm_c2_A = -0.5 * (math.log(2 * math.pi) + math.log(v_A_bar))

    zV = x_V - mu_p
    zA = x_A - mu_p

    quad_c1 = inv00 * zV**2 + 2 * inv01 * zV * zA + inv11 * zA**2
    log_p_c1 = log_norm_c1 - 0.5 * quad_c1

    log_p_c2 = (
        log_norm_c2_V - 0.5 * zV**2 / v_V_bar + log_norm_c2_A - 0.5 * zA**2 / v_A_bar
    )

    #   P(C=1|x)   (shape: N × T)
    logit_pc1 = torch.log(p_same) + log_p_c1 - (torch.log1p(-p_same) + log_p_c2)
    post_c1 = torch.sigmoid(logit_pc1)

    # -----------------------------------------------------------
    # 4.  Posterior means  μ̂_C1  and  μ̂_C2
    # -----------------------------------------------------------
    weight_sum_c1 = iv_V + iv_A + iv_s
    weight_V = iv_V + iv_s
    weight_A = iv_A + iv_s

    mu_c1 = (iv_V * x_V + iv_A * x_A + iv_s * mu_p) / weight_sum_c1
    mu_c2_V = (iv_V * x_V + iv_s * mu_p) / weight_V
    mu_c2_A = (iv_A * x_A + iv_s * mu_p) / weight_A

    # Choose μ̂_C2 according to requested report (BV or BA)
    rt = response_types.unsqueeze(0)  # shape (1, T)
    mu_c2 = torch.where(rt == 0, mu_c2_V, mu_c2_A)

    # Model-averaged estimate
    s_hat = post_c1 * mu_c1 + (1.0 - post_c1) * mu_c2

    # -----------------------------------------------------------
    # 5.  Add motor noise  &  lapses
    # -----------------------------------------------------------
    R_noisy = s_hat + sigma_m * torch.randn_like(s_hat)

    if lapse > 0.0:
        lapse_mask = torch.rand_like(R_noisy) < lapse
        R_uniform = -45.0 + 90.0 * torch.rand_like(R_noisy)
        R_noisy = torch.where(lapse_mask, R_uniform, R_noisy)

    return R_noisy


# ---------------------------------------------------------------------
# S_V S_A rt sampler
# ---------------------------------------------------------------------


def _sample_inputs(
    device: torch.device,
    num_points: int = 98,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate (S_V, S_A, response_types) on the given device/dtype.

    1) S_A ~ Uniform over {-15, -10, -5, 0, 5, 10, 15}  (discrete)
    2) With 50% probability set S_V = S_A; otherwise S_V ~ Uniform(-20, 20) (continuous)
    3) response_types: 50% BV (0) and 50% BA (1).
    """
    if num_points <= 0:
        raise ValueError("num_points must be positive.")

    # 1) Discrete uniform S_A over the 7 canonical positions
    sa_values = torch.tensor(
        [-15.0, -10.0, -5.0, 0.0, 5.0, 10.0, 15.0],
        device=device,
        dtype=dtype,
    )
    idx = torch.randint(
        low=0, high=sa_values.numel(), size=(num_points,), device=device
    )
    S_A = sa_values[idx]

    # 2) S_V equals S_A w.p. 0.5, else uniform(-20, 20)
    same_mask = torch.rand(num_points, device=device) < 0.5
    S_V_alt = -20.0 + 40.0 * torch.rand(num_points, device=device, dtype=dtype)
    S_V = torch.where(same_mask, S_A, S_V_alt).to(dtype)

    # 3) Response types: i.i.d. Bernoulli(0.5)
    response_types = (torch.rand(num_points, device=device) < 0.5).to(torch.long)

    return S_V.to(dtype), S_A.to(dtype), response_types


# ---------------------------------------------------------------------
# Flat-array convenience wrappers
# ---------------------------------------------------------------------


def nll_bav_constant_gaussian_flat(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    R_flat: torch.Tensor,
    *,
    gh_deg: int = 51,
    chunk_size: Optional[int] = None,
) -> torch.Tensor:
    """
    Wrapper around :func:`nll_bav_constant_gaussian` that accepts a single
    vector ``R_flat`` of length 98 (49 visual-report trials followed by
    49 auditory-report trials).
    """
    if R_flat.numel() != 98:
        raise ValueError("R_flat must contain exactly 98 elements (49 BV + 49 BA).")

    device, dtype = R_flat.device, R_flat.dtype
    S_V, S_A, rt = _sample_inputs(device, dtype)

    return nll_bav_constant_gaussian(
        RHO_A, theta, R_flat, S_V, S_A, rt, gh_deg=gh_deg, chunk_size=chunk_size
    )


def sample_bav_responses_flat(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    *,
    N: int = 1,
    num_points: int = 98,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    Draw synthetic responses arranged in the canonical (N, 98) flat format.
    """
    if device is None:
        device = torch.device("cpu")

    S_V, S_A, rt = _sample_inputs(device, num_points, dtype)
    return (
        S_V,
        S_A,
        rt,
        sample_bav_responses(RHO_A, theta.to(device), S_V, S_A, rt, N=N),
    )


class BavSampler:
    THETA_DIM = 5

    def __init__(
        self,
        RHO_A: float = 4.0 / 3.0,
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
    ):
        self.RHO_A = RHO_A
        self.device = device
        self.dtype = dtype

    def _sample_theta(self, batch_size):
        loc = torch.zeros(self.THETA_DIM, dtype=torch.float32)
        scale = torch.ones(self.THETA_DIM, dtype=torch.float32)
        dist = torch.distributions.Normal(loc, scale)
        theta_samples = dist.sample((batch_size,))  # [B, THETA_DIM]
        return theta_samples

    def _sample_num_context(self, context_range):
        """Sample from [low, high] if len=2, else uniformly from the list."""
        n = len(context_range)
        if n == 2:
            low, high = context_range
            if low > high:
                raise ValueError("Invalid `context_range`: low > high.")
            return torch.randint(low, high + 1, (1,)).item()

        if n > 2:
            choices = torch.as_tensor(context_range)
            idx = torch.randint(0, choices.numel(), (1,))
            return choices[idx].item()
        raise ValueError("`context_range` must have length 2 or > 2.")

    def generate_batch(
        self,
        batch_size: int,
        num_context: Optional[int] = None,
        num_buffer: int = 50,
        num_target: int = 50,
        context_range: Tuple[int, int] = (3, 47),
    ) -> DataAttr:
        if num_context is None:
            num_context = self._sample_num_context(context_range)

        num_total = num_context + num_buffer + num_target

        thetas = self._sample_theta(batch_size)

        S_Vs, S_As, rts, Y = [], [], [], []
        for theta in thetas:
            S_V, S_A, rt, resp = sample_bav_responses_flat(
                self.RHO_A,
                theta,
                num_points=num_total,
                device=self.device,
                dtype=self.dtype,
            )
            S_Vs.append(S_V)
            S_As.append(S_A)
            rts.append(rt)
            Y.append(resp.squeeze(0))

        S_Vs, S_As, rts = (torch.stack(t) for t in (S_Vs, S_As, rts))

        x = torch.stack((rts, S_As, S_Vs), dim=-1)
        y = torch.stack(Y).unsqueeze(-1)

        perm = torch.randperm(num_total, device=self.device)
        ctx_idx = perm[:num_context]
        buf_idx = perm[num_context : num_context + num_buffer]
        tar_idx = perm[num_context + num_buffer :]

        if num_context == 0:
            # dummy context for no context
            xc = torch.tensor([2.0, 0., 0.])[None, None, :].expand(
                batch_size, 1, 3
            )
            yc = torch.tensor([0.])[None, None, :].expand(batch_size, 1, 1)
        else:
            xc = x[:, ctx_idx]
            yc = y[:, ctx_idx]

        return DataAttr(
            xc=xc,
            yc=yc,
            xb=x[:, buf_idx],
            yb=y[:, buf_idx],
            xt=x[:, tar_idx],
            yt=y[:, tar_idx],
        )
    
    def generate_test_batch(
        self,
        batch_size: int,
        num_target: int = 400,
    ) -> DataAttr:
        return self.generate_batch(
            batch_size,
            num_context=0,
            num_buffer=0,
            num_target=num_target,
        )
    
    def generate_conditioned_test_batch(
            self,
            batch_size: int,
            num_context: int = 100,
            num_target: int = 300,
    ) -> DataAttr:
        return self.generate_batch(
            batch_size,
            num_context=num_context,
            num_buffer=0,
            num_target=num_target,
        )


def _test_sample_bav_responses_flat(num_points):
    """
    Minimal smoke test for sample_bav_responses_flat:
      - sets a plausible theta
      - samples N=4 response sets on CPU/GPU
      - asserts shape/dtype/finite values
      - prints a small preview
    """

    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # Build theta (7,) in the unconstrained space
    # Indices used by the current sampler:
    #   0..3: log sigmas (V, A, s, m)
    #   4   : logit(p_same)  (note: lapse is fixed internally)
    #   5   : unused here
    #   6   : mu (ignored in current sampler; set 0)
    theta = torch.empty(7, device=device, dtype=dtype)
    theta[0] = torch.log(torch.tensor(3.0, device=device, dtype=dtype))  # σ_V
    theta[1] = torch.log(torch.tensor(5.0, device=device, dtype=dtype))  # σ_A
    theta[2] = torch.log(torch.tensor(8.0, device=device, dtype=dtype))  # σ_s
    theta[3] = torch.log(torch.tensor(2.0, device=device, dtype=dtype))  # σ_m
    theta[4] = torch.logit(
        torch.tensor(0.7, device=device, dtype=dtype)
    )  # p_same ≈ 0.7
    theta[5] = torch.tensor(0.0, device=device, dtype=dtype)  # (unused)
    theta[6] = torch.tensor(0.0, device=device, dtype=dtype)  # μ

    # Sample N=4 draws for the canonical 98 trials
    N = 4
    _, _, _, R = sample_bav_responses_flat(
        theta, num_points=num_points, N=N, device=device, dtype=dtype
    )

    # Basic checks
    assert R.shape == (N, num_points), f"Unexpected shape: {R.shape}"
    assert R.dtype == dtype, f"Unexpected dtype: {R.dtype}"
    assert R.device == device, f"Unexpected device: {R.device}"
    assert torch.isfinite(R).all(), "Non-finite values in sampled responses"

    # Quick peek
    print("OK ✓  sample_bav_responses_flat")
    print("shape:", tuple(R.shape), "| dtype:", R.dtype, "| device:", R.device)
    print("first row, first 8 values:", R[0, :8].tolist())

    return R


if __name__ == "__main__":
    _test_sample_bav_responses_flat(400)

    sampler = BavSampler()
    data = sampler.generate_batch(5, 3, 2, 4)
