import unittest
import numpy as np
import time
import multiprocessing as mp
from task_vae.task_posterior import (
    laplace_fusion_mean,
    numerical_1D_fusion_mean,
    importance_sampling_fusion,
    laplace_fusion_mean_broadcast,
    laplace_moment,
    gaussian_mixture_fusion,
)
import torch

class TestTaskPosterior(unittest.TestCase):
    def test_convergence_condition(self):
        with self.assertRaises(ValueError):
            laplace_fusion_mean(np.array([0.0]), np.array([2.0]), np.array([[0.0]]), np.array([[2.0]]), n_jobs=None)
        with self.assertRaises(ValueError):
            laplace_fusion_mean(np.array([1.0]), np.array([3.0]), np.array([[2.0]]), np.array([[3.0]]), n_jobs=None)

    def test_degenerate_cases(self):
        result = laplace_fusion_mean(np.array([0.0]), np.array([0.5]), np.array([[0.0]]), np.array([[0.5]]), n_jobs=None)[0]
        self.assertEqual(result, 0.0)
        result = laplace_fusion_mean(np.array([0.0]), np.array([0.5]), np.array([[0.0]]), np.array([[0.2]]), n_jobs=None)[0]
        self.assertEqual(result, 0.0)
        result = laplace_fusion_mean(np.array([3.0]), np.array([0.4]), np.array([[-3.0]]), np.array([[0.4]]), n_jobs=None)[0]
        self.assertEqual(result, 0.0)

    def test_fusion_method_consistency(self):
        test_cases = [
            (1.0, 0.8, 2.5, 1.2),
            (0.0, 0.5, 0.0, 0.5),
            (-1.0, 0.3, 2.0, 0.7),
            (5.0, 0.4, -3.0, 0.6),
        ]
        for mu1, sigma1, mu2, sigma2 in test_cases:
            log_space = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), np.array([[mu2]]), np.array([[sigma2]]), n_jobs=None)[0]

            if not hasattr(self, 'skip_numerical') or not self.skip_numerical:
                try:
                    numerical = numerical_1D_fusion_mean(mu1, mu2, sigma1, sigma2)
                    rel_diff = abs(log_space - numerical) / max(abs(log_space), 1e-10)
                    self.assertLess(rel_diff, 0.01)
                except Exception:
                    pass

    def test_fusion_symmetry_properties(self):
        mu1, sigma1 = 2.0, 0.3
        mu2, sigma2 = -1.0, 0.7
        result1 = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), np.array([[mu2]]), np.array([[sigma2]]), n_jobs=None)[0]
        result2 = laplace_fusion_mean(np.array([mu2]), np.array([sigma2]), np.array([[mu1]]), np.array([[sigma1]]), n_jobs=None)[0]
        self.assertAlmostEqual(result1, result2, places=10)
        result = laplace_fusion_mean(np.array([2.0]), np.array([0.5]), np.array([[-2.0]]), np.array([[0.5]]), n_jobs=None)[0]
        self.assertAlmostEqual(result, 0.0, places=10)

    def test_fusion_extreme_parameter_values(self):
        result = laplace_fusion_mean(np.array([50.0]), np.array([0.2]), np.array([[-30.0]]), np.array([[0.3]]), n_jobs=None)[0]
        self.assertTrue(np.isfinite(result))
        result = laplace_fusion_mean(np.array([5.0]), np.array([0.1]), np.array([[2.0]]), np.array([[1.0]]), n_jobs=None)[0]
        self.assertTrue(np.isfinite(result))
        self.assertGreater(result, 3.5)

    def test_fusion_known_reference_values(self):
        reference_values = [
            (1.0, 0.5, 2.0, 0.5, 1.77),
            (0.0, 0.3, 0.0, 0.7, 0.0),
            (3.0, 0.4, -2.0, 0.6, 2.73),
            (10.0, 0.1, -10.0, 0.1, 0.0),
            (-40.0, 0.5, 35.0, 0.6, -39.625),
            (-25.0, 0.2, 25.0, 0.3, -24.76136),
            (100.0, 0.15, -40.0, 1.50, 100.01504),
            (120.0, 0.3, -80.0, 0.4, 119.6614)
        ]
        for mu1, sigma1, mu2, sigma2, expected in reference_values:
            result = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), np.array([[mu2]]), np.array([[sigma2]]), n_jobs=None)[0]
            self.assertAlmostEqual(result, expected, places=1)

    def test_fusion_kappa_zero_cases(self):
        log_space = laplace_fusion_mean(np.array([-1.0]), np.array([0.5]), np.array([[1.0]]), np.array([[1.0]]), n_jobs=None)[0]
        numerical = numerical_1D_fusion_mean(-1.0, 1.0, 0.5, 1.0)
        self.assertAlmostEqual(log_space, numerical, places=6)
        log_space = laplace_fusion_mean(np.array([2.0]), np.array([1.0]), np.array([[-2.0]]), np.array([[2.0]]), n_jobs=None)[0]
        numerical = numerical_1D_fusion_mean(2.0, -2.0, 1.0, 2.0)
        self.assertAlmostEqual(log_space, numerical, places=6)

    # REMOVED: test_fusion_single_component - trivial test that calls same function twice

    def test_fusion_multi_component(self):
        mu1 = 0.0
        sigma1 = 0.5
        mu2_array = np.array([-1.0, 1.0])
        sigma2_array = np.array([0.7, 0.7])
        result = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_array.reshape(-1, 1), sigma2_array.reshape(-1, 1), n_jobs=None)[0]
        self.assertAlmostEqual(result, 0.0, places=8)
        mu1 = 1.0
        sigma1 = 0.4
        mu2_array = np.array([-3.0, 0.0, 5.0])
        sigma2_array = np.array([0.5, 0.7, 0.9])
        result = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_array.reshape(-1, 1), sigma2_array.reshape(-1, 1), n_jobs=None)[0]
        self.assertTrue(np.isfinite(result))
        self.assertAlmostEqual(result, 1.07800322, places=7)
        mu1 = 1.0
        sigma1 = 0.5
        mu2_1 = np.array([0.0])
        sigma2_1 = np.array([0.5])
        mu2_3 = np.array([-1.0, 0.0, 1.0])
        sigma2_3 = np.array([0.5, 0.5, 0.5])
        mu2_5 = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
        sigma2_5 = np.array([0.5, 0.5, 0.5, 0.5, 0.5])
        result_1 = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_1.reshape(-1, 1), sigma2_1.reshape(-1, 1), n_jobs=None)[0]
        result_3 = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_3.reshape(-1, 1), sigma2_3.reshape(-1, 1), n_jobs=None)[0]
        result_5 = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_5.reshape(-1, 1), sigma2_5.reshape(-1, 1), n_jobs=None)[0]
        self.assertTrue(np.isfinite(result_1))
        self.assertTrue(np.isfinite(result_3))
        self.assertTrue(np.isfinite(result_5))

    def test_fusion_multi_reference_values(self):
        reference_values = [
            (0.0, 0.5, [-1.0, 1.0], [0.7, 0.7], 0.0),
            (2.0, 0.8, [0.0, 4.0], [0.5, 0.5], 3.72440818),
            (2.0, 0.5, [3.0, 3.0], [0.7, 0.7], 2.71685855),
            (0.0, 0.5, [-2.0, 0.0, 2.0], [0.6, 0.6, 0.6], 0.0),
            (1.0, 0.4, [-3.0, 0.0, 5.0], [0.5, 0.7, 0.9], 1.07800322),
            (2.0, 0.4, [1.0, 2.0, 3.0], [0.5, 0.4, 0.5], 2.23302281),
            (1.0, 0.5, [-2.0, -1.0, 0.0, 1.0, 2.0], [0.3, 0.4, 0.5, 0.6, 0.7], 1.29043481)
        ]
        for mu1, sigma1, mu2_list, sigma2_list, expected in reference_values:
            mu2_array = np.array(mu2_list)
            sigma2_array = np.array(sigma2_list)
            result = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_array.reshape(-1, 1), sigma2_array.reshape(-1, 1), n_jobs=None)[0]
            self.assertAlmostEqual(result, expected, places=7)



    def test_fusion_identical_components(self):
        mu1 = 2.0
        sigma1 = 0.5
        single = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), np.array([[3.0]]), np.array([[0.7]]), n_jobs=None)[0]
        mu2_array = np.array([3.0, 3.0])
        sigma2_array = np.array([0.7, 0.7])
        multi = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_array.reshape(-1, 1), sigma2_array.reshape(-1, 1), n_jobs=None)[0]
        self.assertAlmostEqual(single, multi, places=8)
        mu2_array_triple = np.array([3.0, 3.0, 3.0])
        sigma2_array_triple = np.array([0.7, 0.7, 0.7])
        triple = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]), mu2_array_triple.reshape(-1, 1), sigma2_array_triple.reshape(-1, 1), n_jobs=None)[0]
        self.assertAlmostEqual(single, triple, places=8)

    def test_fusion_multidim_equivalence(self):


        
        # Test cases from test_sum_laplace_multi_reference_values
        reference_values = [
            # mu1, sigma1, mu2_array, sigma2_array, expected result
            (0.0, 0.5, [-1.0, 1.0], [0.7, 0.7], 0.0),
            (2.0, 0.8, [0.0, 4.0], [0.5, 0.5], 3.72440818),
            (2.0, 0.5, [3.0, 3.0], [0.7, 0.7], 2.71685855),
            (0.0, 0.5, [-2.0, 0.0, 2.0], [0.6, 0.6, 0.6], 0.0),
            (1.0, 0.4, [-3.0, 0.0, 5.0], [0.5, 0.7, 0.9], 1.07800322),
            (2.0, 0.4, [1.0, 2.0, 3.0], [0.5, 0.4, 0.5], 2.23302281),
            (1.0, 0.5, [-2.0, -1.0, 0.0, 1.0, 2.0], [0.3, 0.4, 0.5, 0.6, 0.7], 1.29043481)
        ]
        
        for mu1, sigma1, mu2_list, sigma2_list, expected in reference_values:
            mu2_array = np.array(mu2_list)
            sigma2_array = np.array(sigma2_list)
            
            # Calculate with laplace_fusion_mean
            mu_base = np.array([mu1])
            sigma_base = np.array([sigma1]) 
            mu2_mat = mu2_array.reshape(-1, 1)  # Shape (M,1)
            sigma2_mat = sigma2_array.reshape(-1, 1)  # Shape (M,1)
            
            multi_result = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None)[0]
            
            # Check against expected value
            self.assertAlmostEqual(
                multi_result, expected, places=7,
                msg=f"Parameters: mu1={mu1}, sigma1={sigma1}, mu2_array={mu2_list}, sigma2_array={sigma2_list}"
            )

    def test_fusion_multidim_reference_values(self):
        mu_base_2d = np.array([0.0, 1.0])
        sigma_base_2d = np.array([0.5, 0.7])
        mu2_mat_2d = np.array([[0.5, 1.5], [1.0, 0.0]])
        sigma2_mat_2d = np.array([[0.6, 0.8], [0.7, 0.5]])
        ref_2d = np.array([0.36140753, 1.44153697])
        result_2d = laplace_fusion_mean(mu_base_2d, sigma_base_2d, mu2_mat_2d, sigma2_mat_2d, n_jobs=None)
        np.testing.assert_allclose(result_2d, ref_2d, rtol=1e-7, atol=1e-8)
        mu_base_3d = np.array([0.0, 0.5, 1.0])
        sigma_base_3d = np.array([0.7, 0.6, 0.5])
        mu2_mat_3d = np.array([
            [0.2, 0.7, 1.2],
            [0.5, 1.0, 1.5],
            [1.0, 0.0, 2.0]
        ])
        sigma2_mat_3d = np.array([
            [0.8, 0.7, 0.6],
            [0.7, 0.8, 0.9],
            [0.6, 0.7, 0.8]
        ])
        ref_3d = np.array([0.36499544, 0.80808288, 1.40734982])
        result_3d = laplace_fusion_mean(mu_base_3d, sigma_base_3d, mu2_mat_3d, sigma2_mat_3d, n_jobs=None)
        np.testing.assert_allclose(result_3d, ref_3d, rtol=1e-7, atol=1e-8)
        mu_base_4d = np.array([0.0, 0.5, 1.0, 1.5])
        sigma_base_4d = np.array([0.5, 0.6, 0.7, 0.8])
        mu2_mat_4d = np.array([
            [0.2, 0.7, 1.2, 1.7],
            [0.5, 1.0, 1.5, 2.0],
            [1.0, 0.0, 2.0, 1.0],
            [1.5, 1.5, 1.5, 1.5]
        ])
        sigma2_mat_4d = np.array([
            [0.6, 0.7, 0.8, 0.9],
            [0.7, 0.8, 0.9, 1.0],
            [0.8, 0.7, 0.6, 0.5],
            [0.9, 1.0, 0.5, 0.6]
        ])
        ref_4d = np.array([0.32253191, 0.91993398, 1.54870986, 2.06322058])
        result_4d = laplace_fusion_mean(mu_base_4d, sigma_base_4d, mu2_mat_4d, sigma2_mat_4d, n_jobs=None)
        np.testing.assert_allclose(result_4d, ref_4d, rtol=1e-7, atol=1e-8)
        mu_base_3d2 = np.array([0.0, 1.0, 2.0])
        sigma_base_3d2 = np.array([0.5, 0.6, 0.7])
        mu2_mat_3d2 = np.array([[0.5, 1.5, 2.5], [1.0, 0.0, 1.0]])
        sigma2_mat_3d2 = np.array([[0.6, 0.8, 0.9], [0.7, 0.5, 0.6]])
        ref_3d2 = np.array([0.32956509, 1.47854135, 2.59574399])
        result_3d2 = laplace_fusion_mean(mu_base_3d2, sigma_base_3d2, mu2_mat_3d2, sigma2_mat_3d2, n_jobs=None)
        np.testing.assert_allclose(result_3d2, ref_3d2, rtol=1e-7, atol=1e-8)
        mu_base_2d3 = np.array([0.0, 1.0])
        sigma_base_2d3 = np.array([0.5, 0.7])
        mu2_mat_2d3 = np.array([[0.5, 1.5], [1.0, 0.0], [1.5, 0.5]])
        sigma2_mat_2d3 = np.array([[0.6, 0.8], [0.7, 0.5], [0.8, 0.6]])
        ref_2d3 = np.array([0.4726001, 1.33752264])
        result_2d3 = laplace_fusion_mean(mu_base_2d3, sigma_base_2d3, mu2_mat_2d3, sigma2_mat_2d3, n_jobs=None)
        np.testing.assert_allclose(result_2d3, ref_2d3, rtol=1e-7, atol=1e-8)

    def test_fusion_all_methods_comparison(self):
        """Test comparing all available methods for 1D fusion cases."""
        test_cases = [
            # Standard cases
            (1.0, 0.8, 2.5, 1.2),
            (0.0, 0.5, 0.0, 0.5),
            (-1.0, 0.3, 2.0, 0.7),
            (5.0, 0.4, -3.0, 0.6),
            # Edge cases
            (0.0, 0.9, 0.0, 0.2),
            (1.0, 0.5, 2.0, 0.5),
            (0.0, 1.0, 0.0, 0.5),
            # Symmetric cases
            (10.0, 0.1, -10.0, 0.1),
            (5.0, 0.2, -5.0, 0.2),
            (0.0, 0.2, 0.0, 0.2),
            # Same means
            (50.0, 0.2, 50.0, 0.2),
            (-7.0, 0.3, -7.0, 0.3),
            (0.0, 0.1, 0.0, 0.1),
            # Special cases with kappa=0
            (-1.0, 0.5, 1.0, 1.0),
            (2.0, 1.0, -2.0, 2.0),
            # Numerically challenging cases
            (50.0, 0.1, -20.0, 2.0),
            (75.0, 0.2, 3.0, 3.0),
            (-40.0, 0.5, 35.0, 0.6),
            (-30.0, 0.3, 40.0, 0.4),
            (-25.0, 0.2, 25.0, 0.3),
            # Testing boundaries
            (80.0, 0.05, -30.0, 1.0),
            (100.0, 0.15, -40.0, 1.5),
            (-60.0, 0.1, 50.0, 0.15),
            (120.0, 0.3, -80.0, 0.4),
            (-90.0, 0.25, 70.0, 0.35),
        ]
        
        for mu1, sigma1, mu2, sigma2 in test_cases:
            with self.subTest(mu1=mu1, sigma1=sigma1, mu2=mu2, sigma2=sigma2):
                # Test laplace_fusion_mean (1D case)
                fusion_result = laplace_fusion_mean(
                    np.array([mu1]), np.array([sigma1]), 
                    np.array([[mu2]]), np.array([[sigma2]]), n_jobs=None
                )[0]
                
                # Test numerical_1D_fusion_mean
                try:
                    numerical_result = numerical_1D_fusion_mean(mu1, mu2, sigma1, sigma2)
                    rel_diff = abs(fusion_result - numerical_result) / max(abs(fusion_result), 1e-10)
                    self.assertLess(rel_diff, 0.01, f"Large difference between methods: {rel_diff}")
                except Exception:
                    # Skip if numerical method fails (convergence issues)
                    pass
                
                # Basic sanity checks
                self.assertIsInstance(fusion_result, float)
                self.assertFalse(np.isnan(fusion_result))
                self.assertFalse(np.isinf(fusion_result))

    def test_fusion_sum_equivalence(self):
        """Test that single-component fusion is equivalent to multiple identical components."""
        test_cases = [
            (1.0, 0.8, 2.5, 1.2),
            (0.0, 0.5, 0.0, 0.5),
            (-1.0, 0.3, 2.0, 0.7),
            (5.0, 0.4, -3.0, 0.6),
            (10.0, 0.1, -10.0, 0.1),
        ]
        
        for mu1, sigma1, mu2, sigma2 in test_cases:
            with self.subTest(mu1=mu1, sigma1=sigma1, mu2=mu2, sigma2=sigma2):
                # Single component
                single_result = laplace_fusion_mean(
                    np.array([mu1]), np.array([sigma1]), 
                    np.array([[mu2]]), np.array([[sigma2]]), n_jobs=None
                )[0]
                
                # Two identical components
                double_result = laplace_fusion_mean(
                    np.array([mu1]), np.array([sigma1]), 
                    np.array([[mu2], [mu2]]), np.array([[sigma2], [sigma2]]), n_jobs=None
                )[0]
                
                # Should be the same
                self.assertAlmostEqual(single_result, double_result, places=10)

    def test_fusion_multiple_components(self):
        """Test fusion with multiple mixture components."""
        # Test case 1: Two similar elements
        mu1 = 2.0
        sigma1 = 0.5
        mu2_array = np.array([[1.0], [3.0]])
        sigma2_array = np.array([[0.7], [0.7]])
        
        result = laplace_fusion_mean(
            np.array([mu1]), np.array([sigma1]), mu2_array, sigma2_array, n_jobs=None
        )[0]
        
        self.assertIsInstance(result, float)
        self.assertFalse(np.isnan(result))
        self.assertFalse(np.isinf(result))
        
        # Test case 2: Three varied elements
        mu1 = 0.0
        sigma1 = 0.8
        mu2_array = np.array([[-2.0], [0.0], [2.0]])
        sigma2_array = np.array([[0.5], [0.6], [0.7]])
        
        result = laplace_fusion_mean(
            np.array([mu1]), np.array([sigma1]), mu2_array, sigma2_array, n_jobs=None
        )[0]
        
        self.assertIsInstance(result, float)
        self.assertFalse(np.isnan(result))
        self.assertFalse(np.isinf(result))

    def test_fusion_multivariate_analytical_vs_numeric(self):
        """Test analytical vs numerical methods for multivariate cases."""
        test_cases = [
            # 1D cases
            {
                "mu_base": [0.5], "sigma_base": [0.8],
                "mu2_mat": [[1.0]], "sigma2_mat": [[1.2]]
            },
            {
                "mu_base": [10.0], "sigma_base": [0.2],
                "mu2_mat": [[-5.0]], "sigma2_mat": [[0.4]]
            },
            # 2D cases
            {
                "mu_base": [0.0, 1.0], "sigma_base": [0.5, 0.7],
                "mu2_mat": [[0.5, 1.5]], "sigma2_mat": [[0.6, 0.8]]
            },
            {
                "mu_base": [-1.0, 2.0], "sigma_base": [0.3, 0.6],
                "mu2_mat": [[1.0, -1.0]], "sigma2_mat": [[0.5, 0.4]]
            },
            # 3D cases
            {
                "mu_base": [0.0, 0.5, 1.0], "sigma_base": [0.7, 0.6, 0.5],
                "mu2_mat": [[0.2, 0.7, 1.2]], "sigma2_mat": [[0.8, 0.7, 0.6]]
            },
            # 4D case (limit for importance sampling)
            {
                "mu_base": [0.0, 0.5, 1.0, 1.5], "sigma_base": [0.5, 0.6, 0.7, 0.8],
                "mu2_mat": [[0.2, 0.7, 1.2, 1.7]], "sigma2_mat": [[0.6, 0.7, 0.8, 0.9]]
            },
        ]
        
        for i, test_case in enumerate(test_cases):
            with self.subTest(case=i):
                mu_base = np.array(test_case["mu_base"])
                sigma_base = np.array(test_case["sigma_base"])
                mu2_mat = np.array(test_case["mu2_mat"])
                sigma2_mat = np.array(test_case["sigma2_mat"])
                
                # Test analytical method
                analytical_result = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None)
                
                # Basic checks
                self.assertEqual(analytical_result.shape, mu_base.shape)
                self.assertFalse(np.any(np.isnan(analytical_result)))
                self.assertFalse(np.any(np.isinf(analytical_result)))
                
                # Test importance sampling only for cases with ≤4 dimensions and ≤4 components
                dims = len(mu_base)
                components = mu2_mat.shape[0]
                if dims <= 4 and components <= 4:
                    try:
                        numeric_result = importance_sampling_fusion(
                            mu_base, sigma_base, mu2_mat, sigma2_mat, Nsamples=50000
                        )
                        
                        # Calculate relative difference
                        abs_diff = np.linalg.norm(analytical_result - numeric_result)
                        norm_analytical = np.linalg.norm(analytical_result)
                        rel_diff = abs_diff / (norm_analytical + 1e-10)
                        
                        # Allow larger tolerance for Monte Carlo methods
                        self.assertLess(rel_diff, 0.1, f"Large difference: {rel_diff}")
                    except Exception:
                        # Skip if importance sampling fails
                        pass

    # REMOVED: test_fusion_broadcast_wrapper - superseded by test_fusion_broadcast_wrapper_comprehensive

    def test_fusion_known_reference_values_extended(self):
        """Test against known reference values from calculations."""
        # Reference value 1: Simple 1D case
        result = laplace_fusion_mean(
            np.array([1.0]), np.array([0.8]), 
            np.array([[2.5]]), np.array([[1.2]])
        )[0]
        self.assertAlmostEqual(result, 2.461727491194927, places=10)
        
        # Reference value 2: Symmetric case (should be zero)
        result = laplace_fusion_mean(
            np.array([2.0]), np.array([0.5]), 
            np.array([[-2.0]]), np.array([[0.5]])
        )[0]
        self.assertAlmostEqual(result, 0.0, places=10)
        
        # Reference value 3: Multiple components (symmetric, should be zero)
        result = laplace_fusion_mean(
            np.array([0.0]), np.array([0.5]), 
            np.array([[-1.0], [1.0]]), np.array([[0.7], [0.7]])
        )[0]
        self.assertAlmostEqual(result, 0.0, places=10)
        
        # Reference value 4: 2D case
        result = laplace_fusion_mean(
            np.array([0.0, 1.0]), np.array([0.5, 0.7]), 
            np.array([[0.5, 1.5]]), np.array([[0.6, 0.8]])
        )
        expected = np.array([0.31767004, 1.59014114])
        np.testing.assert_allclose(result, expected, rtol=1e-7)
        
        # Reference value 5: Broadcast wrapper
        result = laplace_fusion_mean_broadcast(5, 0.0, 0.5, [1.0, 2.0], [0.6, 0.7])
        expected = np.array([0.68295355, 0.68295355, 0.68295355, 0.68295355, 0.68295355])
        np.testing.assert_allclose(result, expected, rtol=1e-7)

    def test_fusion_performance_scaling(self):
        """Test performance with increasing dimensions and components."""
        import time
        
        # Test with different dimensions (but keep components low)
        dims_list = [1, 2, 3, 5, 10]
        
        for dims in dims_list:
            with self.subTest(dims=dims):
                # Create test parameters
                mu_base = np.linspace(0, dims/2, dims)
                sigma_base = np.ones(dims) * 0.5 + np.linspace(0, 0.2, dims)
                mu2_mat = np.array([mu_base + 0.5])
                sigma2_mat = np.array([sigma_base + 0.1])
                
                # Time the computation
                start_time = time.time()
                result = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat)
                computation_time = time.time() - start_time
                
                # Basic checks
                self.assertEqual(result.shape, (dims,))
                self.assertFalse(np.any(np.isnan(result)))
                self.assertFalse(np.any(np.isinf(result)))
                
                # Performance should be reasonable (less than 1 second for these sizes)
                self.assertLess(computation_time, 1.0)

    def test_fusion_edge_cases_extended(self):
        """Test additional edge cases and boundary conditions."""
        # Test with very small scales
        result = laplace_fusion_mean(
            np.array([0.0]), np.array([0.01]), 
            np.array([[1.0]]), np.array([[0.01]])
        )[0]
        self.assertIsInstance(result, float)
        self.assertFalse(np.isnan(result))
        self.assertFalse(np.isinf(result))
        
        # Test with large means
        result = laplace_fusion_mean(
            np.array([100.0]), np.array([0.1]), 
            np.array([[-50.0]]), np.array([[0.2]])
        )[0]
        self.assertIsInstance(result, float)
        self.assertFalse(np.isnan(result))
        self.assertFalse(np.isinf(result))
        
        # Test with many components (but low dimension to avoid importance sampling)
        mu_base = np.array([0.0])
        sigma_base = np.array([0.5])
        mu2_mat = np.array([[-2.0], [-1.0], [0.0], [1.0], [2.0]])
        sigma2_mat = np.array([[0.6], [0.7], [0.8], [0.7], [0.6]])
        
        result = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat)[0]
        self.assertIsInstance(result, float)
        self.assertFalse(np.isnan(result))
        self.assertFalse(np.isinf(result))
        # Should be close to zero due to symmetry
        self.assertAlmostEqual(result, 0.0, places=8)

    def test_fusion_boundary_and_challenging_cases(self):
        """Test boundary and challenging cases from the original test_all_methods function."""
        # Testing boundaries cases - these are the most challenging numerical cases
        boundary_cases = [
            # (mu1, sigma1, mu2, sigma2, expected_result)
            (80.0, 0.05, -30.0, 1.0, 79.9999999999998),      # Testing smaller scales
            (100.0, 0.15, -40.0, 1.5, 100.01503759398496),   # Testing larger separation
            (-60.0, 0.1, 50.0, 0.15, -59.833060556467295),   # Testing very small scales
            (120.0, 0.3, -80.0, 0.4, 119.66144200627018),    # Testing extreme separation
            (-90.0, 0.25, 70.0, 0.35, -89.70406504064954),   # Testing large opposite means
        ]
        
        # Numerically challenging cases that work
        challenging_cases = [
            (50.0, 0.1, -20.0, 2.0, 50.01002506265663),      # Moderate contrast
            (75.0, 0.2, 3.0, 3.0, 75.05429864253412),        # Large mean, balanced scales
            (-40.0, 0.5, 35.0, 0.6, -39.62499999998836),     # Far means, similar scales
            (-30.0, 0.3, 40.0, 0.4, -29.661442006269976),    # Far means, similar scales
            (-25.0, 0.2, 25.0, 0.3, -24.761363636363335),    # Near symmetric
        ]
        
        all_cases = boundary_cases + challenging_cases
        
        for i, (mu1, sigma1, mu2, sigma2, expected) in enumerate(all_cases):
            with self.subTest(case=i+1, mu1=mu1, sigma1=sigma1, mu2=mu2, sigma2=sigma2):
                # Test the analytical method
                result = laplace_fusion_mean(
                    np.array([mu1]), np.array([sigma1]), 
                    np.array([[mu2]]), np.array([[sigma2]])
                )[0]
                
                # Basic sanity checks
                self.assertIsInstance(result, float)
                self.assertFalse(np.isnan(result))
                self.assertFalse(np.isinf(result))
                
                # Check against expected value with reasonable tolerance
                # These are challenging cases so we allow slightly larger tolerance
                self.assertAlmostEqual(result, expected, places=8, 
                    msg=f"Case {i+1}: ({mu1}, {sigma1}, {mu2}, {sigma2}) -> expected {expected}, got {result}")

    def test_fusion_broadcast_wrapper_comprehensive(self):
        """Test comprehensive broadcast wrapper functionality from original test_broadcast_wrapper."""
        
        # Test 1: 10D case with 1D base parameters and 2D mixture parameters (pattern repetition)
        print("\nTest 1: 10D with 1D base, 2D mixture (3 components) - pattern repetition")
        
        N = 10
        mu_base_1d = np.array([0.0, 1.0, 2.0])  # Pattern to repeat
        sigma_base_1d = np.array([0.5, 0.6, 0.7, 0.8])  # Pattern to repeat
        mu2_2d = np.array([
            [0.5, 1.5],
            [1.0, 0.0],
            [2.0, 3.0]
        ])  # 3 mixture components, 2D pattern
        sigma2_2d = np.array([
            [0.6, 0.8],
            [0.7, 0.5],
            [0.9, 0.7]
        ])  # 3 mixture components, 2D pattern
        
        # Run with wrapper
        result_broadcast = laplace_fusion_mean_broadcast(
            N, mu_base_1d, sigma_base_1d, mu2_2d, sigma2_2d
        )
        
        # Create explicit arrays for comparison
        mu_base_explicit = np.tile(mu_base_1d, int(np.ceil(N / len(mu_base_1d))))[:N]
        sigma_base_explicit = np.tile(sigma_base_1d, int(np.ceil(N / len(sigma_base_1d))))[:N]
        
        mu2_pattern = np.tile(mu2_2d, (1, int(np.ceil(N / mu2_2d.shape[1]))))
        sigma2_pattern = np.tile(sigma2_2d, (1, int(np.ceil(N / sigma2_2d.shape[1]))))
        
        mu2_mat_explicit = mu2_pattern[:, :N]
        sigma2_mat_explicit = sigma2_pattern[:, :N]
        
        # Run with explicit arrays
        result_explicit = laplace_fusion_mean(
            mu_base_explicit, sigma_base_explicit, mu2_mat_explicit, sigma2_mat_explicit
        )
        
        # Compare results
        np.testing.assert_allclose(result_broadcast, result_explicit, rtol=1e-10)
        
        # Test 2: High-dimensional (1800D) with equal means (should be all zeros)
        print("\nTest 2: 1800D with equal means (mu_base=0, mu2=0)")
        
        N = 1800
        mu_base_scalar = 0.0
        sigma_base_scalar = 0.5
        mu2_scalar = 0.0  # Same as mu_base
        sigma2_scalar = 0.7
        
        result = laplace_fusion_mean_broadcast(
            N, mu_base_scalar, sigma_base_scalar, [mu2_scalar], [sigma2_scalar]
        )
        
        # Check that all values are zero (or very close)
        max_abs = np.max(np.abs(result))
        self.assertLess(max_abs, 1e-10, "Result should be all zeros")
        
        # Test 3: High-dimensional (1800D) with symmetric means (should be all zeros)
        print("\nTest 3: 1800D with symmetric means (mu_base=-10, mu2=10)")
        
        N = 1800
        mu_base_scalar = -10.0
        sigma_base_scalar = 0.5
        mu2_scalar = 10.0  # Symmetric around 0
        sigma2_scalar = 0.5  # Same sigma
        
        result = laplace_fusion_mean_broadcast(
            N, mu_base_scalar, sigma_base_scalar, [mu2_scalar], [sigma2_scalar]
        )
        
        # Check that all values are zero (or very close)
        max_abs = np.max(np.abs(result))
        self.assertLess(max_abs, 1e-10, "Result should be all zeros due to symmetry")
        
        # Test 4: High-dimensional (1800D) with different sigmas (should be close to mu_base)
        print("\nTest 4: 1800D with different sigmas (mu_base=-10, mu2=10, sigma_base=0.1, sigma2=0.3)")
        
        N = 1800
        mu_base_scalar = -10.0
        sigma_base_scalar = 0.1  # Smaller sigma (more concentrated)
        mu2_scalar = 10.0
        sigma2_scalar = 0.3  # Larger sigma (more spread out)
        
        result = laplace_fusion_mean_broadcast(
            N, mu_base_scalar, sigma_base_scalar, [mu2_scalar], [sigma2_scalar]
        )
        
        # Calculate statistics
        mean_result = np.mean(result)
        std_result = np.std(result)
        
        # Check that result is closer to mu_base than mu2
        self.assertLess(abs(mean_result - mu_base_scalar), abs(mean_result - mu2_scalar),
                       "Result should be closer to mu_base (-10) than mu2 (10)")
        
        # Test 5: High-dimensional (1800D) with extreme means and small sigmas
        print("\nTest 5: 1800D with extreme means (mu_base=-60, mu2=50, expected=-59.83306056)")
        
        N = 1800
        mu_base_scalar = -60.0
        sigma_base_scalar = 0.1
        mu2_scalar = 50.0
        sigma2_scalar = 0.15
        expected_value = -59.83306056
        
        result = laplace_fusion_mean_broadcast(
            N, mu_base_scalar, sigma_base_scalar, [mu2_scalar], [sigma2_scalar]
        )
        
        # Calculate statistics
        mean_result = np.mean(result)
        
        # Check that result is close to expected value
        self.assertLess(abs(mean_result - expected_value), 1e-5,
                       f"Result {mean_result} should be close to expected value {expected_value}")
        
        # Test 6: Same as Test 5 but with two identical mixture components (result should be the same)
        print("\nTest 6: 1800D with two identical mixture components (should match Test 5)")
        
        N = 1800
        mu_base_scalar = -60.0
        sigma_base_scalar = 0.1
        mu2_array = np.array([50.0, 50.0])  # Two identical components
        sigma2_array = np.array([0.15, 0.15])  # Two identical components
        expected_value = -59.83306056  # Same as Test 5
        
        result_two_components = laplace_fusion_mean_broadcast(
            N, mu_base_scalar, sigma_base_scalar, mu2_array, sigma2_array
        )
        
        # Calculate statistics
        mean_result_two = np.mean(result_two_components)
        
        # Check that result with two identical components is the same as with one component
        self.assertLess(abs(mean_result_two - mean_result), 1e-10,
                       "Result with two identical components should match the result with one component")
        
        # Also check against the expected value
        self.assertLess(abs(mean_result_two - expected_value), 1e-5,
                       f"Result {mean_result_two} should be close to expected value {expected_value}")

    def test_fusion_broadcast_wrapper_edge_cases(self):
        """Test edge cases for the broadcast wrapper function."""
        
        # Test with very high dimensions and multiple patterns
        N = 500
        mu_base_pattern = np.array([0.0, 0.5, 1.0])
        sigma_base_pattern = np.array([0.4, 0.5, 0.6, 0.7])
        mu2_pattern = np.array([[0.2, 0.8], [1.2, 0.3], [0.7, 1.5]])
        sigma2_pattern = np.array([[0.5, 0.6], [0.7, 0.8], [0.6, 0.5]])
        
        result = laplace_fusion_mean_broadcast(
            N, mu_base_pattern, sigma_base_pattern, mu2_pattern, sigma2_pattern
        )
        
        # Basic checks
        self.assertEqual(result.shape, (N,))
        self.assertFalse(np.any(np.isnan(result)))
        self.assertFalse(np.any(np.isinf(result)))
        
        # Test with single values that get broadcast
        N = 200
        result_scalar = laplace_fusion_mean_broadcast(N, 2.0, 0.5, [3.0], [0.7])
        
        # All values should be the same since everything is scalar
        self.assertLess(np.std(result_scalar), 1e-10, "All values should be identical for scalar inputs")
        
        # Test with asymmetric patterns
        N = 50
        mu_base_asym = np.array([1.0, -1.0])
        sigma_base_asym = np.array([0.3, 0.7, 0.5])
        mu2_asym = np.array([[2.0, -2.0, 0.0]])
        sigma2_asym = np.array([[0.4, 0.6, 0.8]])
        
        result_asym = laplace_fusion_mean_broadcast(
            N, mu_base_asym, sigma_base_asym, mu2_asym, sigma2_asym
        )
        
        # Basic checks
        self.assertEqual(result_asym.shape, (N,))
        self.assertFalse(np.any(np.isnan(result_asym)))
        self.assertFalse(np.any(np.isinf(result_asym)))

    def test_abs_expectation_numeric(self):
        """Validate abs_expectation=True against high-precision numerical integration (quad)."""
        def numerical_abs_expectation(mu1, mu2, sigma1, sigma2):
            from math import isfinite
            from scipy.integrate import quad
            def log_unnorm_pdf(x):
                return -abs(x - mu1) / sigma1 - abs(x - mu2) / sigma2 + abs(x)
            def unnorm(x):
                val = np.exp(log_unnorm_pdf(x)) if log_unnorm_pdf(x) > -700 else 0.0
                return val
            def weighted_abs(x):
                return abs(x) * unnorm(x)
            Z, _ = quad(unnorm, -np.inf, np.inf, epsabs=1e-12, epsrel=1e-12, limit=200)
            N, _ = quad(weighted_abs, -np.inf, np.inf, epsabs=1e-12, epsrel=1e-12, limit=200)
            return N / Z
        test_cases = [
            (0.0, 0.0, 0.5, 0.5),
            (1.0, -1.0, 0.3, 0.7),
            (-2.0, 2.0, 1.0, 1.0),
        ]
        for mu1, mu2, sigma1, sigma2 in test_cases:
            analytic = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]),
                                           np.array([[mu2]]), np.array([[sigma2]]),
                                           moment='abs', n_jobs=None)[0]
            numeric = numerical_abs_expectation(mu1, mu2, sigma1, sigma2)
            rel_err = abs(analytic - numeric) / max(numeric, 1e-12)
            self.assertLess(rel_err, 1e-2, msg=f"Mismatch abs expectation for case {(mu1,mu2,sigma1,sigma2)}")
 
    def test_abs_expectation_reference_values(self):
        """Check abs_expectation=True against hard-coded reference values to catch regressions."""
        reference_values = [
            # (mu1, sigma1, mu2, sigma2, expected |x|)
            (0.0, 0.5, 0.0, 0.5, 1.0/3.0),  # reduces to Laplace(0,1/3)
            (0.0, 1.0, 0.0, 1.0, 1.0),  # reduces to Laplace(0,1)
            (0.0, 0.5, 1.0, 0.5, 0.78413268),  # asymmetric example pre-computed
        ]
        for mu1, sigma1, mu2, sigma2, expected in reference_values:
            analytic = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]),
                                           np.array([[mu2]]), np.array([[sigma2]]),
                                           moment='abs', n_jobs=None)[0]
            self.assertAlmostEqual(analytic, expected, places=6,
                                   msg=f"Reference abs expectation mismatch for {(mu1,mu2,sigma1,sigma2)}")

    def test_abs_expectation_numeric_no_unit(self):
        """Validate abs moment with divide_unit_laplace=False via numerical integration."""
        from scipy.integrate import quad
        def numerical_abs(mu1, mu2, sigma1, sigma2):
            def log_unnorm_pdf(x):
                return -abs(x - mu1)/sigma1 - abs(x - mu2)/sigma2  # omit unit Laplace
            def unnorm(x):
                val = np.exp(log_unnorm_pdf(x)) if log_unnorm_pdf(x) > -700 else 0.0
                return val
            def weighted_abs(x):
                return abs(x) * unnorm(x)
            Z, _ = quad(unnorm, -np.inf, np.inf, epsabs=1e-12, epsrel=1e-12, limit=200)
            N, _ = quad(weighted_abs, -np.inf, np.inf, epsabs=1e-12, epsrel=1e-12, limit=200)
            return N / Z

        cases = [
            (0.0, 0.0, 0.5, 0.5),
            (1.0, -1.0, 0.3, 0.7),
            (-2.0, 2.0, 1.0, 1.0),
        ]
        for mu1, mu2, sigma1, sigma2 in cases:
            analytic = laplace_fusion_mean(np.array([mu1]), np.array([sigma1]),
                                           np.array([[mu2]]), np.array([[sigma2]]),
                                           divide_unit_laplace=False,
                                           moment='abs', n_jobs=None)[0]
            numeric = numerical_abs(mu1, mu2, sigma1, sigma2)
            rel_err = abs(analytic - numeric) / max(numeric, 1e-12)
            self.assertLess(rel_err, 1e-2,
                            msg=f"Mismatch abs (no unit Laplace) for {(mu1,mu2,sigma1,sigma2)}")

    def test_relu_expectation_numeric(self):
        """Validate ReLU expectation against high-precision numerical integration (quad)."""
        from task_vae.task_posterior import numerical_1D_fusion_mean

        test_cases = [
            (0.0, 0.0, 0.5, 0.5),
            (1.0, -1.0, 0.3, 0.7),
            (-2.0, 2.0, 1.0, 1.0),
        ]

        for mu1, mu2, sigma1, sigma2 in test_cases:
            analytic = laplace_fusion_mean(
                np.array([mu1]), np.array([sigma1]),
                np.array([[mu2]]), np.array([[sigma2]]),
                moment='relu', n_jobs=None
            )[0]

            numeric = numerical_1D_fusion_mean(
                mu1, mu2, sigma1, sigma2, moment='relu'
            )

            rel_err = abs(analytic - numeric) / max(abs(numeric), 1e-12)
            self.assertLess(
                rel_err, 1e-2,
                msg=f"Mismatch ReLU expectation for case {(mu1, mu2, sigma1, sigma2)}"
            )

    def test_laplace_moment_torch(self):
        """Test laplace_moment_torch for mean and abs moments against numerical integration."""
        from scipy.stats import laplace
        from scipy.integrate import quad
        # Test a few values
        test_params = [
            (0.0, 1.0),
            (1.5, 0.5),
            (-2.0, 2.0),
            (0.0, 0.1),
            (3.0, 1.0),
        ]
        for mu, sigma in test_params:
            dist = torch.distributions.Laplace(torch.tensor([mu]), torch.tensor([sigma]))
            # Mean
            mean_val = laplace_moment(dist, moment='mean').item()
            self.assertAlmostEqual(mean_val, mu, places=7)
            # Abs analytic
            abs_val = laplace_moment(dist, moment='abs').item()
            # Numerical integration for E[|X|]
            def abs_integrand(x):
                return abs(x) * laplace.pdf(x, loc=mu, scale=sigma)
            num_abs, _ = quad(abs_integrand, mu-20*sigma, mu+20*sigma, epsabs=1e-10, epsrel=1e-10)
            self.assertAlmostEqual(abs_val, num_abs, places=6)

    def test_relu_threshold_expectation(self):
        """Validate analytic ReLU threshold expectation against numerical quadrature."""
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[0.5]])
        sigma2_mat = np.array([[1.0]])
        threshold_val = 1.0

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        from task_vae.task_posterior import numerical_1D_fusion_relu_threshold

        numeric = numerical_1D_fusion_relu_threshold(
            mu_base[0],
            mu2_mat[0, 0],
            sigma_base[0],
            sigma2_mat[0, 0],
            threshold_val,
        )
        print(analytic, numeric)
        self.assertAlmostEqual(analytic, numeric, places=6)

    def test_laplace_moment_relu_threshold(self):
        """Test laplace_moment for ReLU moment with varied thresholds vs numerical integration."""
        from scipy.integrate import quad
        test_params = [
            (0.0, 1.0, -1.0),  # x0 < mu
            (0.0, 1.0, 0.0),   # x0 == mu
            (0.0, 1.0, 1.0),   # x0 > mu
            (2.0, 0.5, 1.0),   # x0 < mu
            (2.0, 0.5, 3.0),   # x0 > mu
        ]
        for mu, sigma, x0 in test_params:
            dist = torch.distributions.Laplace(torch.tensor([mu]), torch.tensor([sigma]))
            analytic = laplace_moment(dist, moment='relu', threshold=x0).item()

            # Numerical integration
            def laplace_pdf(x):
                return (1.0/(2.0*sigma)) * np.exp(-abs(x - mu)/sigma)
            def relu_integrand(x):
                return max(0.0, x - x0) * laplace_pdf(x)

            # integrate reasonably wide domain
            upper = mu + 20*sigma + max(0, x0-mu)
            lower = mu - 20*sigma
            numeric, _ = quad(relu_integrand, lower, upper, epsabs=1e-10, epsrel=1e-10)
            # Because p integrates to 1, expectation value is numeric itself
            self.assertAlmostEqual(analytic, numeric, places=6,
                                   msg=f"Mismatch for mu={mu}, sigma={sigma}, x0={x0}")

    def test_relu_threshold_expectation_multi_component(self):
        """Validate analytic ReLU threshold expectation with multiple components against numerical quadrature."""
        from task_vae.task_posterior import numerical_1D_fusion_relu_threshold_multi_component
        
        # Test case 1: Two symmetric components (should be similar to single component)
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-0.5], [0.5]])  # Two symmetric components
        sigma2_mat = np.array([[1.0], [1.0]])
        threshold_val = 1.0

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Two symmetric components: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 2: Three asymmetric components
        mu_base = np.array([1.0])
        sigma_base = np.array([0.8])
        mu2_mat = np.array([[-1.0], [0.0], [2.0]])  # Three asymmetric components
        sigma2_mat = np.array([[0.7], [1.2], [0.9]])
        threshold_val = 0.5

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Three asymmetric components: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 3: Multiple components with different scales
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-2.0], [-0.5], [0.5], [2.0]])  # Four components
        sigma2_mat = np.array([[0.5], [0.8], [1.2], [0.6]])  # Different scales
        threshold_val = -0.5

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Four components with different scales: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 4: Edge case - threshold below all means
        mu_base = np.array([2.0])
        sigma_base = np.array([0.5])
        mu2_mat = np.array([[1.0], [3.0], [4.0]])  # All means above threshold
        sigma2_mat = np.array([[0.8], [0.6], [1.0]])
        threshold_val = -1.0  # Threshold below all means

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Threshold below all means: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 5: Identical components (should be equivalent to single component)
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[1.0], [1.0], [1.0]])  # Three identical components
        sigma2_mat = np.array([[0.8], [0.8], [0.8]])
        threshold_val = 0.5

        analytic_multi = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        # Compare with single component result
        mu2_single = np.array([[1.0]])
        sigma2_single = np.array([[0.8]])
        
        analytic_single = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_single,
            sigma2_single,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]
        
        self.assertAlmostEqual(analytic_multi, analytic_single, places=10,
                              msg=f"Multiple identical components should equal single component: "
                                  f"multi={analytic_multi:.6f}, single={analytic_single:.6f}")

    def test_relu_threshold_expectation_multi_component_edge_cases(self):
        """Test edge cases for ReLU threshold expectation with multiple components."""
        from task_vae.task_posterior import numerical_1D_fusion_relu_threshold_multi_component
        
        # Test case 1: Very small scales (numerically challenging)
        mu_base = np.array([0.0])
        sigma_base = np.array([0.1])
        mu2_mat = np.array([[-0.1], [0.1]])
        sigma2_mat = np.array([[0.05], [0.15]])
        threshold_val = 0.0

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        # Basic sanity checks
        self.assertIsInstance(analytic, float)
        self.assertFalse(np.isnan(analytic))
        self.assertFalse(np.isinf(analytic))
        self.assertGreaterEqual(analytic, 0.0)  # ReLU should be non-negative
        
        # Test case 2: Large separation between components
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-10.0], [10.0]])  # Large separation
        sigma2_mat = np.array([[0.5], [0.5]])
        threshold_val = 0.0

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        # Basic sanity checks
        self.assertIsInstance(analytic, float)
        self.assertFalse(np.isnan(analytic))
        self.assertFalse(np.isinf(analytic))
        self.assertGreaterEqual(analytic, 0.0)
        
        # Test case 3: Many components
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-2.0], [-1.0], [0.0], [1.0], [2.0]])  # Five components
        sigma2_mat = np.array([[0.8], [0.9], [1.0], [0.9], [0.8]])
        threshold_val = 0.5

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        # Basic sanity checks
        self.assertIsInstance(analytic, float)
        self.assertFalse(np.isnan(analytic))
        self.assertFalse(np.isinf(analytic))
        self.assertGreaterEqual(analytic, 0.0)
        
        # Test case 4: Zero threshold (should be equivalent to regular ReLU)
        mu_base = np.array([1.0])
        sigma_base = np.array([0.8])
        mu2_mat = np.array([[-0.5], [0.5], [1.5]])
        sigma2_mat = np.array([[0.6], [1.0], [0.7]])
        threshold_val = 0.0

        analytic_threshold = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        analytic_relu = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=0.0,  # Default threshold
        )[0]
        
        self.assertAlmostEqual(analytic_threshold, analytic_relu, places=10,
                              msg=f"Zero threshold should equal default ReLU: "
                                  f"threshold={analytic_threshold:.6f}, relu={analytic_relu:.6f}")

