import numpy as np
from scipy.integrate import quad
import multiprocessing as mp
from functools import partial
import os
import torch

def laplace_fusion_mean(
    mu_base,
    sigma_base,
    mu2_mat,
    sigma2_mat,
    return_component_norms=False,
    divide_unit_laplace=True,
    moment='mean',
    n_jobs=1,
    mu_div=None,
    sigma_div=None,
    threshold=0.0,
    component_weights=None,
    natural_prior_weight=None,
):
    """
    Compute the analytic mean of a multivariate mixture of Laplace densities.
    
    When `divide_unit_laplace` is True (default), the density is proportional to:
      base_distribution * mixture_distribution / unit_laplace
    Otherwise it is proportional to:
      base_distribution * mixture_distribution
    where:
      - base_distribution: Laplace(mu_base, sigma_base)
      - mixture_distribution: sum_j Laplace(mu2_mat[j], sigma2_mat[j]) 
      - unit_laplace: Laplace(0, 1)
    
    Parameters
    ----------
    mu_base : array-like
        Mean vector of the base Laplace distribution, shape (N,).
    sigma_base : array-like
        Scale vector of the base Laplace distribution, shape (N,).
    mu2_mat : array-like
        Mean matrix for mixture components, shape (M, N) where:
        - M is the number of mixture components
        - N is the dimensionality of the space
    sigma2_mat : array-like
        Scale matrix for mixture components, shape (M, N).
    component_weights : array-like or None, optional
        Optional non-negative weights for each mixture component, shape (M,).
        If None (default), all components are weighted equally (weight = 1).
        The unnormalised fused density becomes::

            f(x) ∝ Laplace_base(x) · \sum_j w_j · Laplace_j(x)

        where ``w_j`` are the supplied ``component_weights``.  A common scaling
        factor applied to all weights is irrelevant (only relative magnitudes
        matter), so the array does **not** need to sum to 1.
    natural_prior_weight : float or None, optional
        If not None, adds a unit Laplace component (zero mean, unit scale) to the
        mixture with the specified weight. This is equivalent to adding::

            natural_prior_weight * Laplace(x; 0, 1)

        to the mixture. The component is added after any user-provided components,
        so the total number of components becomes M + 1 when this option is used.
        Must be positive if provided.
    return_component_norms : bool, optional
        If True, also return the normalization constants for each component.
        Default is False.
    divide_unit_laplace : bool, optional
        If True, divide by the unit Laplace density. Default is True.
    moment : {'mean', 'abs'}, optional
        Which moment to compute: 'mean' (E[X], default) or 'abs' (E[|X|]).
    n_jobs : int, optional
        Number of parallel jobs for dimension-level parallelization.
        - None (default): Use serial computation
        - 1: Use serial computation
        - > 1: Use parallel computation with specified number of jobs
        - -1: Use all available CPU cores
        
    Returns
    -------
    np.ndarray or tuple
        If return_component_norms is False:
            Mean vector of the resulting distribution, shape (N,).
        If return_component_norms is True:
            Tuple (mean, component_norms, component_norms_per_dim) where:
            - mean: Mean vector of the resulting distribution, shape (N,)
            - component_norms: Normalization constants for each component, shape (M,)
            - component_norms_per_dim: Normalization constants for each component and dimension, shape (M, N)
    
    Notes
    -----
    Uses log-space calculations for numerical stability in high dimensions.
    The calculation is performed in a two-pass algorithm:
    1. First pass: compute normalization constants for each mixture component
    2. Second pass: compute the overall mean with numerical stabilization
    
    When n_jobs > 1, dimensions are processed in parallel for improved performance
    on large problems. For small problems, serial computation may be faster due
    to parallelization overhead.
    """
    # Convert inputs to numpy arrays
    mu_base, sigma_base = np.asarray(mu_base), np.asarray(sigma_base)
    mu2_mat, sigma2_mat = np.asarray(mu2_mat), np.asarray(sigma2_mat)

    # ------------------------------------------------------------------
    # Optional threshold handling for ReLU moments
    # ------------------------------------------------------------------
    # For the ReLU moment we sometimes want E[ReLU(X - x0)].  This can be
    # obtained by a change of variables y = x - x0 which results in the same
    # calculation as E[ReLU(y)] for a distribution whose location parameters
    # are all shifted by -x0.  We therefore implement a simple translation of
    # all mean parameters when a non-zero ``threshold`` is supplied.
    # NOTE: We intentionally restrict the shift to the ReLU moment so that the
    # semantics of ``mean`` / ``abs`` moments remain unchanged.
    # ------------------------------------------------------------------
    if moment == 'relu' and (threshold is not None) and np.any(np.asarray(threshold) != 0):
        mu_base = mu_base - threshold
        mu2_mat = mu2_mat - threshold
    M, N = mu2_mat.shape  # M: number of mixture components, N: dimensions

    # ------------------------------------------------------------------
    # Handle optional component weights (mixture coefficients)
    # ------------------------------------------------------------------
    if component_weights is None:
        component_weights_arr = np.ones(M)
    else:
        component_weights_arr = np.asarray(component_weights, dtype=float)
        if component_weights_arr.ndim == 0:
            component_weights_arr = np.full(M, component_weights_arr)
        elif len(component_weights_arr) != M:
            raise ValueError(
                f"component_weights must have length {M}, got {len(component_weights_arr)}"
            )
        if np.any(component_weights_arr < 0):
            raise ValueError("component_weights must be non-negative")

    # Normalising the weights is not required – they act as relative coefficients.

    # ------------------------------------------------------------------
    # Handle natural prior weight (adds unit Laplace component)
    # ------------------------------------------------------------------
    if natural_prior_weight is not None and natural_prior_weight > 0:
        # Add a unit Laplace component (zero mean, unit scale) to the mixture
        # Extend mu2_mat and sigma2_mat with the unit Laplace component
        mu2_mat_extended = np.vstack([mu2_mat, np.zeros((1, N))])  # Zero mean
        sigma2_mat_extended = np.vstack([sigma2_mat, np.ones((1, N))])  # Unit scale
        
        # Extend component weights if provided
        if component_weights_arr is not None:
            component_weights_arr = np.append(component_weights_arr, natural_prior_weight)
        else:
            component_weights_arr = np.append(np.ones(M), natural_prior_weight)
        
        # Update M to reflect the new number of components
        M = M + 1
        mu2_mat = mu2_mat_extended
        sigma2_mat = sigma2_mat_extended
    elif natural_prior_weight is not None and natural_prior_weight <= 0:
        raise ValueError("natural_prior_weight must be positive if provided")

    # Determine whether to divide by an additional Laplace distribution.
    # If the user supplied non-None `mu_div` or `sigma_div`, we force division.
    if (mu_div is not None) or (sigma_div is not None):
        divide_unit_laplace = True

    if divide_unit_laplace:
        # Default to unit Laplace if the user did not supply explicit parameters.
        if mu_div is None:
            mu_div = 0.0
        if sigma_div is None:
            sigma_div = 1.0

        # Broadcast the denominator parameters to shape (N,)
        mu_div = np.asarray(mu_div)
        sigma_div = np.asarray(sigma_div)
        if mu_div.ndim == 0:
            mu_div = np.full(N, mu_div)
        else:
            mu_div = np.broadcast_to(mu_div, (N,))

        if sigma_div.ndim == 0:
            sigma_div = np.full(N, sigma_div)
        else:
            sigma_div = np.broadcast_to(sigma_div, (N,))

        # Apply the same shift to the denominator means for ReLU moments.
        if moment == 'relu' and (threshold is not None) and np.any(np.asarray(threshold) != 0):
            mu_div = mu_div - threshold
    else:
        # Placeholders – will be ignored downstream when divide_unit_laplace is False.
        mu_div = np.zeros(N)
        sigma_div = np.ones(N)
    
    # Determine if we should use parallel computation
    use_parallel = n_jobs is not None and n_jobs != 1 and N > 1
    
    if use_parallel:
        # Determine number of jobs
        if n_jobs == -1:
            n_jobs = mp.cpu_count()
        else:
            n_jobs = min(n_jobs, mp.cpu_count(), N)  # Don't use more jobs than dimensions
        
        # Use parallel computation
        return _laplace_fusion_mean_parallel_impl(
            M,
            N,
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            mu_div,
            sigma_div,
            return_component_norms,
            divide_unit_laplace,
            moment,
            n_jobs,
            threshold,
            component_weights_arr,
            natural_prior_weight,
        )
    else:
        # Use serial computation
        return _laplace_fusion_mean_serial_impl(
            M,
            N,
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            mu_div,
            sigma_div,
            return_component_norms,
            divide_unit_laplace,
            moment,
            threshold,
            component_weights_arr,
            natural_prior_weight,
        )

