"""
Unit tests for CRN-level correctness validation.

Verifies that:
1. CRN simulations match BP marginals
2. Reduced CRN matches original CRN on surviving variables
"""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph
from inference import run_bp
from crn import compile_factor_graph_to_crn, simulate_crn
from reduction import from_factor_graph, reduce_to_core_spb, to_factor_graph_if_possible
from benchmarks.benchmark_runner import max_marginal_diff_between_dicts, _normalize


def test_crn_vs_bp_chain_5():
    """
    Test that CRN simulation matches BP for a chain of length 5.
    This is a regression test to protect against future changes.
    """
    # Create a simple chain factor graph
    fg = FactorGraph("test_chain")
    variables = []
    for i in range(5):
        v = fg.add_variable(Variable(f"x{i}", [0, 1]))
        variables.append(v)
    
    # Binary factors with interesting potentials
    np.random.seed(42)
    for i in range(4):
        table = np.random.rand(2, 2) + 0.1  # Ensure positive
        fg.add_factor(Factor(f"f{i}_{i+1}", [variables[i], variables[i+1]], table))
    
    # Unary factors at endpoints
    fg.add_factor(Factor("u0", [variables[0]], np.array([0.7, 0.3])))
    fg.add_factor(Factor("u4", [variables[4]], np.array([0.4, 0.6])))
    
    # Run BP
    bp_result = run_bp(fg, tolerance=1e-10, max_iterations=1000)
    assert bp_result.converged, "BP should converge on tree"
    
    # Compile and simulate CRN
    crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    sim_result = simulate_crn(crn, t_end=5000, n_points=200)
    
    # Compare marginals
    var_names = [v.name for v in fg.variables]
    bp_margs = {v: bp_result.get_marginal(v) for v in var_names}
    crn_margs = sim_result.marginals
    
    diff = max_marginal_diff_between_dicts(bp_margs, crn_margs, var_names)
    
    print(f"CRN vs BP max diff: {diff:.2e}")
    assert diff < 1e-6, f"CRN should match BP within 1e-6, got {diff}"


def test_reduction_preserves_crn_marginals():
    """
    Test that SP-B reduction preserves CRN marginals on surviving variables.
    """
    # Create chain with endpoints
    fg = FactorGraph("test_reduction")
    variables = []
    for i in range(5):
        v = fg.add_variable(Variable(f"x{i}", [0, 1]))
        variables.append(v)
    
    # Binary factors
    np.random.seed(123)
    for i in range(4):
        table = np.random.rand(2, 2) + 0.1
        fg.add_factor(Factor(f"f{i}_{i+1}", [variables[i], variables[i+1]], table))
    
    # Unary factors
    fg.add_factor(Factor("u0", [variables[0]], np.array([0.8, 0.2])))
    fg.add_factor(Factor("u4", [variables[4]], np.array([0.3, 0.7])))
    
    # Compile original CRN
    orig_crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    orig_sim = simulate_crn(orig_crn, t_end=5000, n_points=200)
    
    # Reduce and compile
    poset = from_factor_graph(fg)
    reduce_to_core_spb(poset)
    reduced_fg = to_factor_graph_if_possible(poset)
    
    assert reduced_fg is not None, "Should reduce to non-trivial graph"
    
    reduced_crn = compile_factor_graph_to_crn(reduced_fg, kappa_r=0.02, kappa_prod=50.0)
    reduced_sim = simulate_crn(reduced_crn, t_end=5000, n_points=200)
    
    # Compare on surviving variables
    survivor_names = [v.name for v in reduced_fg.variables]
    
    diff = max_marginal_diff_between_dicts(
        orig_sim.marginals, 
        reduced_sim.marginals, 
        survivor_names
    )
    
    print(f"Original vs Reduced CRN max diff: {diff:.2e}")
    assert diff < 1e-6, f"Reduced CRN should match original within 1e-6, got {diff}"


def test_normalize_function():
    """Test the normalize helper function."""
    # Normal case
    p = np.array([0.3, 0.7])
    assert np.allclose(_normalize(p), p)
    
    # Unnormalized
    p = np.array([3, 7])
    assert np.allclose(_normalize(p), [0.3, 0.7])
    
    # With negatives (should be clipped)
    p = np.array([-0.1, 0.6, 0.5])
    n = _normalize(p)
    assert np.all(n >= 0)
    assert np.isclose(n.sum(), 1.0)
    
    # Degenerate (all zeros)
    p = np.array([0.0, 0.0])
    n = _normalize(p)
    assert np.allclose(n, [0.5, 0.5])  # Uniform


def run_all_tests():
    """Run all unit tests."""
    print("=" * 60)
    print("Running CRN Correctness Unit Tests")
    print("=" * 60)
    
    tests = [
        ("test_normalize_function", test_normalize_function),
        ("test_crn_vs_bp_chain_5", test_crn_vs_bp_chain_5),
        ("test_reduction_preserves_crn_marginals", test_reduction_preserves_crn_marginals),
    ]
    
    passed = 0
    failed = 0
    
    for name, test_fn in tests:
        try:
            print(f"\n{name}...", end=" ")
            test_fn()
            print("PASSED")
            passed += 1
        except AssertionError as e:
            print(f"FAILED: {e}")
            failed += 1
        except Exception as e:
            print(f"ERROR: {e}")
            failed += 1
    
    print("\n" + "=" * 60)
    print(f"Results: {passed} passed, {failed} failed")
    print("=" * 60)
    
    return failed == 0


if __name__ == "__main__":
    success = run_all_tests()
    exit(0 if success else 1)