# -----------------------------------------------------------------------------
#               NEW TESTS FOR divide_unit_laplace = False OPTION
# -----------------------------------------------------------------------------
        """Validate analytic ReLU threshold expectation with multiple components against numerical quadrature."""
        from task_vae.task_posterior import numerical_1D_fusion_relu_threshold_multi_component
        
        # Test case 1: Two symmetric components (should be similar to single component)
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-0.5], [0.5]])  # Two symmetric components
        sigma2_mat = np.array([[1.0], [1.0]])
        threshold_val = 1.0

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Two symmetric components: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 2: Three asymmetric components
        mu_base = np.array([1.0])
        sigma_base = np.array([0.8])
        mu2_mat = np.array([[-1.0], [0.0], [2.0]])  # Three asymmetric components
        sigma2_mat = np.array([[0.7], [1.2], [0.9]])
        threshold_val = 0.5

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Three asymmetric components: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 3: Multiple components with different scales
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-2.0], [-0.5], [0.5], [2.0]])  # Four components
        sigma2_mat = np.array([[0.5], [0.8], [1.2], [0.6]])  # Different scales
        threshold_val = -0.5

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Four components with different scales: analytic={analytic:.6f}, numeric={numeric:.6f}")
        
        # Test case 4: Edge case - threshold below all means
        mu_base = np.array([2.0])
        sigma_base = np.array([0.5])
        mu2_mat = np.array([[1.0], [3.0], [4.0]])  # All means above threshold
        sigma2_mat = np.array([[0.8], [0.6], [1.0]])
        threshold_val = -1.0  # Threshold below all means

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        numeric = numerical_1D_fusion_relu_threshold_multi_component(
            mu_base[0],
            mu2_mat.flatten(),
            sigma_base[0],
            sigma2_mat.flatten(),
            threshold_val,
        )
        
        self.assertAlmostEqual(analytic, numeric, places=5,
                              msg=f"Threshold below all means: analytic={analytic:.6f}, numeric={numeric:.6f}")
        

        
        # Test case 5: Identical components (should be equivalent to single component)
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[1.0], [1.0], [1.0]])  # Three identical components
        sigma2_mat = np.array([[0.8], [0.8], [0.8]])
        threshold_val = 0.5

        analytic_multi = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]

        # Compare with single component result
        mu2_single = np.array([[1.0]])
        sigma2_single = np.array([[0.8]])
        
        analytic_single = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_single,
            sigma2_single,
            divide_unit_laplace=True,
            moment="relu",
            threshold=threshold_val,
        )[0]
        
        self.assertAlmostEqual(analytic_multi, analytic_single, places=10,
                              msg=f"Multiple identical components should equal single component: "
                                  f"multi={analytic_multi:.6f}, single={analytic_single:.6f}")