def _laplace_fusion_mean_serial_impl(
    M,
    N,
    mu_base,
    sigma_base,
    mu2_mat,
    sigma2_mat,
    mu_div,
    sigma_div,
    return_component_norms=False,
    divide_unit_laplace=True,
    moment='mean',
    threshold=0.0,
    component_weights_arr=None,
    natural_prior_weight=None,
):
    """
    Serial implementation of laplace_fusion_mean.
    """
    # Compute moments for each component and dimension
    component_moments = _compute_component_moments(
        M,
        N,
        mu_base,
        sigma_base,
        mu2_mat,
        sigma2_mat,
        divide_unit_laplace,
        moment,
        mu_div,
        sigma_div,
        threshold,
        component_weights_arr,
        natural_prior_weight,
    )
    log_Z_per_component, Ms_values, log_Z_values = component_moments

    # Apply component weights (if provided)
    if component_weights_arr is not None:
        # Replace zero weights with -inf in log domain to effectively ignore them
        log_w = np.full_like(log_Z_per_component, -np.inf)
        positive_mask = component_weights_arr > 0
        log_w[positive_mask] = np.log(component_weights_arr[positive_mask])
        log_Z_per_component = log_Z_per_component + log_w
    
    # Find maximum log normalization constant for numerical stability
    max_log_Z = np.max(log_Z_per_component)
    
    # Compute the final mean with the max shift for stability
    total_Z, total_N = _compute_stabilized_mean(
        M, N, log_Z_per_component, Ms_values, log_Z_values, max_log_Z
    )
    
    # The max_log_Z factor cancels out in the division
    mean_result = total_N / total_Z
    
    if return_component_norms:
        # Convert log normalization constants back to actual values
        component_norms = np.exp(log_Z_per_component)
        # Also return individual normalization constants per dimension
        component_norms_per_dim = np.exp(log_Z_values)
        return mean_result, component_norms, component_norms_per_dim
    else:
        return mean_result

def _laplace_fusion_mean_parallel_impl(
    M,
    N,
    mu_base,
    sigma_base,
    mu2_mat,
    sigma2_mat,
    mu_div,
    sigma_div,
    return_component_norms=False,
    divide_unit_laplace=True,
    moment='mean',
    n_jobs=None,
    threshold=0.0,
    component_weights_arr=None,
    natural_prior_weight=None,
):
    """
    Parallel implementation of laplace_fusion_mean.
    """
    # Prepare arguments for parallel processing
    args_list = []
    for i in range(N):
        args = (
            i,
            mu_base[i],
            sigma_base[i],
            mu2_mat[:, i],
            sigma2_mat[:, i],
            mu_div[i],
            sigma_div[i],
            divide_unit_laplace,
            moment,
            threshold,
            component_weights_arr,
            natural_prior_weight,
        )
        args_list.append(args)
    
    # Process dimensions in parallel
    with mp.Pool(processes=n_jobs) as pool:
        results = pool.map(_process_dimension_parallel, args_list)
    
    # Reconstruct the results
    log_Z_values = np.zeros((M, N))  # Log normalization per dimension
    Ms_values = np.zeros((M, N))     # First moments
    
    for i, log_Zs_i, Ms_i in results:
        log_Z_values[:, i] = log_Zs_i
        Ms_values[:, i] = Ms_i
    
    # Compute total log normalization for each component
    log_Z_per_component = np.sum(log_Z_values, axis=1)

    # Apply component weights (if provided)
    if component_weights_arr is not None:
        # Replace zero weights with -inf in log domain to effectively ignore them
        log_w = np.full_like(log_Z_per_component, -np.inf)
        positive_mask = component_weights_arr > 0
        log_w[positive_mask] = np.log(component_weights_arr[positive_mask])
        log_Z_per_component = log_Z_per_component + log_w
    
    # Find maximum log normalization constant for numerical stability
    max_log_Z = np.max(log_Z_per_component)
    
    # Compute the final mean with the max shift for stability
    total_Z, total_N = _compute_stabilized_mean(
        M, N, log_Z_per_component, Ms_values, log_Z_values, max_log_Z
    )
    
    # The max_log_Z factor cancels out in the division
    mean_result = total_N / total_Z
    
    if return_component_norms:
        # Convert log normalization constants back to actual values
        component_norms = np.exp(log_Z_per_component)
        # Also return individual normalization constants per dimension
        component_norms_per_dim = np.exp(log_Z_values)
        return mean_result, component_norms, component_norms_per_dim
    else:
        return mean_result

def _compute_component_moments(
    M,
    N,
    mu_base,
    sigma_base,
    mu2_mat,
    sigma2_mat,
    divide_unit_laplace=True,
    moment="mean",
    mu_div=None,
    sigma_div=None,
    threshold=0.0,
    component_weights_arr=None,
    natural_prior_weight=None,
):
    """
    Compute the normalization constants and moments for each mixture component.
    
    Parameters
    ----------
    M : int
        Number of mixture components.
    N : int
        Number of dimensions.
    mu_base, sigma_base, mu2_mat, sigma2_mat : array-like
        Distribution parameters.
    divide_unit_laplace : bool
        If True, divide by the unit Laplace density.
    moment : {'mean', 'abs'}
        Moment type.
        
    Returns
    -------
    tuple
        (log_Z_per_component, Ms_values, log_Z_values) containing:
        - log_Z_per_component: Log normalization constant for each component, shape (M,)
        - Ms_values: First moment values for each component and dimension, shape (M, N)
        - log_Z_values: Log normalization for each component and dimension, shape (M, N)
    """
    # Initialize storage arrays
    log_Z_per_component = np.zeros(M)  # Log normalization for each component
    Ms_values = np.zeros((M, N))       # First moments
    log_Z_values = np.zeros((M, N))    # Log normalization per dimension
    
    # Compute moments for each mixture component
    for j in range(M):
        # For each dimension, compute normalization Z_i and moment M_i
        log_Zs = np.zeros(N)
        Ms = np.zeros(N)
        
        for i in range(N):
            # Compute univariate moments for this dimension
            Z_i, M_i = univariate_laplace_mixture_moment(
                mu_base[i],
                sigma_base[i],
                mu2_mat[j, i],
                sigma2_mat[j, i],
                divide_unit_laplace,
                moment,
                mu_div[i],
                sigma_div[i],
                threshold,
                component_weights_arr[j] if component_weights_arr is not None else 1.0,
                natural_prior_weight,
            )
            
            # Store log(Z_i) to avoid overflow when multiplying across dimensions
            log_Zs[i] = np.log(Z_i) if Z_i > 0 else -np.inf
            Ms[i] = M_i
            
            # Save values for later use
            Ms_values[j, i] = M_i
            log_Z_values[j, i] = log_Zs[i]
        
        # Total log normalization for component j is the sum of log normalizations
        # across all dimensions (equivalent to product of normalizations)
        log_Z_per_component[j] = np.sum(log_Zs)
    
    return log_Z_per_component, Ms_values, log_Z_values

