# AI Summary: Multi-level BAV model (σV_low/med/high) with per-trial visual noise levels.
# Provides NLL via Gauss–Hermite quadrature, sampling, random-trial wrappers returning
# V_levels, and BavSampler with 4-feature inputs. Adds gradient-based MLE fitting via
# fit_bav_mle(...) and fit_bav_mle_from_xy(...). Backward-compat: 5-D θ uses one σV.
# Adds paired-batch generator to simulate on identical trials with RHO_A ∈ {4/3, 1}.

"""
Multi-level Bayesian Audio-Visual (BAV) localization model with causal inference.

Overview
--------
This module implements a BAV model where the *visual* sensory noise can take one of three
levels per trial (low/med/high), indicated by an integer tensor `V_levels ∈ {0,1,2}`.
The rest of the model follows the standard BAV formulation:

- Sensory measurements: x_V ~ N(S_V, σ_V^2), x_A ~ N(ρ·S_A, σ_A^2)
- Gaussian spatial prior over s:   s ~ N(μ, σ_s^2), with μ fixed to 0
- Causal inference: model-averaging between C=1 (common source) and C=2 (independent)
- Motor noise on the reported estimate:  R ~ N(ŝ, σ_m^2)
- Lapse process: with small prob. (0.02) emit a uniform response on [-45°, 45°]

Key features
------------
1) nll_bav_constant_gaussian(...) evaluates a summed negative log-likelihood using
   separable 2D Gauss–Hermite quadrature over (x_V, x_A). For efficiency, trials are
   split by `V_levels` and each subset is integrated with its corresponding σ_V.
2) sample_bav_responses(...) draws synthetic responses by sampling measurements,
   computing model-averaged posterior estimates, then adding motor noise and lapses.
3) Random-trial wrappers sample (S_V, S_A, response_types, V_levels) for convenience.
4) BavSampler assembles batched (context, buffer, target) splits with input features
   (response_type, V_level, S_A, S_V). The parameter vector θ is 7-D in the new layout.
5) NEW: fit_bav_mle(...) and fit_bav_mle_from_xy(...) perform gradient-based MLE for θ.

Parameter vector θ (unconstrained → transformed)
------------------------------------------------
New (len=7):
  [ log σV_low, log σV_med, log σV_high, log σA, log σs, log σm, logit(p_same) ]
Compat (len=5): all three σV levels use the single σV from indices [0], and the rest map
  as in the constant-noise model. In both cases, lapse=0.02 and μ=0 are fixed internally.

Default priors for θ sampling (BavSampler._sample_theta)
--------------------------------------------------------
All priors are **truncated Gaussians** (unit variance) in the unconstrained space,
truncated at ±2 standard deviations and resampled until within bounds:
  - log σV_low    ~ N(0.00, 1)   truncated to [−2.00, +2.00]      → σV_low ≈ e^0.00
  - log σV_med    ~ N(0.75, 1)   truncated to [−1.25, +2.75]      → σV_med ≈ e^0.75
  - log σV_high   ~ N(1.50, 1)   truncated to [−0.50, +3.50]      → σV_high ≈ e^1.50
  - log σA        ~ N(0.75, 1)   truncated to [−1.25, +2.75]      → σA ≈ e^0.75
  - log σs        ~ N(0.00, 1)   truncated to [−2.00, +2.00]      (unchanged default)
  - log σm        ~ N(0.00, 1)   truncated to [−2.00, +2.00]      (unchanged default)
  - logit(p_same) ~ N(0.00, 1)   truncated to [−2.00, +2.00]      (unchanged default)

Rationale: Truncation avoids extreme values that can destabilize GH integration or produce
unreasonable noise scales, while keeping priors weakly informative.

Public APIs (signatures)
------------------------
nll_bav_constant_gaussian(RHO_A, theta, R, S_V, S_A, response_types, V_levels, *, gh_deg=51, chunk_size=None)
sample_bav_responses(RHO_A, theta, S_V, S_A, response_types, V_levels, *, N=1)
sample_bav_responses_random_trials(RHO_A, theta, *, N=1, num_points=98, device=None, dtype=torch.float32)
fit_bav_mle(RHO_A, R, S_V, S_A, response_types, V_levels, *, init_theta=None, theta_dim=7, lr=0.05, steps=500, gh_deg=31, chunk_size=None, weight_decay=0.0, patience=50, verbose=False, return_history=False)
fit_bav_mle_from_xy(RHO_A, x, y, **kwargs)
"""

from __future__ import annotations

import math
from typing import Optional, Tuple

import numpy as np
import torch
from dataclasses import dataclass
#from src.utils import DataAttr

@dataclass
class DataAttr:
    """Dataclass for neural process data with all six components."""

    xc: Optional[torch.Tensor] = None
    yc: Optional[torch.Tensor] = None
    xb: Optional[torch.Tensor] = None
    yb: Optional[torch.Tensor] = None
    xt: Optional[torch.Tensor] = None
    yt: Optional[torch.Tensor] = None
    
    def to(self, device):
        """Move all tensors to the specified device."""
        return DataAttr(
            xc=self.xc.to(device) if self.xc is not None else None,
            yc=self.yc.to(device) if self.yc is not None else None,
            xb=self.xb.to(device) if self.xb is not None else None,
            yb=self.yb.to(device) if self.yb is not None else None,
            xt=self.xt.to(device) if self.xt is not None else None,
            yt=self.yt.to(device) if self.yt is not None else None,
        )