# -----------------------------------------------------------------------------
#               NEW TESTS FOR divide_unit_laplace = False OPTION
# -----------------------------------------------------------------------------


class TestTaskPosteriorNoUnitLaplace(unittest.TestCase):
    """Tests for the new behaviour with divide_unit_laplace=False."""

    def _numerical_mean_product_laplace(self, mu1, b1, mu2, b2):
        """Numerically compute the mean of the product of two Laplace pdfs.

        The un-normalized density is g(x) = Lap(mu1,b1) * Lap(mu2,b2).
        The function returns ∫ x g(x) dx / ∫ g(x) dx  using quad integration.
        """
        from scipy.integrate import quad
        import math

        def laplace_pdf(x, mu, b):
            return (1.0 / (2.0 * b)) * math.exp(-abs(x - mu) / b)

        def product_pdf(x):
            return laplace_pdf(x, mu1, b1) * laplace_pdf(x, mu2, b2)

        # Compute denominator and numerator
        denom = quad(product_pdf, -math.inf, math.inf, epsabs=1e-10, epsrel=1e-10)[0]
        num = quad(lambda x: x * product_pdf(x), -math.inf, math.inf, epsabs=1e-10, epsrel=1e-10)[0]
        return num / denom

    def test_symmetry_same_means(self):
        """When both Laplace factors share the same mean, the fused mean should equal that value."""
        mu = 3.0
        sigma1 = 0.5
        sigma2 = 2.0

        analytic = laplace_fusion_mean(
            np.array([mu]),
            np.array([sigma1]),
            np.array([[mu]]),
            np.array([[sigma2]]),
            divide_unit_laplace=False,
        )[0]

        self.assertAlmostEqual(analytic, mu, places=10)

    def test_analytic_vs_numerical_one_component(self):
        """Compare analytic result with high-precision numerical integration for a 1-component case."""
        mu1, sigma1 = 0.0, 1.0
        mu2, sigma2 = 2.0, 1.5

        analytic = laplace_fusion_mean(
            np.array([mu1]),
            np.array([sigma1]),
            np.array([[mu2]]),
            np.array([[sigma2]]),
            divide_unit_laplace=False,
        )[0]

        numerical = self._numerical_mean_product_laplace(mu1, sigma1, mu2, sigma2)

        rel_diff = abs(analytic - numerical) / max(abs(numerical), 1e-10)
        # The analytic computation should match the numerical integration within 1e-3 relative tolerance.
        self.assertLess(rel_diff, 1e-3)

    def test_component_weights_analytical_vs_numerical(self):
        """Compare analytical solution with component weights to numerical integration."""
        from task_vae.task_posterior import numerical_1D_fusion_mean_with_weights
        
        # Test case 1: Two components with asymmetric weights
        mu_base = 0.0
        sigma_base = 1.0
        mu2_array = np.array([-2.0, 2.0])
        sigma2_array = np.array([1.0, 1.0])
        weights = np.array([3.0, 1.0])
        
        # Analytical solution
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            component_weights=weights,
        )[0]
        
        # Numerical solution
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_array, sigma_base, sigma2_array, weights
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01, 
                       f"Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")
        
        # Test case 2: Three components with different weights
        mu_base = 1.0
        sigma_base = 0.8
        mu2_array = np.array([-1.0, 0.0, 3.0])
        sigma2_array = np.array([0.6, 1.0, 0.7])
        weights = np.array([2.0, 1.0, 4.0])
        
        # Analytical solution
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            component_weights=weights,
        )[0]
        
        # Numerical solution
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_array, sigma_base, sigma2_array, weights
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01,
                       f"Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")
        
        # Test case 3: Abs moment with weights
        mu_base = 0.0
        sigma_base = 1.0
        mu2_array = np.array([-1.5, 1.5])
        sigma2_array = np.array([0.8, 0.8])
        weights = np.array([1.0, 2.0])
        
        # Analytical solution
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            component_weights=weights,
            moment="abs",
        )[0]
        
        # Numerical solution
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_array, sigma_base, sigma2_array, weights, moment="abs"
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01,
                       f"Abs moment: Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")
        
        # Test case 4: ReLU moment with weights
        mu_base = 0.5
        sigma_base = 0.7
        mu2_array = np.array([-0.5, 1.0, 2.0])
        sigma2_array = np.array([0.6, 0.8, 0.9])
        weights = np.array([1.0, 3.0, 1.0])
        
        # Analytical solution
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            component_weights=weights,
            moment="relu",
        )[0]
        
        # Numerical solution
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_array, sigma_base, sigma2_array, weights, moment="relu"
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01,
                       f"ReLU moment: Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")
        
        # Test case 5: Zero weights (should effectively remove components)
        mu_base = 0.0
        sigma_base = 1.0
        mu2_array = np.array([-2.0, 2.0, 5.0])
        sigma2_array = np.array([1.0, 1.0, 1.0])
        weights = np.array([1.0, 0.0, 2.0])  # Middle component has zero weight
        
        # Analytical solution
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            component_weights=weights,
        )[0]
        
        # Numerical solution
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_array, sigma_base, sigma2_array, weights
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01,
                       f"Zero weights: Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")
        
        # Verify that zero-weight component is effectively ignored
        # by comparing with a version that excludes the zero-weight component
        mu2_array_filtered = mu2_array[weights > 0]
        sigma2_array_filtered = sigma2_array[weights > 0]
        weights_filtered = weights[weights > 0]
        
        analytic_filtered = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array_filtered.reshape(-1, 1),
            sigma2_array_filtered.reshape(-1, 1),
            n_jobs=None,
            component_weights=weights_filtered,
        )[0]
        
        self.assertAlmostEqual(analytic, analytic_filtered, places=10,
                              msg="Zero-weight component should be ignored")

    def test_natural_prior_weight(self):
        """Test that natural_prior_weight correctly adds a unit Laplace component."""
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[2.0], [-1.0]])
        sigma2_mat = np.array([[0.8], [0.6]])
        
        # Test without natural prior weight
        result_no_prior = laplace_fusion_mean(
            mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None
        )[0]
        
        # Test with natural prior weight
        result_with_prior = laplace_fusion_mean(
            mu_base, sigma_base, mu2_mat, sigma2_mat, 
            n_jobs=None, natural_prior_weight=1.0
        )[0]
        
        # The result with natural prior should be different (closer to zero)
        # because the unit Laplace component pulls the mean toward zero
        self.assertNotAlmostEqual(result_no_prior, result_with_prior, places=6)
        
        # Test with larger natural prior weight - should pull even more toward zero
        result_large_prior = laplace_fusion_mean(
            mu_base, sigma_base, mu2_mat, sigma2_mat, 
            n_jobs=None, natural_prior_weight=5.0
        )[0]
        
        # Larger weight should pull the mean closer to zero
        self.assertLess(abs(result_large_prior), abs(result_with_prior))
        
        # Test that natural_prior_weight works with component_weights
        weights = np.array([2.0, 1.0])
        result_weighted_prior = laplace_fusion_mean(
            mu_base, sigma_base, mu2_mat, sigma2_mat, 
            n_jobs=None, component_weights=weights, natural_prior_weight=1.0
        )[0]
        
        # Should be different from the unweighted version
        self.assertNotAlmostEqual(result_with_prior, result_weighted_prior, places=6)
        
        # Test error handling for non-positive natural_prior_weight
        with self.assertRaises(ValueError):
            laplace_fusion_mean(
                mu_base, sigma_base, mu2_mat, sigma2_mat, 
                n_jobs=None, natural_prior_weight=0.0
            )
        
        with self.assertRaises(ValueError):
            laplace_fusion_mean(
                mu_base, sigma_base, mu2_mat, sigma2_mat, 
                n_jobs=None, natural_prior_weight=-1.0
            )

    def test_natural_prior_weight_analytical_vs_numerical(self):
        """Compare analytical solution with natural_prior_weight to numerical integration."""
        from task_vae.task_posterior import numerical_1D_fusion_mean_with_weights
        
        # Test case: Add natural prior weight to existing components
        mu_base = 0.0
        sigma_base = 1.0
        mu2_array = np.array([-1.5, 2.0])
        sigma2_array = np.array([0.8, 1.2])
        natural_prior_weight = 2.0
        
        # Analytical solution with natural_prior_weight
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            natural_prior_weight=natural_prior_weight,
        )[0]
        
        # Numerical solution: manually add unit Laplace component
        # Create extended arrays with the unit Laplace component
        mu2_extended = np.append(mu2_array, 0.0)  # Zero mean for unit Laplace
        sigma2_extended = np.append(sigma2_array, 1.0)  # Unit scale for unit Laplace
        weights_extended = np.array([1.0, 1.0, natural_prior_weight])  # Equal weights + natural prior
        
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_extended, sigma_base, sigma2_extended, weights_extended
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01, 
                       f"Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")
        
        # Test case 2: Natural prior weight with existing component weights
        mu_base = 1.0
        sigma_base = 0.7
        mu2_array = np.array([-0.5, 1.5])
        sigma2_array = np.array([0.6, 0.9])
        component_weights = np.array([2.0, 1.0])
        natural_prior_weight = 1.5
        
        # Analytical solution
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            mu2_array.reshape(-1, 1),
            sigma2_array.reshape(-1, 1),
            n_jobs=None,
            component_weights=component_weights,
            natural_prior_weight=natural_prior_weight,
        )[0]
        
        # Numerical solution
        mu2_extended = np.append(mu2_array, 0.0)
        sigma2_extended = np.append(sigma2_array, 1.0)
        weights_extended = np.append(component_weights, natural_prior_weight)
        
        numerical = numerical_1D_fusion_mean_with_weights(
            mu_base, mu2_extended, sigma_base, sigma2_extended, weights_extended
        )
        
        # Compare results
        rel_diff = abs(analytic - numerical) / max(abs(analytic), 1e-10)
        self.assertLess(rel_diff, 0.01,
                       f"Analytical ({analytic:.6f}) vs numerical ({numerical:.6f}) mismatch")