def _compute_stabilized_mean(M, N, log_Z_per_component, Ms_values, log_Z_values, max_log_Z):
    """
    Compute the stabilized mean using max-shift technique to prevent overflow.
    
    Parameters
    ----------
    M : int
        Number of mixture components.
    N : int
        Number of dimensions.
    log_Z_per_component : array-like
        Log normalization constant for each component, shape (M,).
    Ms_values : array-like
        First moment values for each component and dimension, shape (M, N).
    log_Z_values : array-like
        Log normalization for each component and dimension, shape (M, N).
    max_log_Z : float
        Maximum log normalization constant used for stabilization.
        
    Returns
    -------
    tuple
        (total_Z, total_N) containing:
        - total_Z: Total normalization constant (scalar)
        - total_N: Unnormalized mean vector, shape (N,)
    """
    total_Z = 0.0
    total_N = np.zeros(N)
    
    for j in range(M):
        # Get log normalization for this component
        log_Z_j = log_Z_per_component[j]
        
        # Apply max-shift for numerical stability
        Z_j_shifted = np.exp(log_Z_j - max_log_Z)
        
        # Accumulate for the normalization constant
        total_Z += Z_j_shifted
        
        # For each dimension, compute the contribution to the mean
        for i in range(N):
            M_i = Ms_values[j, i]
            log_Z_i = log_Z_values[j, i]
            
            # Only compute if both Z_j and Z_i are non-zero/finite
            if np.isfinite(log_Z_j) and np.isfinite(log_Z_i):
                # Compute N_j[i] = M_i * Z_j / Z_i with the max-shift
                # This represents the contribution of component j to dimension i
                N_j_i = M_i * np.exp(log_Z_j - log_Z_i - max_log_Z)
                total_N[i] += N_j_i
    
    return total_Z, total_N

def univariate_laplace_mixture_moment(
    mu_base,
    sigma_base,
    mu2,
    sigma2,
    divide_unit_laplace=True,
    moment='mean',
    mu_div=None,
    sigma_div=None,
    threshold=0.0,
    component_weight=1.0,
    natural_prior_weight=None,
):
    """
    Compute normalization Z and first moment M for the 1D density.
    
    Calculates the normalization constant and first moment for the density:
        If `divide_unit_laplace` is True (default):
            f(x) ∝ Laplace(x; mu_base, sigma_base)
                   * Laplace(x; mu2, sigma2)
                   / Laplace(x; 0, 1)
        Else:
            f(x) ∝ Laplace(x; mu_base, sigma_base) * Laplace(x; mu2, sigma2)
    
    Parameters
    ----------
    mu_base : float
        Mean of the base Laplace distribution.
    sigma_base : float
        Scale of the base Laplace distribution.
    mu2 : float
        Mean of the mixture component Laplace distribution.
    sigma2 : float
        Scale of the mixture component Laplace distribution.
    divide_unit_laplace : bool
        If True, divide by the unit Laplace density.
    moment : str, optional
        'mean' for the mean (E[X]), 'abs' for the expected value of |X|.
        
    Returns
    -------
    tuple
        (Z, M) where:
        - Z = ∫ f(x) dx : Normalization constant
        - M = ∫ x f(x) dx : First moment (unnormalized)
        
    Notes
    -----
    Uses piecewise analytic formula with max-shift for numerical stability.
    The calculation divides the real line into regions between "kink points"
    where the Laplace density changes form, and computes integrals separately
    in each region.
    """
    # Convert scale parameters to rate parameters
    rate_base = 1.0 / sigma_base
    rate_mix = 1.0 / sigma2

    if divide_unit_laplace:
        # Provide defaults for unit Laplace if parameters are missing
        if mu_div is None:
            mu_div = 0.0
        if sigma_div is None:
            sigma_div = 1.0
        rate_div = 1.0 / sigma_div
    else:
        # Placeholders (unused)
        mu_div = 0.0
        sigma_div = 1.0
        rate_div = 1.0
 
    # Define and sort kink points (points where the PDF changes form)
    if divide_unit_laplace:
        kink_points_set = {-np.inf, mu_base, mu2, mu_div, np.inf}
    else:
        kink_points_set = {-np.inf, mu_base, mu2, np.inf}

    # For absolute and ReLU moments, ensure we add 0 as a kink so that regions do not
    # cross the sign-change point where |x| or ReLU(x) changes form. This avoids the
    # straddle-zero assertion inside the region integrators.
    if moment in ("abs", "relu"):
        kink_points_set.add(0.0)

    kink_points = sorted(kink_points_set)

    # Analyze each region between kink points
    regions_data, exponents = _analyze_density_regions(
        kink_points,
        mu_base,
        mu2,
        rate_base,
        rate_mix,
        divide_unit_laplace,
        mu_div,
        rate_div,
    )
    
    # Find maximum exponent for numerical stability
    max_exponent = max(exponents)
    
    # Compute integrals with numerical stabilization
    Z, M = _compute_stable_moments(regions_data, max_exponent, moment)
    
    return Z, M

def _analyze_density_regions(
    kink_points,
    mu_base,
    mu2,
    rate_base,
    rate_mix,
    divide_unit_laplace=True,
    mu_div=None,
    rate_div=None,
):
    """
    Analyze the regions between kink points to determine the density parameters.
    
    Parameters
    ----------
    kink_points : list
        Sorted list of points where the density changes form.
    mu_base, mu2 : float
        Mean parameters.
    rate_base, rate_mix : float
        Rate parameters (1/sigma).
    divide_unit_laplace : bool
        If True, divide by the unit Laplace density.
        
    Returns
    -------
    tuple
        (regions_data, exponents) containing:
        - regions_data: List of tuples (x_lo, x_hi, kappa, d) for each region
        - exponents: List of exponents at region endpoints for max-shift
    """
    regions_data = []
    exponents = []
    
    # Create intervals between adjacent kink points
    intervals = list(zip(kink_points[:-1], kink_points[1:]))
    
    # Process each interval
    for x_lo, x_hi in intervals:
        # Skip degenerate intervals
        if x_lo == x_hi:
            continue
            
        # Choose a sample point in the interval to determine signs
        sample_point = _choose_interval_sample(x_lo, x_hi)
        
        # Calculate signs based on the sample point
        sign_base = np.sign(sample_point - mu_base)
        sign_mix = np.sign(sample_point - mu2)

        if divide_unit_laplace:
            # Contribution from the (possibly non-unit) Laplace in the denominator
            if rate_div is None:
                rate_div = 1.0  # default unit rate (sigma=1) if not supplied
            # sign relative to the denominator mean
            sign_div = np.sign(sample_point - mu_div)
            kappa = (
                -rate_base * sign_base
                - rate_mix * sign_mix
                + rate_div * sign_div
            )
        else:
            # Omit the unit Laplace term
            kappa = -rate_base * sign_base - rate_mix * sign_mix
        # Offset term (independent of whether we divide by unit Laplace)
        if divide_unit_laplace:
            d = (
                rate_base * sign_base * mu_base
                + rate_mix * sign_mix * mu2
                - rate_div * sign_div * mu_div
            )
        else:
            d = rate_base * sign_base * mu_base + rate_mix * sign_mix * mu2
        
        # Check for divergence in infinite tails
        _check_tail_convergence(x_lo, x_hi, kappa)
        
        # Store region data
        regions_data.append((x_lo, x_hi, kappa, d))
        
        # Collect exponents at finite endpoints for max-shift
        if np.isfinite(x_lo):
            exponents.append(d + kappa * x_lo)
        if np.isfinite(x_hi):
            exponents.append(d + kappa * x_hi)
    
    return regions_data, exponents

def _choose_interval_sample(x_lo, x_hi):
    """
    Choose a representative sample point within an interval.
    
    Parameters
    ----------
    x_lo, x_hi : float
        Lower and upper bounds of the interval.
        
    Returns
    -------
    float
        A sample point within the interval.
    """
    if np.isneginf(x_lo):
        return x_hi - 1.0  # Sample point 1 unit below upper bound
    elif np.isposinf(x_hi):
        return x_lo + 1.0  # Sample point 1 unit above lower bound
    else:
        return 0.5 * (x_lo + x_hi)  # Midpoint of interval

def _check_tail_convergence(x_lo, x_hi, kappa):
    """
    Check for convergence of the density in infinite tails.
    
    Parameters
    ----------
    x_lo, x_hi : float
        Lower and upper bounds of the interval.
    kappa : float
        Slope of the log-density in the region.
        
    Raises
    ------
    ValueError
        If the density diverges in an infinite tail.
    """
    if np.isneginf(x_lo) and kappa <= 0:
        raise ValueError("Divergent lower tail in univariate moment")
    if np.isposinf(x_hi) and kappa >= 0:
        raise ValueError("Divergent upper tail in univariate moment")