__all__ = [
    "nll_bav_constant_gaussian",
    "sample_bav_responses",
    "sample_bav_responses_random_trials",
    "BavSampler",
    "fit_bav_mle",
    "fit_bav_mle_from_xy",
]

# sqrt(2π) used in the univariate Gaussian pdf
SQRT_TWO_PI = math.sqrt(2.0 * math.pi)


def _gaussian_pdf(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """Elementwise univariate Gaussian PDF.

    Args:
        x: (...,) evaluation points.
        mu: (...,) mean (broadcastable to x).
        sigma: (...,) std dev (broadcastable to x).

    Returns:
        (...,) PDF values for N(mu, sigma^2).
    """
    return torch.exp(-0.5 * ((x - mu) / sigma) ** 2) / (SQRT_TWO_PI * sigma)


def _gauss_hermite_tensor(n: int, device: torch.device, dtype: torch.dtype):
    """Return physicists' Gauss–Hermite nodes/weights for ∫ e^{−y²} f(y) dy.

    We use the change of variables x = mean + √2·σ·y to integrate w.r.t. a Gaussian.

    Args:
        n: number of nodes (accuracy/speed tradeoff).
        device, dtype: torch placement for returned tensors.

    Returns:
        (y, w): shape (n,) nodes and weights as torch tensors.
    """
    y, w = np.polynomial.hermite.hermgauss(n)
    return (
        torch.as_tensor(y, dtype=dtype, device=device),
        torch.as_tensor(w, dtype=dtype, device=device),
    )


def _unpack_theta(theta: torch.Tensor, device: torch.device, dtype: torch.dtype):
    """Map unconstrained θ → model parameters (supports 7-D new or 5-D compat layout).

    New layout (len=7):
        [ log σV_low, log σV_med, log σV_high, log σA, log σs, log σm, logit(p_same) ]
    Compat layout (len=5):
        [ log σV, log σA, log σs, log σm, logit(p_same) ]
        → same σV is used for all three visual noise levels.

    Fixed constants:
        lapse = 0.02, μ = 0.0

    Returns:
        (σV_low, σV_med, σV_high, σA, σs, σm, lapse, p_same, μ)
        All returned as torch tensors on (device, dtype).
    """
    th = theta.to(device=device, dtype=dtype).flatten()
    if th.numel() >= 7:
        sigV_low, sigV_med, sigV_high = torch.exp(th[0]), torch.exp(th[1]), torch.exp(th[2])
        sigA, sigS, sigM = torch.exp(th[3]), torch.exp(th[4]), torch.exp(th[5])
        p_same = torch.sigmoid(th[6])
    else:
        # Backward compatibility (len=5): reuse single σV for all levels
        sigV = torch.exp(th[0])
        sigV_low = sigV_med = sigV_high = sigV
        sigA, sigS, sigM = torch.exp(th[1]), torch.exp(th[2]), torch.exp(th[3])
        p_same = torch.sigmoid(th[4])

    lapse = torch.tensor(0.02, device=device, dtype=dtype)  # fixed lapse prob
    mu_p = torch.tensor(0.0, device=device, dtype=dtype)    # fixed prior mean
    return sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, mu_p


def _chunk_nll(
    RHO_A: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    R: torch.Tensor,
    rt: torch.Tensor,
    rel_V: torch.Tensor,
    rel_A: torch.Tensor,
    weight_mat: torch.Tensor,
    mu_p: torch.Tensor,
    sigma_m: torch.Tensor,
    lapse: torch.Tensor,
    p_same: torch.Tensor,
    iv_V: torch.Tensor,
    iv_A: torch.Tensor,
    iv_s: torch.Tensor,
    weight_sum_c1: torch.Tensor,
    weight_V: torch.Tensor,
    weight_A: torch.Tensor,
    inv00: torch.Tensor,
    inv11: torch.Tensor,
    inv01: torch.Tensor,
    log_norm_c1: torch.Tensor,
    log_norm_c2_V: torch.Tensor,
    log_norm_c2_A: torch.Tensor,
) -> torch.Tensor:
    """Summed NLL for a batch using precomputed constants and GH rel_V/rel_A.

    This function is called repeatedly on (possibly) small slices of trials (chunks),
    which reduces memory footprint. The math follows the conceptual model:

      1) Build GH grids over x_V, x_A around S_V, ρ·S_A using rel_V, rel_A
      2) Compute p(x|C=1), p(x|C=2) (log domain) and posterior Pr(C=1|x)
      3) Compute posterior means μ̂_{C=1}, μ̂_{C=2} and model-averaged estimate ŝ
      4) Evaluate the motor-noise likelihood p(R | ŝ) and integrate over x via GH
      5) Mix with lapses and sum −log likelihoods over trials

    Shapes:
        S_V, S_A, R, rt: (B,)
        rel_V: (N_V,), rel_A: (N_A,), weight_mat: (N_V, N_A)
        Returns scalar NLL for this chunk.
    """
    # 1) Expand GH grids for this batch of trials
    xV = S_V[:, None, None] + rel_V[None, :, None]
    xA = RHO_A * (S_A[:, None, None] + rel_A[None, None, :])

    # 2) Posterior over causal structure P(C=1|x): compute in log domain then exponentiate
    zV, zA = xV - mu_p, xA - mu_p
    quad_c1 = inv00 * zV * zV + 2 * inv01 * zV * zA + inv11 * zA * zA
    log_p_c1 = log_norm_c1 - 0.5 * quad_c1

    vVbar = 1.0 / iv_V + 1.0 / iv_s
    vAbar = 1.0 / iv_A + 1.0 / iv_s
    log_p_c2 = log_norm_c2_V - 0.5 * (zV * zV) / vVbar + log_norm_c2_A - 0.5 * (zA * zA) / vAbar

    post_c1 = torch.exp(
        torch.log(p_same) + log_p_c1
        - torch.logaddexp(torch.log(p_same) + log_p_c1, torch.log1p(-p_same) + log_p_c2)
    )

    # 3) Posterior means under C=1 and C=2 (report-dependent for C=2)
    mu_c1 = (iv_V * xV + iv_A * xA + iv_s * mu_p) / weight_sum_c1
    mu_c2_V = (iv_V * xV + iv_s * mu_p) / weight_V
    mu_c2_A = (iv_A * xA + iv_s * mu_p) / weight_A
    mu_c2 = torch.where(rt[:, None, None] == 0, mu_c2_V, mu_c2_A)

    # Model-averaged estimate ŝ
    s_hat = post_c1 * mu_c1 + (1.0 - post_c1) * mu_c2

    # 4) Likelihood of response given estimate, integrated over x with GH weights
    ll_r = _gaussian_pdf(R[:, None, None], s_hat, sigma_m)
    prob_r = torch.sum(ll_r * weight_mat, dim=(1, 2))

    # 5) Lapse mixture and final NLL for the chunk
    prob_r = (1.0 - lapse) * prob_r + lapse / 90.0
    return -torch.sum(torch.log(prob_r + 1e-12))


def nll_bav_constant_gaussian(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    R: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    V_levels: torch.Tensor,
    *,
    gh_deg: int = 51,
    chunk_size: Optional[int] = None,
) -> torch.Tensor:
    """Summed NLL with three σ_V levels selected by V_levels ∈ {0,1,2}.

    Args:
        RHO_A: auditory rescaling factor ρ (float/tensor; broadcasts).
        theta: parameters; supports len 7 (3×σV + σA, σs, σm, p_same) or len 5 (compat).
        R: (T,) observed responses (deg).
        S_V, S_A: (T,) true stimulus locations (deg).
        response_types: (T,) 0=BV (visual report), 1=BA (auditory report).
        V_levels: (T,) integers in {0,1,2} indicating low/med/high visual noise.
        gh_deg: Gauss–Hermite nodes per dimension (51 is accurate but costlier).
        chunk_size: optional memory bound; trials are processed in pieces.

    Returns:
        Scalar tensor: negative log-likelihood summed across trials.

    Implementation notes:
        - GH nodes/weights are computed once. rel_A depends only on σ_A; rel_V depends
          on the σ_V of the current level and is recomputed per level.
        - Trials are partitioned by `V_levels`. Each subset uses its σ_V-level constants.
    """
    device, dtype = R.device, R.dtype
    sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, mu_p = _unpack_theta(
        theta, device, dtype
    )
    RHO_A = torch.as_tensor(RHO_A, device=device, dtype=dtype)

    # Precompute GH nodes/weights (shared) and rel_A (depends only on σ_A)
    nodes_V, w_V = _gauss_hermite_tensor(gh_deg, device, dtype)
    nodes_A, w_A = _gauss_hermite_tensor(gh_deg, device, dtype)
    rel_A = sigA * math.sqrt(2.0) * nodes_A
    weight_mat = (w_V[:, None] * w_A[None, :]) / math.pi

    nll_total = torch.tensor(0.0, device=device, dtype=dtype)
    sigV_levels = (sigV_low, sigV_med, sigV_high)

    for lvl, sigV in enumerate(sigV_levels):
        # Select the subset of trials that belong to this visual noise level
        idx = torch.nonzero(V_levels == lvl, as_tuple=False).squeeze(-1)
        if idx.numel() == 0:
            continue

        # Level-specific variances/precisions and normalization constants
        vV = sigV * sigV
        vA = sigA * sigA
        vS = sigS * sigS
        iv_V = 1.0 / vV
        iv_A = 1.0 / vA
        iv_s = 1.0 / vS

        # C=1 (common source) covariance Σ for x = (xV, xA):
        # Σ = [[vV+vS, vS],[vS, vA+vS]], with det and inverse entries
        a = vV + vS
        b = vS
        d = vA + vS
        det = a * d - b * b
        inv00 = d / det
        inv11 = a / det
        inv01 = -b / det
        log_norm_c1 = -0.5 * (math.log((2 * math.pi) ** 2) + torch.log(det))

        # C=2 (independent sources, same prior per modality)
        vVbar = vV + vS
        vAbar = vA + vS
        log_norm_c2_V = -0.5 * (math.log(2 * math.pi) + torch.log(vVbar))
        log_norm_c2_A = -0.5 * (math.log(2 * math.pi) + torch.log(vAbar))

        weight_sum_c1 = iv_V + iv_A + iv_s
        weight_V = iv_V + iv_s
        weight_A = iv_A + iv_s

        # rel_V depends on σ_V of this level
        rel_V = sigV * math.sqrt(2.0) * nodes_V

        # Chunked computation on the selected subset of trials
        sub_R = R[idx]
        cs = sub_R.numel() if chunk_size is None else chunk_size
        for start in range(0, sub_R.numel(), cs):
            sl = slice(start, min(start + cs, sub_R.numel()))
            nll_total += _chunk_nll(
                RHO_A,
                S_V[idx][sl],
                S_A[idx][sl],
                sub_R[sl],
                response_types[idx][sl],
                rel_V,
                rel_A,
                weight_mat,
                mu_p,
                sigM,
                lapse,
                p_same,
                iv_V,
                iv_A,
                iv_s,
                weight_sum_c1,
                weight_V,
                weight_A,
                inv00,
                inv11,
                inv01,
                log_norm_c1,
                log_norm_c2_V,
                log_norm_c2_A,
            )

    return nll_total


def sample_bav_responses(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    V_levels: torch.Tensor,
    *,
    N: int = 1,
) -> torch.Tensor:
    """Draw synthetic responses with per-trial visual noise selected by V_levels.

    Args:
        RHO_A: auditory rescaling factor ρ (float/tensor).
        theta: parameter vector (7-D new or 5-D compat); lapse, μ fixed internally.
        S_V, S_A: (T,) stimuli (deg).
        response_types: (T,) 0=BV (visual), 1=BA (auditory).
        V_levels: (T,) 0/1/2 select σV_low/med/high for each trial.
        N: number of responses per trial to sample.

    Returns:
        (N, T) sampled responses.
    """
    device, dtype = S_V.device, S_V.dtype
    (sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, mu_p) = _unpack_theta(
        theta, device, dtype
    )
    RHO_A = torch.as_tensor(RHO_A, device=device, dtype=dtype)

    T = S_V.numel()
    # Build per-trial visual noise σ_V by indexing levels
    sigV_levels = torch.stack([sigV_low, sigV_med, sigV_high])
    sigV_trial = sigV_levels[V_levels]  # (T,)

    # 1) Sensory measurements x_V, x_A with appropriate noise levels
    # Generative auditory rescaling: scale the noisy observation itself, but
    # KEEP inference covariance unscaled (σ_A) in all subsequent computations.
    x_V = S_V.unsqueeze(0) + sigV_trial.unsqueeze(0) * torch.randn((N, T), device=device, dtype=dtype)
    x_A = RHO_A * (S_A.unsqueeze(0) + sigA * torch.randn((N, T), device=device, dtype=dtype))

    # 2) Compute posterior terms for C=1 and C=2
    vV = sigV_trial * sigV_trial  # (T,)
    vA = sigA * sigA
    vS = sigS * sigS
    iv_V = 1.0 / vV
    iv_A = torch.tensor(1.0 / vA, device=device, dtype=dtype)
    iv_s = torch.tensor(1.0 / vS, device=device, dtype=dtype)

    a = vV + vS
    b = vS
    d = vA + vS
    det = a * d - b * b
    inv00 = d / det
    inv11 = a / det
    inv01 = -b / det
    log_norm_c1 = -0.5 * (math.log((2 * math.pi) ** 2) + torch.log(det))

    vVbar = vV + vS
    vAbar = vA + vS
    log_norm_c2_V = -0.5 * (math.log(2 * math.pi) + torch.log(vVbar))
    # BUGFIX: use torch.log for tensor vAbar (was math.log)
    log_norm_c2_A = -0.5 * (math.log(2 * math.pi) + torch.log(vAbar))

    zV = x_V - mu_p
    zA = x_A - mu_p
    quad_c1 = inv00.unsqueeze(0) * (zV * zV) + 2 * inv01.unsqueeze(0) * (zV * zA) + inv11.unsqueeze(0) * (zA * zA)
    log_p_c1 = log_norm_c1.unsqueeze(0) - 0.5 * quad_c1
    log_p_c2 = (
        log_norm_c2_V.unsqueeze(0) - 0.5 * (zV * zV) / vVbar.unsqueeze(0)
        + log_norm_c2_A - 0.5 * (zA * zA) / vAbar
    )

    logit_pc1 = torch.log(p_same) + log_p_c1 - (torch.log1p(-p_same) + log_p_c2)
    post_c1 = torch.sigmoid(logit_pc1)

    # 3) Posterior means and model-averaged estimate
    weight_sum_c1 = iv_V + iv_A + iv_s
    weight_V = iv_V + iv_s
    weight_A = iv_A + iv_s

    mu_c1 = (iv_V.unsqueeze(0) * x_V + iv_A * x_A + iv_s * mu_p) / weight_sum_c1.unsqueeze(0)
    mu_c2_V = (iv_V.unsqueeze(0) * x_V + iv_s * mu_p) / weight_V.unsqueeze(0)
    mu_c2_A = (iv_A * x_A + iv_s * mu_p) / weight_A

    rt = response_types.unsqueeze(0)
    mu_c2 = torch.where(rt == 0, mu_c2_V, mu_c2_A)
    s_hat = post_c1 * mu_c1 + (1.0 - post_c1) * mu_c2

    # 4) Motor noise and lapses
    R_noisy = s_hat + sigM * torch.randn_like(s_hat)
    if float(lapse) > 0.0:
        mask = torch.rand_like(R_noisy) < lapse
        R_uniform = -45.0 + 90.0 * torch.rand_like(R_noisy)
        R_noisy = torch.where(mask, R_uniform, R_noisy)
    return R_noisy


def _sample_inputs(
    device: torch.device,
    num_points: int = 98,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate synthetic stimuli and per-trial visual level.

    Procedure:
        - S_A: discrete uniform over {-15, -10, -5, 0, 5, 10, 15}
        - With 50% probability set S_V = S_A; else S_V ~ Uniform(-20, 20)
        - response_types ~ Bernoulli(0.5) (0=BV, 1=BA)
        - V_levels ~ Uniform{0,1,2} indicating low/med/high σ_V

    Args:
        device: torch device for tensors.
        num_points: number of trials to generate.
        dtype: tensor dtype.

    Returns:
        (S_V, S_A, response_types, V_levels) each of shape (num_points,).
    """
    if num_points <= 0:
        raise ValueError("num_points must be positive.")

    sa_vals = torch.tensor([-15.0, -10.0, -5.0, 0.0, 5.0, 10.0, 15.0], device=device, dtype=dtype)
    S_A = sa_vals[torch.randint(0, sa_vals.numel(), (num_points,), device=device)]

    same_mask = torch.rand(num_points, device=device) < 0.5
    S_V_alt = -20.0 + 40.0 * torch.rand(num_points, device=device, dtype=dtype)
    S_V = torch.where(same_mask, S_A, S_V_alt).to(dtype)

    response_types = (torch.rand(num_points, device=device) < 0.5).to(torch.long)
    V_levels = torch.randint(0, 3, (num_points,), device=device, dtype=torch.long)
    return S_V, S_A, response_types, V_levels


def sample_bav_responses_random_trials(
    RHO_A: torch.Tensor,
    theta: torch.Tensor,
    *,
    N: int = 1,
    num_points: int = 98,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Convenience wrapper: sample random trials and simulate responses.

    Args:
        RHO_A: auditory rescaling factor ρ.
        theta: parameter vector (7-D new or 5-D compat).
        N: number of responses per trial.
        num_points: number of trials.
        device, dtype: torch placement and dtype.

    Returns:
        (S_V, S_A, response_types, V_levels, R_sim) with shapes:
            S_V, S_A: (num_points,)
            response_types, V_levels: (num_points,)
            R_sim: (N, num_points)
    """
    if device is None:
        device = torch.device("cpu")
    S_V, S_A, rt, V_levels = _sample_inputs(device, num_points, dtype)
    R_sim = sample_bav_responses(RHO_A, theta.to(device), S_V, S_A, rt, V_levels, N=N)
    return S_V, S_A, rt, V_levels, R_sim


class BavSampler:
    """Assembles (context, buffer, target) splits with 4-feature inputs.

    Features:
        x = (response_type, V_level, S_A, S_V), y = responses
    Parameter sampling:
        θ is sampled in 7-D unconstrained space with **truncated Gaussian** priors on
        log noise parameters (±2 SD; see `_sample_theta`); internal transforms per `_unpack_theta`.
    Notes:
        If num_context == 0, a 1×4 dummy context row [2.0, 0.0, 0.0, 0.0] is inserted.
    """
    THETA_DIM = 7

    def __init__(self, RHO_A: float = 4.0 / 3.0, device: str = "cpu", dtype: torch.dtype = torch.float32):
        self.RHO_A = RHO_A
        self.device = device
        self.dtype = dtype

    def _sample_theta(self, batch_size: int) -> torch.Tensor:
        """Sample θ in the unconstrained space with **truncated Gaussian** priors.

        Priors (applied to the *log* scale for noise std devs), each truncated at ±2 SD
        around the mean and **resampled** until within bounds:
        - log σV_low    ~ N(0.00, 1.5)
        - log σV_med    ~ N(log σV_low + 1, 1)
        - log σV_high   ~ N(log σV_med + 0.75, 0.5)
        - log σA        ~ N(1.75, 0.5)
        - log σs        ~ N(2.5, 1)
        - log σm        ~ N(0.00, 0.5)
        - logit(p_same) ~ N(1.5, 1.5)
        """
        # Use module's device/dtype when available
        try:
            p = next(self.parameters())
            device, dtype = p.device, p.dtype
        except Exception:
            device, dtype = torch.device("cpu"), torch.float32

        def trunc_normal(mean: torch.Tensor, std: float, max_iters: int = 100) -> torch.Tensor:
            """Vectorized truncated normal: resample until within [mean-2σ, mean+2σ]."""
            std_t = torch.as_tensor(std, device=device, dtype=dtype)
            x = mean + std_t * torch.randn_like(mean, device=device, dtype=dtype)
            lower = mean - 2.0 * std_t
            upper = mean + 2.0 * std_t

            mask = (x < lower) | (x > upper)
            iters = 0
            while mask.any() and iters < max_iters:
                n_bad = int(mask.sum().item())
                x_new = mean[mask] + std_t * torch.randn(n_bad, device=device, dtype=dtype)
                x = x.clone()
                x[mask] = x_new
                mask = (x < lower) | (x > upper)
                iters += 1

            if mask.any():
                # Extremely unlikely fallback: hard-clip remaining offenders to bounds
                x = torch.minimum(torch.maximum(x, lower), upper)
            return x

        # 1) Hierarchical visual-noise scales on log-scale
        mu_low  = torch.full((batch_size,), 0.0, device=device, dtype=dtype)
        log_sigmaV_low  = trunc_normal(mu_low, 1.5)

        mu_med  = log_sigmaV_low + 1.0
        log_sigmaV_med  = trunc_normal(mu_med, 1.0)

        mu_high = log_sigmaV_med + 0.75
        log_sigmaV_high = trunc_normal(mu_high, 0.5)

        # 2) Other log-scales (independent)
        log_sigmaA = trunc_normal(torch.full((batch_size,), 1.75, device=device, dtype=dtype), 0.5)
        log_sigmas = trunc_normal(torch.full((batch_size,), 2.5,  device=device, dtype=dtype), 1.0)
        log_sigmam = trunc_normal(torch.full((batch_size,), 0.0,  device=device, dtype=dtype), 0.5)

        # 3) Logit of p_same
        logit_p_same = trunc_normal(torch.full((batch_size,), 1.5, device=device, dtype=dtype), 1.5)

        # Stack in the order specified above
        theta = torch.stack(
            [
                log_sigmaV_low,
                log_sigmaV_med,
                log_sigmaV_high,
                log_sigmaA,
                log_sigmas,
                log_sigmam,
                logit_p_same,
            ],
            dim=1,
        )

        return theta

    def _sample_num_context(self, context_range: Tuple[int, int]) -> int:
        """Sample a context count from [low, high] inclusive, or uniformly from a list."""
        n = len(context_range)
        if n == 2:
            low, high = context_range
            if low > high:
                raise ValueError("Invalid `context_range`: low > high.")
            return torch.randint(low, high + 1, (1,)).item()
        if n > 2:
            choices = torch.as_tensor(context_range)
            return choices[torch.randint(0, choices.numel(), (1,))].item()
        raise ValueError("`context_range` must have length 2 or > 2.")

    def generate_batch(
        self,
        batch_size: int,
        num_context: Optional[int] = None,
        num_buffer: int = 50,
        num_target: int = 50,
        context_range: Tuple[int, int] = (3, 47),
    ) -> DataAttr:
        """Generate a batch of random trials split into (context, buffer, target).

        Args:
            batch_size: number of independent θ draws / tasks.
            num_context: number of context points per task (if None, sampled from range).
            num_buffer: number of buffer points per task.
            num_target: number of target points per task.
            context_range: used when num_context is None; inclusive [low, high].

        Returns:
            DataAttr with fields:
              - xc, yc: (B, C, 4), (B, C, 1)
              - xb, yb: (B, Bf, 4), (B, Bf, 1)
              - xt, yt: (B, T, 4), (B, T, 1)
        """
        if num_context is None:
            num_context = self._sample_num_context(context_range)
        num_total = num_context + num_buffer + num_target

        thetas = self._sample_theta(batch_size)
        S_Vs, S_As, rts, Vls, Y = [], [], [], [], []

        for th in thetas:
            S_V, S_A, rt, Vl, resp = sample_bav_responses_random_trials(
                self.RHO_A, th, num_points=num_total, device=self.device, dtype=self.dtype
            )
            S_Vs.append(S_V)
            S_As.append(S_A)
            rts.append(rt)
            Vls.append(Vl)
            Y.append(resp.squeeze(0))

        # Stack per-task tensors and assemble features
        S_Vs, S_As, rts, Vls = (torch.stack(t) for t in (S_Vs, S_As, rts, Vls))
        x = torch.stack((rts, Vls, S_As, S_Vs), dim=-1)   # (B, N, 4)
        y = torch.stack(Y).unsqueeze(-1)                  # (B, N, 1)

        # Random split indices (shared across tasks for simplicity)
        perm = torch.randperm(num_total, device=x.device)
        ctx_idx = perm[:num_context]
        buf_idx = perm[num_context : num_context + num_buffer]
        tar_idx = perm[num_context + num_buffer :]

        if num_context == 0:
            # Dummy one-row context when none requested
            xc = torch.tensor([2.0, 0.0, 0.0, 0.0], device=x.device, dtype=x.dtype)[None, None, :].expand(
                batch_size, 1, 4
            )
            yc = torch.tensor([0.0], device=y.device, dtype=y.dtype)[None, None, :].expand(batch_size, 1, 1)
        else:
            xc = x[:, ctx_idx]
            yc = y[:, ctx_idx]

        return DataAttr(xc=xc, yc=yc, xb=x[:, buf_idx], yb=y[:, buf_idx], xt=x[:, tar_idx], yt=y[:, tar_idx]), thetas

    def generate_test_batch(self, batch_size: int, num_target: int = 400) -> DataAttr:
        """No-context, all-target batch generation helper (useful for evaluation)."""
        return self.generate_batch(batch_size, num_context=0, num_buffer=0, num_target=num_target)

    def generate_paired_batches_same_trials(
        self,
        batch_size: int,
        num_context: Optional[int] = None,
        num_buffer: int = 50,
        num_target: int = 50,
        context_range: Tuple[int, int] = (3, 47),
        rho_a_first: float = 4.0 / 3.0,
        rho_a_second: float = 1.0,
    ) -> Tuple[DataAttr, DataAttr, torch.Tensor]:
        """Generate two datasets on the *same* θ and trials, differing only in RHO_A.

        Mirrors generate_batch(), but reuses the *identical* S_A, S_V, V_levels, and
        response_types for both simulations. Returns two DataAttr objects:
        the first using RHO_A=rho_a_first (default 4/3) and the second using
        RHO_A=rho_a_second (default 1), plus the shared θ tensor.

        Returns:
            (data_first, data_second, thetas)
                where each DataAttr contains the same x/splits but different y.
        """
        # Determine counts and total trials
        if num_context is None:
            num_context = self._sample_num_context(context_range)
        num_total = num_context + num_buffer + num_target

        # Sample θ once (shared)
        thetas = self._sample_theta(batch_size)

        # Collect per-task features and responses for both RHO_A settings
        S_Vs, S_As, rts, Vls = [], [], [], []
        Y_first, Y_second = [], []

        # Ensure device/dtype consistency
        dev = self.device
        dty = self.dtype

        for th in thetas:
            # Draw *one* set of trials per θ
            S_V, S_A, rt, Vl = _sample_inputs(dev, num_total, dty)
            # Simulate responses with two different RHO_A values on the same trials
            resp_first = sample_bav_responses(rho_a_first, th.to(dtype=dty), S_V, S_A, rt, Vl, N=1).squeeze(0)
            resp_second = sample_bav_responses(rho_a_second, th.to(dtype=dty), S_V, S_A, rt, Vl, N=1).squeeze(0)

            S_Vs.append(S_V)
            S_As.append(S_A)
            rts.append(rt)
            Vls.append(Vl)
            Y_first.append(resp_first)
            Y_second.append(resp_second)

        # Stack per-task tensors and assemble shared features
        S_Vs, S_As, rts, Vls = (torch.stack(t) for t in (S_Vs, S_As, rts, Vls))
        x_shared = torch.stack((rts, Vls, S_As, S_Vs), dim=-1)        # (B, N, 4)
        y_first = torch.stack(Y_first).unsqueeze(-1)                  # (B, N, 1)
        y_second = torch.stack(Y_second).unsqueeze(-1)                # (B, N, 1)

        # Shared random split
        perm = torch.randperm(num_total, device=x_shared.device)
        ctx_idx = perm[:num_context]
        buf_idx = perm[num_context : num_context + num_buffer]
        tar_idx = perm[num_context + num_buffer :]

        # Build DataAttr for both simulations (identical x/splits)
        if num_context == 0:
            # Dummy one-row context when none requested
            xc = torch.tensor([2.0, 0.0, 0.0, 0.0], device=x_shared.device, dtype=x_shared.dtype)[None, None, :].expand(
                batch_size, 1, 4
            )
            yc_first = torch.tensor([0.0], device=y_first.device, dtype=y_first.dtype)[None, None, :].expand(batch_size, 1, 1)
            yc_second = torch.tensor([0.0], device=y_second.device, dtype=y_second.dtype)[None, None, :].expand(batch_size, 1, 1)
        else:
            xc = x_shared[:, ctx_idx]
            yc_first = y_first[:, ctx_idx]
            yc_second = y_second[:, ctx_idx]

        data_first = DataAttr(
            xc=xc,
            yc=yc_first,
            xb=x_shared[:, buf_idx],
            yb=y_first[:, buf_idx],
            xt=x_shared[:, tar_idx],
            yt=y_first[:, tar_idx],
        )

        data_second = DataAttr(
            xc=xc,
            yc=yc_second,
            xb=x_shared[:, buf_idx],
            yb=y_second[:, buf_idx],
            xt=x_shared[:, tar_idx],
            yt=y_second[:, tar_idx],
        )

        return data_first, data_second, thetas


# -------------------------------
# Maximum-Likelihood (gradient) fit
# -------------------------------

def _default_init_theta(theta_dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """Provide a sensible default initialization in unconstrained space.
    For 7D: [log σV_low, log σV_med, log σV_high, log σA, log σs, log σm, logit(p_same)]
    For 5D: [log σV, log σA, log σs, log σm, logit(p_same)]
    """
    if theta_dim == 7:
        init = torch.tensor([0.0, 0.75, 1.5, 0.75, 0.0, 0.0, 0.0], device=device, dtype=dtype)
    elif theta_dim == 5:
        init = torch.tensor([0.75, 0.75, 0.0, 0.0, 0.0], device=device, dtype=dtype)
    else:
        raise ValueError("theta_dim must be 5 or 7.")
    return init


def fit_bav_mle(
    RHO_A: torch.Tensor,
    R: torch.Tensor,
    S_V: torch.Tensor,
    S_A: torch.Tensor,
    response_types: torch.Tensor,
    V_levels: torch.Tensor,
    *,
    init_theta: Optional[torch.Tensor] = None,
    theta_dim: int = 7,
    lr: float = 0.05,
    steps: int = 500,
    gh_deg: int = 31,
    chunk_size: Optional[int] = None,
    weight_decay: float = 0.0,
    patience: Optional[int] = 50,
    verbose: bool = False,
    return_history: bool = False,
) -> dict:
    """Fit θ by minimizing the summed NLL with Adam.

    Args:
        RHO_A: auditory rescaling factor ρ (float/tensor; broadcasts).
        R, S_V, S_A, response_types, V_levels: trial tensors, shape (T,).
        init_theta: optional initial θ in the unconstrained space.
        theta_dim: 7 (multi-level σV) or 5 (compat single σV).
        lr, steps: optimizer settings. Reduce `gh_deg` (e.g., 31) for speed.
        gh_deg, chunk_size: forwarded to NLL for performance control.
        weight_decay: L2 on unconstrained θ (Adam's built-in weight decay).
        patience: early stopping patience on best loss (None disables).
        verbose: print progress every ~10% steps.
        return_history: include per-step loss history in the result dict.

    Returns:
        dict with:
          - 'theta_opt': best unconstrained θ (torch.Tensor)
          - 'theta_transformed': dict of transformed parameters
          - 'loss': best NLL (float)
          - 'history': list of floats (when requested)
          - 'steps_run': int
    """
    # Normalize inputs
    R = R.flatten()
    S_V = S_V.flatten()
    S_A = S_A.flatten()
    response_types = response_types.to(torch.long).flatten()
    V_levels = V_levels.to(torch.long).flatten()

    device = R.device
    dtype = R.dtype
    if init_theta is None:
        init_theta = _default_init_theta(theta_dim, device, dtype)

    theta = torch.nn.Parameter(init_theta.clone().detach().to(device=device, dtype=dtype))
    opt = torch.optim.Adam([theta], lr=lr, weight_decay=weight_decay)

    history = []
    best_loss = None
    best_theta = None
    since_improve = 0

    # Simple progress cadence
    log_every = max(1, steps // 10)

    for t in range(1, steps + 1):
        opt.zero_grad(set_to_none=True)
        loss = nll_bav_constant_gaussian(
            RHO_A, theta, R, S_V, S_A, response_types, V_levels, gh_deg=gh_deg, chunk_size=chunk_size
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_([theta], max_norm=5.0)
        opt.step()

        if return_history:
            history.append(float(loss.item()))

        # Track best
        if best_loss is None or loss.item() < float(best_loss):
            best_loss = loss.detach()
            best_theta = theta.detach().clone()
            since_improve = 0
        else:
            since_improve += 1

        if verbose and (t % log_every == 0 or t == 1 or t == steps):
            print(f"[fit_bav_mle] step {t}/{steps}  NLL={loss.item():.3f}")

        if patience is not None and since_improve >= patience:
            if verbose:
                print(f"[fit_bav_mle] Early stopping at step {t} (patience {patience})")
            break
        
    # Transform best θ to human-readable parameters
    sigV_low, sigV_med, sigV_high, sigA, sigS, sigM, lapse, p_same, _ = _unpack_theta(
        best_theta, device, dtype
    )
    transformed = {
        "sigma_V_low": float(sigV_low.item()),
        "sigma_V_med": float(sigV_med.item()),
        "sigma_V_high": float(sigV_high.item()),
        "sigma_A": float(sigA.item()),
        "sigma_s": float(sigS.item()),
        "sigma_m": float(sigM.item()),
        "p_same": float(p_same.item()),
        "lapse": float(lapse.item()),
    }

    out = {
        "theta_opt": best_theta,
        "theta_transformed": transformed,
        "loss": float(best_loss.item()),
        "steps_run": len(history) if return_history else (t),
    }
    if return_history:
        out["history"] = history
    return out


def fit_bav_mle_from_xy(
    RHO_A: torch.Tensor,
    x: torch.Tensor,
    y: torch.Tensor,
    **kwargs,
) -> dict:
    """Convenience wrapper to fit θ directly from (x, y) like those produced by BavSampler.

    Args:
        RHO_A: auditory rescaling factor ρ.
        x: (T, 4) features = (response_type, V_level, S_A, S_V)
        y: (T,) or (T,1) responses

    Returns:
        Same dict as fit_bav_mle(...).
    """
    if x.ndim != 2 or x.shape[-1] != 4:
        raise ValueError("x must have shape (T, 4) with columns [response_type, V_level, S_A, S_V].")
    R = y.reshape(-1).to(x.device, x.dtype)
    rt = x[:, 0].to(torch.long)
    vl = x[:, 1].to(torch.long)
    S_A = x[:, 2].to(x.dtype)
    S_V = x[:, 3].to(x.dtype)
    return fit_bav_mle(RHO_A, R, S_V, S_A, rt, vl, **kwargs)