class TestTaskPosteriorParallelization(unittest.TestCase):
    """Test cases for parallelization functionality of laplace_fusion_mean."""
    
    def test_parallel_correctness(self):
        """Test that parallel and serial versions give the same results."""
        # Test parameters
        N = 50  # dimensions
        M = 3   # mixture components
        
        # Generate random test data with reasonable values to avoid divergent tails
        np.random.seed(42)
        mu_base = np.random.randn(N) * 0.5  # Smaller means
        sigma_base = np.abs(np.random.randn(N)) * 0.5 + 0.5  # Reasonable scales
        mu2_mat = np.random.randn(M, N) * 0.5  # Smaller means
        sigma2_mat = np.abs(np.random.randn(M, N)) * 0.5 + 0.5  # Reasonable scales
        
        # Test both versions
        result_serial = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None)
        result_parallel = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=4)
        
        # Check if results are close
        diff = np.abs(result_serial - result_parallel)
        max_diff = np.max(diff)
        mean_diff = np.mean(diff)
        
        self.assertLess(max_diff, 1e-10, f"Max difference {max_diff:.2e} exceeds tolerance")
        self.assertLess(mean_diff, 1e-10, f"Mean difference {mean_diff:.2e} exceeds tolerance")
        
        # Test with return_component_norms=True
        result_serial_full = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, return_component_norms=True, n_jobs=None)
        result_parallel_full = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, return_component_norms=True, n_jobs=4)
        
        diff_mean = np.abs(result_serial_full[0] - result_parallel_full[0])
        diff_norms = np.abs(result_serial_full[1] - result_parallel_full[1])
        diff_norms_per_dim = np.abs(result_serial_full[2] - result_parallel_full[2])
        
        self.assertLess(np.max(diff_mean), 1e-10, "Mean results don't match")
        self.assertLess(np.max(diff_norms), 1e-10, "Component norms don't match")
        self.assertLess(np.max(diff_norms_per_dim), 1e-10, "Component norms per dimension don't match")

    def test_broadcast_parallel_correctness(self):
        """Test that broadcast versions work correctly with parallelization."""
        N = 100
        M = 4
        
        # Test with scalar parameters
        mu_base_1d = 0.0
        sigma_base_1d = 1.0
        mu2_m = np.array([1.0, -1.0, 2.0, -2.0])
        sigma2_m = np.array([0.5, 0.5, 1.0, 1.0])
        
        # Test both versions
        result_serial = laplace_fusion_mean_broadcast(N, mu_base_1d, sigma_base_1d, mu2_m, sigma2_m, n_jobs=None)
        result_parallel = laplace_fusion_mean_broadcast(N, mu_base_1d, sigma_base_1d, mu2_m, sigma2_m, n_jobs=4)
        
        diff = np.abs(result_serial - result_parallel)
        max_diff = np.max(diff)
        
        self.assertLess(max_diff, 1e-10, f"Broadcast versions don't match: max_diff={max_diff:.2e}")

    def test_different_n_jobs(self):
        """Test performance with different numbers of jobs."""
        N = 100
        M = 3
        
        # Generate test data with reasonable values
        np.random.seed(42)
        mu_base = np.random.randn(N) * 0.5
        sigma_base = np.abs(np.random.randn(N)) * 0.5 + 0.5
        mu2_mat = np.random.randn(M, N) * 0.5
        sigma2_mat = np.abs(np.random.randn(M, N)) * 0.5 + 0.5
        
        # Test serial version
        result_serial = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None)
        
        # Test different numbers of jobs
        for n_jobs in [1, 2, 4, 8, -1]:
            result_parallel = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=n_jobs)
            
            # Check results match
            diff = np.abs(result_serial - result_parallel)
            max_diff = np.max(diff)
            self.assertLess(max_diff, 1e-10, f"Results don't match for n_jobs={n_jobs}")

    def test_parallel_performance_benchmark(self):
        """Benchmark performance of serial vs parallel versions."""
        # Test different problem sizes
        test_cases = [
            (10, 2),    # Small problem
            (50, 3),    # Medium problem
            (100, 5),   # Large problem
            (200, 8),   # Very large problem
        ]
        
        for N, M in test_cases:
            # Generate test data with reasonable values
            np.random.seed(42)
            mu_base = np.random.randn(N) * 0.5
            sigma_base = np.abs(np.random.randn(N)) * 0.5 + 0.5
            mu2_mat = np.random.randn(M, N) * 0.5
            sigma2_mat = np.abs(np.random.randn(M, N)) * 0.5 + 0.5
            
            # Benchmark serial version
            start_time = time.time()
            result_serial = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None)
            serial_time = time.time() - start_time
            
            # Benchmark parallel version
            start_time = time.time()
            result_parallel = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=4)
            parallel_time = time.time() - start_time
            
            # Check results match
            diff = np.abs(result_serial - result_parallel)
            max_diff = np.max(diff)
            self.assertLess(max_diff, 1e-10, f"Results don't match for N={N}, M={M}")
            
            # For large problems, parallel should be faster
            if N >= 100:
                speedup = serial_time / parallel_time
                self.assertGreater(speedup, 1.0, f"Parallel not faster for N={N}, M={M}: speedup={speedup:.2f}")

    def test_edge_cases_parallel(self):
        """Test edge cases with parallelization."""
        # Test single dimension
        result_serial = laplace_fusion_mean(np.array([0.5]), np.array([1.0]), np.array([[0.0]]), np.array([[1.0]]), n_jobs=None)
        result_parallel = laplace_fusion_mean(np.array([0.5]), np.array([1.0]), np.array([[0.0]]), np.array([[1.0]]), n_jobs=4)
        self.assertAlmostEqual(result_serial[0], result_parallel[0], places=10)
        
        # Test single component
        result_serial = laplace_fusion_mean(np.array([0.5, 1.0]), np.array([1.0, 1.0]), np.array([[0.0, 0.5]]), np.array([[1.0, 1.0]]), n_jobs=None)
        result_parallel = laplace_fusion_mean(np.array([0.5, 1.0]), np.array([1.0, 1.0]), np.array([[0.0, 0.5]]), np.array([[1.0, 1.0]]), n_jobs=4)
        np.testing.assert_allclose(result_serial, result_parallel, rtol=1e-10, atol=1e-10)

    def test_automatic_selection(self):
        """Test automatic selection between serial and parallel."""
        # Small problem - should work with both serial and parallel
        N_small = 20
        M_small = 2
        
        np.random.seed(42)
        mu_base_small = np.random.randn(N_small) * 0.5
        sigma_base_small = np.abs(np.random.randn(N_small)) * 0.5 + 0.5
        mu2_mat_small = np.random.randn(M_small, N_small) * 0.5
        sigma2_mat_small = np.abs(np.random.randn(M_small, N_small)) * 0.5 + 0.5
        
        result_small_serial = laplace_fusion_mean(mu_base_small, sigma_base_small, mu2_mat_small, sigma2_mat_small, n_jobs=None)
        result_small_parallel = laplace_fusion_mean(mu_base_small, sigma_base_small, mu2_mat_small, sigma2_mat_small, n_jobs=4)
        
        np.testing.assert_allclose(result_small_serial, result_small_parallel, rtol=1e-10, atol=1e-10)
        
        # Large problem - should work with both serial and parallel
        N_large = 200
        M_large = 5
        
        mu_base_large = np.random.randn(N_large) * 0.5
        sigma_base_large = np.abs(np.random.randn(N_large)) * 0.5 + 0.5
        mu2_mat_large = np.random.randn(M_large, N_large) * 0.5
        sigma2_mat_large = np.abs(np.random.randn(M_large, N_large)) * 0.5 + 0.5
        
        result_large_serial = laplace_fusion_mean(mu_base_large, sigma_base_large, mu2_mat_large, sigma2_mat_large, n_jobs=None)
        result_large_parallel = laplace_fusion_mean(mu_base_large, sigma_base_large, mu2_mat_large, sigma2_mat_large, n_jobs=4)
        
        np.testing.assert_allclose(result_large_serial, result_large_parallel, rtol=1e-10, atol=1e-10)