def _compute_stable_moments(regions_data, max_exponent, moment='mean'):
    """
    Compute the normalization constant Z and first moment M with numerical stabilization.
    
    Parameters
    ----------
    regions_data : list
        List of tuples (x_lo, x_hi, kappa, d) for each region.
    max_exponent : float
        Maximum exponent used for numerical stabilization.
    moment : str, optional
        'mean' for the mean (E[X]), 'abs' for the expected value of |X|.
        
    Returns
    -------
    tuple
        (Z, M) containing normalization constant and first moment.
    """
    Z_scaled = 0.0
    M_scaled = 0.0
    
    # Process each region
    for x_lo, x_hi, kappa, d in regions_data:
        # Apply max-shift for numerical stability
        d_shifted = d - max_exponent
        
        # Handle constant and exponential regions with built-in moment handling
        if abs(kappa) < 1e-12:
            Z_i, M_i = _integrate_constant_region(x_lo, x_hi, d_shifted, moment)
        else:
            Z_i, M_i = _integrate_exponential_region(x_lo, x_hi, kappa, d_shifted, moment)

        # Accumulate contributions (Z_i may have been updated above for straddling case)
        Z_scaled += Z_i
        M_scaled += M_i
    
    # Apply exponent correction
    scaling_factor = np.exp(max_exponent)
    Z = Z_scaled * scaling_factor
    M = M_scaled * scaling_factor
    
    return Z, M

def _integrate_constant_region(x_lo, x_hi, d_shifted, moment='mean'):
    """
    Compute integral contributions for a constant log-density region.
    
    Parameters
    ----------
    x_lo, x_hi : float
        Lower and upper bounds of the region.
    d_shifted : float
        Shifted offset parameter.
    moment : str, optional
        'mean' for the mean (E[X]), 'abs' for the expected value of |X|.
        
    Returns
    -------
    tuple
        (Z_i, M_i) containing:
        - Z_i: Contribution to normalization constant
        - M_i: Contribution to first moment
    """
    # Constant term e^(d_shifted)
    val = np.exp(d_shifted)
    
    # For constant function f(x) = val, the integral is val * (x_hi - x_lo)
    Z_i = val * (x_hi - x_lo)
    
    # For x * f(x) = x * val, the integral is val * (x_hi^2 - x_lo^2) / 2
    M_i = val * (x_hi**2 - x_lo**2) / 2.0

    if moment == 'abs':
        # Determine the sign of x in this region (cannot straddle 0)
        assert not (x_lo < 0 < x_hi), "Region straddles zero in constant region integration"
        sign = -1 if x_hi <= 0 else 1  # region fully negative or non-negative
        M_i *= sign
    elif moment == 'relu':
        # For ReLU, the integrand is zero for x<0. If the entire region is non-positive
        # then the first moment contribution is zero. If the region is non-negative we
        # keep the computed value. Regions should never straddle zero because 0 is a
        # kink point that we explicitly added.
        if x_hi <= 0:
            M_i = 0.0
        elif x_lo >= 0:
            pass  # keep computed M_i
        else:
            assert False, "Region straddles zero in constant region integration"
    
    return Z_i, M_i

def _integrate_exponential_region(x_lo, x_hi, kappa, d_shifted, moment='mean'):
    """
    Compute integral contributions for an exponential log-density region.
    
    Parameters
    ----------
    x_lo, x_hi : float
        Lower and upper bounds of the region.
    kappa : float
        Slope of the log-density.
    d_shifted : float
        Shifted offset parameter.
    moment : str, optional
        'mean' for the mean (E[X]), 'abs' for the expected value of |X|.
        
    Returns
    -------
    tuple
        (Z_i, M_i) containing:
        - Z_i: Contribution to normalization constant
        - M_i: Contribution to first moment
    """
    # Handle the lower bound
    if np.isneginf(x_lo):
        # For -infinity, the function values are 0
        e_lo = 0.0
        xe_lo = 0.0
    else:
        # Calculate e^(d + kappa * x_lo - max_exp)
        e_lo = np.exp(d_shifted + kappa * x_lo)
        xe_lo = x_lo * e_lo
    
    # Handle the upper bound
    if np.isposinf(x_hi):
        # For +infinity, the function values are 0
        e_hi = 0.0
        xe_hi = 0.0
    else:
        # Calculate e^(d + kappa * x_hi - max_exp)
        e_hi = np.exp(d_shifted + kappa * x_hi)
        xe_hi = x_hi * e_hi
    
    # Compute integral for f(x) = e^(d + kappa * x - max_exp)
    # ∫ f(x) dx = [e^(d + kappa * x - max_exp) / kappa]_x_lo^x_hi
    Z_i = (e_hi - e_lo) / kappa
    
    # Compute integral for x * f(x) = x * e^(d + kappa * x - max_exp)
    # ∫ x * f(x) dx = [x * e^(d + kappa * x) / kappa - e^(d + kappa * x) / kappa^2]_x_lo^x_hi
    M_i = (xe_hi - xe_lo) / kappa - (e_hi - e_lo) / (kappa**2)

    if moment == 'abs':
        # Ensure region does not straddle zero
        assert not (x_lo < 0 < x_hi), "Region straddles zero in exponential region integration"
        sign = -1 if x_hi <= 0 else 1
        M_i *= sign
    elif moment == 'relu':
        # For ReLU, contribution is zero for x<0.
        assert not (x_lo < 0 < x_hi), "Region straddles zero in exponential region integration"
        if x_hi <= 0:
            M_i = 0.0
        # If region is non-negative, keep M_i as-is.
    
    return Z_i, M_i