# -----------------------------------------------------------------------------
#         NEW TESTS FOR GENERALIZED DENOMINATOR (NON-UNIT LAPLACE)             
# -----------------------------------------------------------------------------


class TestGeneralizedDenominator(unittest.TestCase):
    """Tests for scenarios where we divide by Laplace(µ_div, σ_div) instead of unit Laplace."""

    def _run_numeric_check(self, mu_base, sigma_base, mu2, sigma2, mu_div, sigma_div, moment="mean"):
        # Analytic
        analytic = laplace_fusion_mean(
            np.array([mu_base]),
            np.array([sigma_base]),
            np.array([[mu2]]),
            np.array([[sigma2]]),
            moment=moment,
            mu_div=np.array([mu_div]),
            sigma_div=np.array([sigma_div]),
        )[0]

        # Numerical 1-D integration
        from task_vae.task_posterior import numerical_1D_fusion_mean

        numeric = numerical_1D_fusion_mean(
            mu_base,
            mu2,
            sigma_base,
            sigma2,
            mu_div=mu_div,
            sigma_div=sigma_div,
            moment=moment,
        )

        self.assertAlmostEqual(analytic, numeric, delta=2e-4,
                               msg="Mismatch between analytic and numerical 1-D check")

    def test_generalized_denominator_mean_and_scale(self):
        self._run_numeric_check(
            mu_base=0.3,
            sigma_base=1.2,
            mu2=2.0,
            sigma2=0.7,
            mu_div=1.0,
            sigma_div=2.0,
            moment="mean",
        )

    def test_generalized_denominator_alternative_parameters(self):
        self._run_numeric_check(
            mu_base=1.5,
            sigma_base=0.9,
            mu2=-0.7,
            sigma2=1.3,
            mu_div=-1.2,
            sigma_div=0.8,
            moment="mean",
        )

    def test_generalized_denominator_abs_moment(self):
        self._run_numeric_check(
            mu_base=0.2,
            sigma_base=0.6,
            mu2=1.8,
            sigma2=1.1,
            mu_div=0.5,
            sigma_div=0.7,
            moment="abs",
        )

    def test_generalized_denominator_relu_moment(self):
        """Check ReLU moment with non-default denominator parameters against numerical integration."""
        self._run_numeric_check(
            mu_base=0.7,
            sigma_base=1.1,
            mu2=-1.4,
            sigma2=0.9,
            mu_div=0.2,
            sigma_div=1.3,
            moment="relu",
        )


class TestGeneralizedDenominator2D(unittest.TestCase):
    """2-D validation against high-precision double integration."""

    def test_vector_denominator_two_dimensions(self):
        mu_base = np.array([0.3, -0.5])
        sigma_base = np.array([1.0, 0.8])

        mu2_mat = np.array([[1.5, -0.2], [-1.0, 2.0]])
        sigma2_mat = np.array([[0.9, 1.1], [0.7, 1.4]])

        mu_div = np.array([0.1, -0.3])
        sigma_div = np.array([1.3, 0.6])

        analytic = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            mu_div=mu_div,
            sigma_div=sigma_div,
        )

        from task_vae.task_posterior import numerical_2D_fusion_mean

        numeric = numerical_2D_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            mu_div,
            sigma_div,
        )

        for a, n in zip(analytic, numeric):
            self.assertAlmostEqual(a, n, delta=2e-2,
                                   msg=f"2-D analytic ({a}) vs numeric ({n}) mismatch")

    def test_component_weights_influence(self):
        """Verify that asymmetric component weights shift the fused mean in the expected direction."""
        mu_base = np.array([0.0])
        sigma_base = np.array([1.0])
        mu2_mat = np.array([[-2.0], [2.0]])
        sigma2_mat = np.array([[1.0], [1.0]])
 
        # Equal weights should give symmetric mean ≈ 0
        mean_equal = laplace_fusion_mean(mu_base, sigma_base, mu2_mat, sigma2_mat, n_jobs=None)[0]
        self.assertAlmostEqual(mean_equal, 0.0, places=6)
 
        # Heavier weight on the negative component should shift mean negative
        weights = np.array([3.0, 1.0])
        mean_weighted = laplace_fusion_mean(
            mu_base,
            sigma_base,
            mu2_mat,
            sigma2_mat,
            n_jobs=None,
            component_weights=weights,
        )[0]
        self.assertLess(mean_weighted, mean_equal)