def laplace_fusion_mean_broadcast(
    N,
    mu_base_1d,
    sigma_base_1d,
    mu2_m,
    sigma2_m,
    return_component_norms=False,
    moment="mean",
    n_jobs=None,
    mu_div_1d=None,
    sigma_div_1d=None,
    threshold: float = 0.0,
    component_weights_1d=None,
    natural_prior_weight=None,
):
    """
    Wrapper for laplace_fusion_mean that handles broadcasting of parameters.
    
    Parameters
    ----------
    N : int
        Number of dimensions to use.
    mu_base_1d : float or array-like
        Mean parameter(s) for the base distribution. If a scalar or 1D array with length < N,
        will be broadcast to N dimensions by repeating the value(s).
    sigma_base_1d : float or array-like
        Scale parameter(s) for the base distribution. If a scalar or 1D array with length < N,
        will be broadcast to N dimensions by repeating the value(s).
    mu2_m : array-like
        Mean parameters for the mixture components. Shape (M,) or (M,K) where K <= N.
        If K < N, values will be broadcast to shape (M,N).
    sigma2_m : array-like
        Scale parameters for the mixture components. Shape (M,) or (M,K) where K <= N.
        If K < N, values will be broadcast to shape (M,N).
    return_component_norms : bool, optional
        If True, also return the normalization constants for each component.
        Default is False.
    moment : {'mean', 'abs'}, optional
        Moment to compute, passed to laplace_fusion_mean.
    n_jobs : int, optional
        Number of parallel jobs for dimension-level parallelization.
        - None (default): Use serial computation
        - 1: Use serial computation
        - > 1: Use parallel computation with specified number of jobs
        - -1: Use all available CPU cores
    
    Returns
    -------
    np.ndarray or tuple
        If return_component_norms is False:
            Mean vector of shape (N,).
        If return_component_norms is True:
            Tuple (mean, component_norms, component_norms_per_dim) where:
            - mean: Mean vector of shape (N,)
            - component_norms: Normalization constants for each component, shape (M,)
            - component_norms_per_dim: Normalization constants for each component and dimension, shape (M, N)
    
    Examples
    --------
    # 10D case with 1D parameters
    result = laplace_fusion_mean_broadcast(
        N=10,
        mu_base_1d=0.0,  # Same mu_base for all dimensions
        sigma_base_1d=0.5,  # Same sigma_base for all dimensions
        mu2_m=[1.0, 2.0, 3.0],  # 3 mixture components
        sigma2_m=[0.6, 0.7, 0.8]  # 3 mixture components
    )
    
    # 100D case with some variation
    result = laplace_fusion_mean_broadcast(
        N=100,
        mu_base_1d=np.linspace(0, 10, 10),  # Repeat pattern for 100D
        sigma_base_1d=np.linspace(0.5, 1.0, 5),  # Repeat pattern for 100D
        mu2_m=[[1.0, 2.0], [3.0, 4.0]],  # 2 mixture components, 2D pattern
        sigma2_m=[[0.6, 0.7], [0.8, 0.9]]  # 2 mixture components, 2D pattern
    )
    
    # Parallel computation example
    result = laplace_fusion_mean_broadcast(
        N=200,
        mu_base_1d=0.0,
        sigma_base_1d=1.0,
        mu2_m=[1.0, -1.0, 2.0],
        sigma2_m=[0.5, 0.5, 1.0],
        n_jobs=4  # Use 4 parallel jobs
    )
    """
    # Convert inputs to numpy arrays
    mu_base_1d = np.asarray(mu_base_1d)
    sigma_base_1d = np.asarray(sigma_base_1d)
    mu2_m = np.asarray(mu2_m)
    sigma2_m = np.asarray(sigma2_m)
    
    # Broadcast base parameters to N dimensions
    if mu_base_1d.ndim == 0:  # scalar
        mu_base = np.full(N, mu_base_1d)
    elif len(mu_base_1d) < N:  # array shorter than N
        # Repeat pattern to fill N dimensions
        repeat_count = int(np.ceil(N / len(mu_base_1d)))
        mu_base = np.tile(mu_base_1d, repeat_count)[:N]
    else:  # array of length >= N
        mu_base = mu_base_1d[:N]
    
    if sigma_base_1d.ndim == 0:  # scalar
        sigma_base = np.full(N, sigma_base_1d)
    elif len(sigma_base_1d) < N:  # array shorter than N
        # Repeat pattern to fill N dimensions
        repeat_count = int(np.ceil(N / len(sigma_base_1d)))
        sigma_base = np.tile(sigma_base_1d, repeat_count)[:N]
    else:  # array of length >= N
        sigma_base = sigma_base_1d[:N]
    
    # Handle mixture components
    if mu2_m.ndim == 1:  # 1D array [a, b, c] -> [[a, a, ...], [b, b, ...], [c, c, ...]]
        M = len(mu2_m)
        mu2_mat = np.tile(mu2_m.reshape(M, 1), (1, N))
    elif mu2_m.ndim == 2:  # 2D array with shape (M, K)
        M, K = mu2_m.shape
        if K < N:
            # Repeat pattern to fill N dimensions
            repeat_count = int(np.ceil(N / K))
            mu2_mat = np.tile(mu2_m, (1, repeat_count))[:, :N]
        else:  # K >= N
            mu2_mat = mu2_m[:, :N]
    else:
        raise ValueError(f"mu2_m must be 1D or 2D, got shape {mu2_m.shape}")
    
    if sigma2_m.ndim == 1:  # 1D array
        M = len(sigma2_m)
        sigma2_mat = np.tile(sigma2_m.reshape(M, 1), (1, N))
    elif sigma2_m.ndim == 2:  # 2D array with shape (M, K)
        M, K = sigma2_m.shape
        if K < N:
            # Repeat pattern to fill N dimensions
            repeat_count = int(np.ceil(N / K))
            sigma2_mat = np.tile(sigma2_m, (1, repeat_count))[:, :N]
        else:  # K >= N
            sigma2_mat = sigma2_m[:, :N]
    else:
        raise ValueError(f"sigma2_m must be 1D or 2D, got shape {sigma2_m.shape}")
    
    # Ensure both mixture component matrices have the same first dimension
    if mu2_mat.shape[0] != sigma2_mat.shape[0]:
        raise ValueError(f"mu2_m and sigma2_m must have the same number of components, "
                         f"got {mu2_mat.shape[0]} and {sigma2_mat.shape[0]}")
    
    # Broadcast denominator parameters (if provided)
    if mu_div_1d is None and sigma_div_1d is None:
        mu_div = None
        sigma_div = None
    else:
        if mu_div_1d is None:
            mu_div_1d = 0.0
        if sigma_div_1d is None:
            sigma_div_1d = 1.0

        mu_div_1d = np.asarray(mu_div_1d)
        sigma_div_1d = np.asarray(sigma_div_1d)

        if mu_div_1d.ndim == 0:
            mu_div = np.full(N, mu_div_1d)
        elif len(mu_div_1d) < N:
            repeat_count = int(np.ceil(N / len(mu_div_1d)))
            mu_div = np.tile(mu_div_1d, repeat_count)[:N]
        else:
            mu_div = mu_div_1d[:N]

        if sigma_div_1d.ndim == 0:
            sigma_div = np.full(N, sigma_div_1d)
        elif len(sigma_div_1d) < N:
            repeat_count = int(np.ceil(N / len(sigma_div_1d)))
            sigma_div = np.tile(sigma_div_1d, repeat_count)[:N]
        else:
            sigma_div = sigma_div_1d[:N]

    # Broadcast component weights (if provided)
    if component_weights_1d is None:
        component_weights = None
    else:
        component_weights_1d = np.asarray(component_weights_1d)
        if component_weights_1d.ndim == 0:  # scalar
            component_weights = np.full(M, component_weights_1d)
        elif len(component_weights_1d) < M:  # array shorter than M
            repeat_count = int(np.ceil(M / len(component_weights_1d)))
            component_weights = np.tile(component_weights_1d, repeat_count)[:M]
        else:  # array of length >= M
            component_weights = component_weights_1d[:M]

    # Call the main function with the broadcast parameters
    return laplace_fusion_mean(
        mu_base,
        sigma_base,
        mu2_mat,
        sigma2_mat,
        return_component_norms,
        divide_unit_laplace=True,
        moment=moment,
        n_jobs=n_jobs,
        mu_div=mu_div,
        sigma_div=sigma_div,
        threshold=threshold,
        component_weights=component_weights,
        natural_prior_weight=natural_prior_weight,
    )

def numerical_1D_fusion_mean(
    mu1,
    mu2,
    sigma1,
    sigma2,
    mu_div=0.0,
    sigma_div=1.0,
    moment="mean",
    debug=False,
):
    """
    Numerical integration method for verification (1D fusion mean).
    This method is slower and less accurate but useful for testing.
    """
    def log_unnorm_pdf(x):
        term1 = -np.abs(x - mu1) / sigma1
        term2 = -np.abs(x - mu2) / sigma2
        term3 = np.abs(x - mu_div) / sigma_div
        return term1 + term2 + term3

    def unnorm_pdf(x):
        log_pdf = log_unnorm_pdf(x)
        if log_pdf < -700:
            return 0.0
        return np.exp(log_pdf)

    if moment == "mean":
        def weighted_x(x):
            return x * unnorm_pdf(x)
    elif moment == "abs":
        def weighted_x(x):
            return np.abs(x) * unnorm_pdf(x)
    elif moment == "relu":
        def weighted_x(x):
            return 0.0 if x < 0 else x * unnorm_pdf(x)
    else:
        raise ValueError(f"Unknown moment {moment}")

    max_scale = max(sigma1, sigma2, sigma_div)
    min_scale = min(sigma1, sigma2, sigma_div)
    max_abs_mean = max(abs(mu1), abs(mu2))
    key_points = [mu1, mu2, mu_div]
    for mean in [mu1, mu2]:
        for mult in [1, 3, 5, 10]:
            key_points.append(mean - mult * sigma1)
            key_points.append(mean + mult * sigma1)
            key_points.append(mean - mult * sigma2)
            key_points.append(mean + mult * sigma2)
    key_points = sorted(set(key_points))
    outer_lower = min(key_points) - 20 * max_scale
    outer_upper = max(key_points) + 20 * max_scale
    if abs(mu1 - mu2) > 50 * (sigma1 + sigma2):
        bounds = sorted([
            outer_lower,
            mu1 - 10 * sigma1,
            mu1 - 5 * sigma1,
            mu1,
            mu1 + 5 * sigma1,
            mu1 + 10 * sigma1,
            mu_div - 5 * min_scale,
            mu_div,
            mu_div + 5 * min_scale,
            mu2 - 10 * sigma2,
            mu2 - 5 * sigma2,
            mu2,
            mu2 + 5 * sigma2,
            mu2 + 10 * sigma2,
            outer_upper
        ])
        bounds = sorted(set([round(b, 6) for b in bounds]))
    else:
        bounds = []
        for point in key_points:
            bounds.extend([
                point - 5 * min_scale,
                point - min_scale,
                point,
                point + min_scale,
                point + 5 * min_scale
            ])
        bounds = [outer_lower] + sorted(set([round(b, 6) for b in bounds])) + [outer_upper]
    bounds = sorted(set(bounds))
    Z = 0
    N = 0
    errors_Z = 0
    errors_N = 0
    for i in range(len(bounds) - 1):
        a, b = bounds[i], bounds[i+1]
        if a == b:
            continue
        try:
            Z_i, Z_err = quad(unnorm_pdf, a, b, epsabs=1e-12, epsrel=1e-12, limit=100)
            N_i, N_err = quad(weighted_x, a, b, epsabs=1e-12, epsrel=1e-12, limit=100)
            Z += Z_i
            N += N_i
            errors_Z += Z_err
            errors_N += N_err
            if debug and (Z_i > 1e-10 or abs(N_i) > 1e-10):
                print(f"  Interval [{a:.2f}, {b:.2f}]: Z_i={Z_i:.6e}, N_i={N_i:.6e}")
        except Exception as e:
            if debug:
                print(f"  Integration error in [{a:.2f}, {b:.2f}]: {e}")
            continue
    if abs(Z) < 1e-12:
        raise ValueError(f"Normalization constant too small for stable calculation: {Z:.6e}")
    result = N / Z
    if debug:
        print(f"Numerical integration: Z={Z:.6e} (±{errors_Z:.6e}), N={N:.6e} (±{errors_N:.6e})")
        if moment == "mean":
            print(f"Expected value: E[X]={result:.6f} ≈ {N:.6e}/{Z:.6e}")
        else:
            print(f"Expected value: E[|X|]={result:.6f} ≈ {N:.6e}/{Z:.6e}")
        theoretical_error = (abs(result) * errors_Z/Z + errors_N/Z)
        print(f"Estimated numerical error: ±{theoretical_error:.6e}")
    return result