class TestAdjustedMixture(unittest.TestCase):
    """Basic sanity checks for the Gaussian `gaussian_mixture_fusion` helper."""

    def _check_mixture_consistency(self, base_mu, base_var, comp_mus, comp_vars):
        import torch
        mixture, weights, means_prime, var_prime = gaussian_mixture_fusion(
            torch.tensor(base_mu),
            torch.tensor(base_var),
            torch.tensor(comp_mus),
            torch.tensor(comp_vars),
        )

        # 1. Weights sum to one and are positive.
        self.assertAlmostEqual(weights.sum().item(), 1.0, places=6)
        self.assertTrue(torch.all(weights > 0))

        # 2. Mixture mean matches explicit weighted sum of component means.
        expected_mean = (weights.unsqueeze(-1) * means_prime).sum(0)
        torch.testing.assert_close(mixture.mean, expected_mean, rtol=1e-5, atol=1e-7)

        # 3. Variance matches law of total variance formula.
        # total_var = Σ w_i (Σ'_i + (μ'_i - μ)^2)
        centered_sq = (means_prime - expected_mean).pow(2)
        expected_var = (weights.unsqueeze(-1) * (var_prime + centered_sq)).sum(0)
        torch.testing.assert_close(mixture.variance, expected_var, rtol=1e-5, atol=1e-7)

    def test_one_dimensional_symmetric(self):
        # Symmetric set-up should give mean ≈ 0
        base_mu = [0.0]
        base_var = [1.0]
        comp_mus = [[1.0], [-1.0]]
        comp_vars = [[1.0], [1.0]]
        self._check_mixture_consistency(base_mu, base_var, comp_mus, comp_vars)

    def test_two_dimensional_generic(self):
        base_mu = [0.5, -0.25]
        base_var = [0.8, 1.2]
        comp_mus = [[1.0, -0.5], [-1.0, 0.75], [0.0, 0.0]]
        comp_vars = [[0.9, 0.7], [1.1, 0.6], [0.5, 1.3]]
        self._check_mixture_consistency(base_mu, base_var, comp_mus, comp_vars)

    def test_numerical_vs_analytical_1d(self):
        """Compare analytical Gaussian fusion mean against numerical quadrature (1-D, 2 components)."""
        import torch, math
        from task_vae.task_posterior import numerical_1D_gaussian_fusion_mean

        base_mu = 0.3
        base_var = 0.7
        comp_mus = torch.tensor([[1.2], [-0.8]])
        comp_vars = torch.tensor([[0.9], [1.1]])

        # Analytical
        mixture, weights, means_prime, var_prime = gaussian_mixture_fusion(
            torch.tensor([base_mu]),
            torch.tensor([base_var]),
            comp_mus,
            comp_vars,
        )
        analytic_mean = mixture.mean.item()

        # Numerical
        numerical_mean = numerical_1D_gaussian_fusion_mean(
            mu_base=base_mu,
            mu2_array=[1.2, -0.8],
            sigma_base=math.sqrt(base_var),
            sigma2_array=[math.sqrt(0.9), math.sqrt(1.1)],
        )
        print(f"weights: {weights}")
        print(f"means_prime: {means_prime}")
        print(f"var_prime: {var_prime}")
        self.assertAlmostEqual(analytic_mean, numerical_mean, places=4)

if __name__ == '__main__':
    unittest.main() 