# -----------------------------------------------------------------------------
# 2-D numerical integration utility (for test validation purposes).
# -----------------------------------------------------------------------------


def numerical_2D_fusion_mean(
    mu_base,
    sigma_base,
    mu2_mat,
    sigma2_mat,
    mu_div,
    sigma_div,
    debug=False,
):
    """Numerical double integral estimate of the fused mean in 2D (M components)."""

    import math
    from scipy.integrate import nquad

    mu_base = np.asarray(mu_base)
    sigma_base = np.asarray(sigma_base)
    mu2_mat = np.asarray(mu2_mat)
    sigma2_mat = np.asarray(sigma2_mat)
    mu_div = np.asarray(mu_div)
    sigma_div = np.asarray(sigma_div)

    assert mu_base.shape == (2,) and sigma_base.shape == (2,), "mu_base/sigma_base must be length-2 vectors"
    assert mu_div.shape == (2,) and sigma_div.shape == (2,), "mu_div/sigma_div must be length-2 vectors"
    M = mu2_mat.shape[0]
    assert mu2_mat.shape == sigma2_mat.shape == (M, 2), "mu2_mat/sigma2_mat must be (M,2)"

    # Precompute constant log factors for base and denominator (don't affect mean ratio)

    def log_laplace(x, mu, sigma):
        return -math.log(2.0 * sigma) - abs(x - mu) / sigma

    def log_unnorm_pdf(x1, x2):
        # log base
        log_b = log_laplace(x1, mu_base[0], sigma_base[0]) + log_laplace(x2, mu_base[1], sigma_base[1])

        # mixture log-sum-exp
        log_comps = np.empty(M)
        for j in range(M):
            log_comps[j] = (
                log_laplace(x1, mu2_mat[j, 0], sigma2_mat[j, 0])
                + log_laplace(x2, mu2_mat[j, 1], sigma2_mat[j, 1])
            )
        m = np.max(log_comps)
        log_mix = m + np.log(np.sum(np.exp(log_comps - m)))

        # denominator
        log_d = (
            log_laplace(x1, mu_div[0], sigma_div[0])
            + log_laplace(x2, mu_div[1], sigma_div[1])
        )

        return log_b + log_mix - log_d

    def unnorm_pdf(x1, x2):
        log_val = log_unnorm_pdf(x1, x2)
        if log_val < -700:
            return 0.0
        return math.exp(log_val)

    # Integration options
    opts = {"epsabs": 1e-5, "epsrel": 1e-5}

    # Normalisation constant
    Z, Z_err = nquad(
        lambda x2, x1: unnorm_pdf(x1, x2),
        [(-np.inf, np.inf), (-np.inf, np.inf)],
        opts=[opts, opts],
    )

    # First moments
    N1, _ = nquad(
        lambda x2, x1: x1 * unnorm_pdf(x1, x2),
        [(-np.inf, np.inf), (-np.inf, np.inf)],
        opts=[opts, opts],
    )
    N2, _ = nquad(
        lambda x2, x1: x2 * unnorm_pdf(x1, x2),
        [(-np.inf, np.inf), (-np.inf, np.inf)],
        opts=[opts, opts],
    )

    if debug:
        print(f"2D numerical integration: Z={Z:.6e} (err≈{Z_err:.1e})")

    return np.array([N1 / Z, N2 / Z])

def importance_sampling_fusion(mu_base, sigma_base, mu2_mat, sigma2_mat, Nsamples=500000):
    """
    Monte Carlo estimate of multivariate mixture mean via importance sampling
    from the base Laplace distribution, with log-sum-exp stabilization.
    """
    mu_base = np.asarray(mu_base)
    sigma_base = np.asarray(sigma_base)
    mu2_mat = np.asarray(mu2_mat)
    sigma2_mat = np.asarray(sigma2_mat)
    N = len(mu_base)
    M = mu2_mat.shape[0]
    rng = np.random.default_rng(12345)
    samples = rng.laplace(loc=mu_base, scale=sigma_base, size=(Nsamples, N))
    logw = np.empty(Nsamples)
    for i, x in enumerate(samples):
        logs = []
        for j in range(M):
            logs.append(-np.sum(np.abs(x-mu2_mat[j])/sigma2_mat[j]))
        max_log = np.max(logs)
        log_mix = max_log + np.log(np.sum(np.exp(np.array(logs)-max_log)))
        log_den = -np.sum(np.abs(x))
        logw[i] = log_mix - log_den
    w_max = np.max(logw)
    w = np.exp(logw - w_max)
    weighted = samples * w[:, None]
    return weighted.sum(axis=0) / w.sum() 

def laplace_moment(dist, moment='mean', threshold=0.0):
    """
    Compute moments of a (possibly batched) PyTorch Laplace distribution.

    Supported moments:
        - 'mean':     E[X]
        - 'abs':      E[|X|]
        - 'relu':     E[ReLU(X - x0)] with optional ``threshold`` = x0 (default 0.0)

    The closed-form for the ReLU moment of a 1-D Laplace(µ, σ) random variable is::

        if x0 < µ:   (µ - x0) + (σ / 2) * exp(-(µ - x0)/σ)
        else:        (σ / 2) * exp(-(x0 - µ)/σ)

    which we vectorise below for arbitrary tensors.

    Args:
        dist (torch.distributions.Laplace): Laplace distribution (can be batched)
        moment (str): 'mean', 'abs', or 'relu'
        threshold (float or torch.Tensor): x0 value for ReLU; ignored otherwise

    Returns
    -------
        torch.Tensor: Requested moment, same shape as ``dist.loc``.
    """

    mu = dist.loc
    sigma = dist.scale

    if moment == 'mean':
        return mu

    elif moment == 'abs':
        abs_mu = torch.abs(mu)
        return abs_mu + sigma * torch.exp(-abs_mu / sigma)

    elif moment == 'relu':
        # Allow threshold to be tensor-broadcastable with mu
        x0 = torch.as_tensor(threshold, dtype=mu.dtype, device=mu.device)
        # Broadcast x0 to the shape of mu if needed
        if x0.shape != mu.shape:
            x0 = x0.expand_as(mu)

        diff = x0 - mu  # Note: could be negative

        below_mask = diff < 0  # x0 < mu case

        # Piecewise formula
        # Case 1: x0 < mu
        case1 = (-diff) + 0.5 * sigma * torch.exp(diff / sigma)
        # Case 2: x0 >= mu
        case2 = 0.5 * sigma * torch.exp(-diff / sigma)

        return torch.where(below_mask, case1, case2)

    else:
        raise ValueError(f"Unknown moment: {moment}")

def _process_dimension_parallel(args):
    """
    Helper function for parallel processing of dimensions.
    
    Parameters
    ----------
    args : tuple
        (i, mu_base_i, sigma_base_i, mu2_mat_i, sigma2_mat_i, mu_div_i, sigma_div_i, divide_unit_laplace, moment)
        where i is the dimension index and the rest are the parameters for that dimension.
        
    Returns
    -------
    tuple
        (i, log_Zs_i, Ms_i) where:
        - i: dimension index
        - log_Zs_i: log normalization constants for each component at dimension i, shape (M,)
        - Ms_i: first moments for each component at dimension i, shape (M,)
    """

    (
        i,
        mu_base_i,
        sigma_base_i,
        mu2_mat_i,
        sigma2_mat_i,
        mu_div_i,
        sigma_div_i,
        divide_unit_laplace,
        moment,
        threshold,
        component_weights_arr,
        natural_prior_weight,
    ) = args
    M = len(mu2_mat_i)
    
    log_Zs_i = np.zeros(M)
    Ms_i = np.zeros(M)
    
    for j in range(M):
        # Compute univariate moments for this dimension and component
        Z_i, M_i = univariate_laplace_mixture_moment(
            mu_base_i,
            sigma_base_i,
            mu2_mat_i[j],
            sigma2_mat_i[j],
            divide_unit_laplace,
            moment,
            mu_div_i,
            sigma_div_i,
            threshold,
            component_weights_arr[j] if component_weights_arr is not None else 1.0,
            natural_prior_weight,
        )
        
        # Store log(Z_i) to avoid overflow when multiplying across dimensions
        log_Zs_i[j] = np.log(Z_i) if Z_i > 0 else -np.inf
        Ms_i[j] = M_i
    
    return i, log_Zs_i, Ms_i 

# -----------------------------------------------------------------------------
# Additional numerical helper for validating the ReLU threshold option
# -----------------------------------------------------------------------------


def numerical_1D_fusion_relu_threshold(
    mu1,
    mu2,
    sigma1,
    sigma2,
    x0,
    mu_div=0.0,
    sigma_div=1.0,
    debug=False,
):
    """Numerical integration of E[ReLU(X - x0)] for the 1-D fusion density.

    This helper is intended only for testing the new ``threshold`` option.  It
    performs adaptive quadrature over the real line without performing any
    change-of-variables, i.e. it evaluates ReLU(x − x0) directly.
    """

    def log_unnorm_pdf(x):
        term1 = -np.abs(x - mu1) / sigma1
        term2 = -np.abs(x - mu2) / sigma2
        term3 = np.abs(x - mu_div) / sigma_div
        return term1 + term2 + term3

    def unnorm_pdf(x):
        log_pdf = log_unnorm_pdf(x)
        if log_pdf < -700:
            return 0.0
        return np.exp(log_pdf)

    def weighted_relu(x):
        return max(0.0, x - x0) * unnorm_pdf(x)

    Z, _ = quad(unnorm_pdf, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)
    N, _ = quad(weighted_relu, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)

    if Z == 0:
        raise ValueError("Normalization constant is zero in numerical integration")
    return N / Z


def numerical_1D_fusion_relu_threshold_multi_component(
    mu1,
    mu2_array,
    sigma1,
    sigma2_array,
    x0,
    mu_div=0.0,
    sigma_div=1.0,
    debug=False,
):
    """Numerical integration of E[ReLU(X - x0)] for the 1-D fusion density with multiple components.

    This helper is intended only for testing the new ``threshold`` option with multiple mixture components.
    It performs adaptive quadrature over the real line without performing any change-of-variables,
    i.e. it evaluates ReLU(x − x0) directly.
    
    Parameters
    ----------
    mu1 : float
        Mean of the base Laplace distribution
    mu2_array : array-like
        Array of means for the mixture components
    sigma1 : float
        Scale of the base Laplace distribution
    sigma2_array : array-like
        Array of scales for the mixture components
    x0 : float
        Threshold value for ReLU function
    mu_div : float, optional
        Mean of the denominator Laplace distribution (default: 0.0)
    sigma_div : float, optional
        Scale of the denominator Laplace distribution (default: 1.0)
    debug : bool, optional
        Whether to print debug information (default: False)
        
    Returns
    -------
    float
        The expected value E[ReLU(X - x0)]
    """

    def log_unnorm_pdf(x):
        # Base distribution term
        term1 = -np.abs(x - mu1) / sigma1
        
        # Mixture components term (log-sum-exp trick)
        log_comps = np.array([-np.abs(x - mu2) / sigma2 for mu2, sigma2 in zip(mu2_array, sigma2_array)])
        max_log_comp = np.max(log_comps)
        term2 = max_log_comp + np.log(np.sum(np.exp(log_comps - max_log_comp)))
        
        # Denominator term
        term3 = np.abs(x - mu_div) / sigma_div
        
        return term1 + term2 + term3

    def unnorm_pdf(x):
        log_pdf = log_unnorm_pdf(x)
        if log_pdf < -700:
            return 0.0
        return np.exp(log_pdf)

    def weighted_relu(x):
        return max(0.0, x - x0) * unnorm_pdf(x)

    Z, _ = quad(unnorm_pdf, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)
    N, _ = quad(weighted_relu, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)

    if Z == 0:
        raise ValueError("Normalization constant is zero in numerical integration")
    
    if debug:
        print(f"Numerical integration: Z={Z:.6e}, N={N:.6e}, result={N/Z:.6f}")
    
    return N / Z 

def numerical_1D_fusion_mean_with_weights(
    mu1,
    mu2_array,
    sigma1,
    sigma2_array,
    weights,
    mu_div=0.0,
    sigma_div=1.0,
    moment="mean",
    debug=False,
):
    """Numerical integration of E[X] for the 1-D fusion density with weighted components.

    This helper is intended only for testing the component_weights option.
    It performs adaptive quadrature over the real line.
    
    Parameters
    ----------
    mu1 : float
        Mean of the base Laplace distribution
    mu2_array : array-like
        Array of means for the mixture components
    sigma1 : float
        Scale of the base Laplace distribution
    sigma2_array : array-like
        Array of scales for the mixture components
    weights : array-like
        Array of weights for each mixture component
    mu_div : float, optional
        Mean of the denominator Laplace distribution (default: 0.0)
    sigma_div : float, optional
        Scale of the denominator Laplace distribution (default: 1.0)
    moment : str, optional
        Moment to compute: 'mean', 'abs', or 'relu' (default: 'mean')
    debug : bool, optional
        Whether to print debug information (default: False)
        
    Returns
    -------
    float
        The expected value E[X] (or E[|X|] or E[ReLU(X)] depending on moment)
    """
    from scipy.integrate import quad

    def log_unnorm_pdf(x):
        # Base distribution term
        term1 = -np.abs(x - mu1) / sigma1
        
        # Mixture components term (log-sum-exp trick with weights)
        log_comps = np.array([np.log(w) - np.abs(x - mu2) / sigma2 
                             for w, mu2, sigma2 in zip(weights, mu2_array, sigma2_array)])
        max_log_comp = np.max(log_comps)
        term2 = max_log_comp + np.log(np.sum(np.exp(log_comps - max_log_comp)))
        
        # Denominator term
        term3 = np.abs(x - mu_div) / sigma_div
        
        return term1 + term2 + term3

    def unnorm_pdf(x):
        log_pdf = log_unnorm_pdf(x)
        if log_pdf < -700:
            return 0.0
        return np.exp(log_pdf)

    if moment == "mean":
        def weighted_x(x):
            return x * unnorm_pdf(x)
    elif moment == "abs":
        def weighted_x(x):
            return np.abs(x) * unnorm_pdf(x)
    elif moment == "relu":
        def weighted_x(x):
            return max(0.0, x) * unnorm_pdf(x)
    else:
        raise ValueError(f"Unknown moment {moment}")

    Z, _ = quad(unnorm_pdf, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)
    N, _ = quad(weighted_x, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)

    if Z == 0:
        raise ValueError("Normalization constant is zero in numerical integration")
    
    if debug:
        print(f"Numerical integration: Z={Z:.6e}, N={N:.6e}, result={N/Z:.6f}")
    
    return N / Z 

# -----------------------------------------------------------------------------
# Gaussian fusion helper (for reference / experimentation only)
# -----------------------------------------------------------------------------

def gaussian_mixture_fusion(
    base_mu,
    base_var,
    comp_mus,
    comp_vars,
):
    """Construct the Gaussian mixture that is proportional to
        g_base(x) * ( \sum_i g_i(x) ) / g_div(x),
    where:
        g_div  ~  N(0, I)
        g_base ~  N(base_mu, diag(base_var))
        g_i    ~  N(comp_mus[i], diag(comp_vars[i]))

    The resulting (unnormalised) density is still a Gaussian mixture.  This
    helper returns an explicit ``MixtureSameFamily`` distribution describing
    the normalised mixture **as well as** the per-component means/variances and
    normalised component weights.

    Parameters
    ----------
    base_mu : Tensor[D]
        Mean of the base Gaussian (g_base).
    base_var : Tensor[D]
        Diagonal variances of the base Gaussian.
    comp_mus : Tensor[N, D]
        Means of the mixture components g_i.
    comp_vars : Tensor[N, D]
        Diagonal variances of the mixture components g_i.

    Returns
    -------
    mixture : torch.distributions.MixtureSameFamily
        The normalised Gaussian mixture distribution.
    weights : Tensor[N]
        Normalised mixture weights that sum to 1.
    means_prime : Tensor[N, D]
        Means of the adjusted Gaussians after combining with g_base and the
        divisor distribution.
    var_prime : Tensor[N, D]
        Diagonal variances of the adjusted Gaussians.
    """

    import torch
    from torch.distributions import Normal, Independent, Categorical, MixtureSameFamily

    # Ensure tensors and float dtype
    base_mu = torch.as_tensor(base_mu, dtype=torch.get_default_dtype())
    base_var = torch.as_tensor(base_var, dtype=torch.get_default_dtype())
    comp_mus = torch.as_tensor(comp_mus, dtype=torch.get_default_dtype())
    comp_vars = torch.as_tensor(comp_vars, dtype=torch.get_default_dtype())

    if base_mu.ndim != 1:
        raise ValueError("base_mu must be 1-D (shape [D])")
    if base_var.shape != base_mu.shape:
        raise ValueError("base_var must have same shape as base_mu")
    if comp_mus.ndim != 2 or comp_mus.shape != comp_vars.shape:
        raise ValueError("comp_mus/comp_vars must be shape [N, D]")

    D = base_mu.shape[0]
    N = comp_mus.shape[0]

    # Precision (inverse variance) terms
    lambda_base = 1.0 / base_var                    # [D]
    lambda_div = torch.ones(D, dtype=base_var.dtype)  # unit variance divisor
    lambda_comps = 1.0 / comp_vars                  # [N, D]

    # Adjusted precisions λ'_i = λ_base + λ_i − λ_div
    lambda_prime = lambda_base.unsqueeze(0) + lambda_comps - lambda_div.unsqueeze(0)  # [N, D]

    if torch.any(lambda_prime <= 0):
        raise ValueError("Adjusted precision became non-positive; choose different parameters.")

    # Adjusted variances Σ'_i = 1 / λ'_i
    var_prime = 1.0 / lambda_prime                  # [N, D]

    # Adjusted means μ'_i = Σ'_i (λ_base μ_base + λ_i μ_i)
    weighted_base = lambda_base * base_mu           # [D]
    weighted_comp = lambda_comps * comp_mus         # [N, D]
    numer = weighted_base.unsqueeze(0) + weighted_comp  # [N, D]
    means_prime = var_prime * numer                 # [N, D]

    # ------------------------------------------------------------------
    # Mixture weights (unnormalised log-weights for numerical stability)
    # ------------------------------------------------------------------

    # |Σ'_i| / (|Σ_base| |Σ_i|)^{1/2} term (diagonal covariances ⇒ product)
    log_det_term = 0.5 * (
        var_prime.log().sum(-1)  # log |Σ'_i|
        - base_var.log().sum()   # −log |Σ_base|
        - comp_vars.log().sum(-1)  # −log |Σ_i|
    )  # [N]

    # Quadratic forms μ^T Λ μ (with diagonal precision)
    quad_base = (base_mu * (lambda_base * base_mu)).sum()            # scalar
    quad_comps = (comp_mus * (lambda_comps * comp_mus)).sum(-1)      # [N]
    quad_adj = (means_prime * (lambda_prime * means_prime)).sum(-1)  # [N]

    log_alpha_tilde = log_det_term - 0.5 * (quad_base + quad_comps - quad_adj)

    # Stable softmax
    max_log = log_alpha_tilde.max()
    alpha_tilde = torch.exp(log_alpha_tilde - max_log)
    weights = alpha_tilde / alpha_tilde.sum()

    # ------------------------------------------------------------------
    # Build a proper MixtureSameFamily distribution
    # ------------------------------------------------------------------
    component_dist = Independent(Normal(means_prime, var_prime.sqrt()), 1)
    mixture_dist = MixtureSameFamily(Categorical(probs=weights), component_dist)

    return mixture_dist, weights, means_prime, var_prime 

# Backward-compatibility alias (to be removed once downstream code updated)
build_adjusted_mixture = gaussian_mixture_fusion

# -----------------------------------------------------------------------------
# Numerical integration helper (1-D Gaussian case) for validation purposes
# -----------------------------------------------------------------------------

def numerical_1D_gaussian_fusion_mean(
    mu_base,
    mu2_array,
    sigma_base,
    sigma2_array,
    mu_div=0.0,
    sigma_div=1.0,
    debug=False,
):
    """Numerical quadrature estimate of the fused mean in the 1-D Gaussian case.

    Parameters
    ----------
    mu_base : float
        Mean of the base Gaussian g_base.
    mu2_array : array-like of length M
        Means of the mixture components g_i.
    sigma_base : float
        Standard deviation (√variance) of g_base.
    sigma2_array : array-like of length M
        Standard deviations of the mixture components.
    mu_div : float, optional
        Mean of the divisor Gaussian g_div (default 0).
    sigma_div : float, optional
        Standard deviation of g_div (default 1).
    debug : bool, optional
        Print debug information.

    Returns
    -------
    float
        Numerical estimate of the mean of the normalised fusion density.
    """
    from math import log, exp
    from scipy.integrate import quad
    import numpy as np

    mu2_array = np.asarray(mu2_array, dtype=float)
    sigma2_array = np.asarray(sigma2_array, dtype=float)

    def log_norm_pdf(x, mu, sigma):
        return -0.5 * log(2.0 * np.pi) - log(sigma) - 0.5 * ((x - mu) / sigma) ** 2

    def log_unnorm_pdf(x):
        # base term
        log_b = log_norm_pdf(x, mu_base, sigma_base)

        # mixture term via log-sum-exp over components
        log_comps = [log_norm_pdf(x, m, s) for m, s in zip(mu2_array, sigma2_array)]
        m = max(log_comps)
        log_mix = m + log(sum(exp(l - m) for l in log_comps))

        # divisor term
        log_d = log_norm_pdf(x, mu_div, sigma_div)

        return log_b + log_mix - log_d

    def unnorm_pdf(x):
        val = log_unnorm_pdf(x)
        # Clip to avoid underflow
        if val < -745:  # exp(-745) ~ 5e-324 (double precision underflow threshold)
            return 0.0
        return exp(val)

    # Integrands
    def weighted_x(x):
        return x * unnorm_pdf(x)

    # Perform integration over the real line.
    Z, Z_err = quad(unnorm_pdf, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)
    N, N_err = quad(weighted_x, -np.inf, np.inf, epsabs=1e-10, epsrel=1e-10)

    if debug:
        print(f"Gaussian quad: Z={Z:.6e} ±{Z_err:.1e}, N={N:.6e} ±{N_err:.1e}")

    if Z == 0 or not np.isfinite(Z):
        raise ValueError("Normalization constant is zero or non-finite in numerical integration")

    return N / Z 