"""
Tests for Strict SP-B Reduction on Trees

On trees, the Bethe free energy is exact and its critical points correspond
to the true marginals. Therefore, SP-B reductions should preserve marginals
for surviving variables.

Test strategy:
1. Build random tree factor graphs
2. Compute exact marginals via brute force
3. Apply SP-B reductions step by step  
4. Verify marginals match for surviving variables at each step
"""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph
from inference import run_bp
from reduction.poset_reduction import (
    PosetModel, Region, ReductionStep,
    from_factor_graph, to_factor_graph_if_possible,
    reduce_to_core_spb, retract_linear, retract_colinear
)


def compute_exact_marginal(fg: FactorGraph, var_name: str) -> np.ndarray:
    """Compute exact marginal by brute force enumeration."""
    for var in fg.variables:
        if var.name == var_name:
            return fg.compute_marginal_exact(var)
    raise ValueError(f"Variable {var_name} not found")


def test_simple_chain():
    """
    Test: x1 -- f12 -- x2 -- f23 -- x3
    
    This is a tree, so:
    - x1 is linear (only in f12)
    - x3 is linear (only in f23)
    - After removing x1, f12 becomes unary on x2 (colinear)
    """
    print("=" * 70)
    print("TEST: Simple Chain x1 -- f12 -- x2 -- f23 -- x3")
    print("=" * 70)
    
    # Build factor graph
    fg = FactorGraph("chain")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    f12_table = np.array([[0.8, 0.2], [0.3, 0.7]])
    f23_table = np.array([[0.6, 0.4], [0.4, 0.6]])
    
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    fg.add_factor(Factor("f23", [x2, x3], f23_table))
    
    print(f"\nOriginal graph: {fg}")
    
    # Compute exact marginals
    exact_marginals = {
        "x1": compute_exact_marginal(fg, "x1"),
        "x2": compute_exact_marginal(fg, "x2"),
        "x3": compute_exact_marginal(fg, "x3"),
    }
    
    print("\nExact marginals:")
    for name, marg in exact_marginals.items():
        print(f"  P({name}) = {np.round(marg, 6)}")
    
    # Convert to poset model
    poset = from_factor_graph(fg)
    print(f"\nPoset model: {poset}")
    print(f"  Variables: {poset.variables}")
    print(f"  Factors: {poset.factors}")
    
    # Check linear/colinear
    linear_vars = poset.get_linear_variables()
    colinear_facs = poset.get_colinear_factors()
    print(f"\n  Linear variables: {linear_vars}")
    print(f"  Colinear factors: {colinear_facs}")
    
    # Step-by-step reduction with verification
    all_passed = True
    step_num = 0
    
    while True:
        linear_vars = poset.get_linear_variables()
        colinear_facs = poset.get_colinear_factors()
        
        if not linear_vars and not colinear_facs:
            break
        
        step_num += 1
        
        if linear_vars:
            var_to_remove = linear_vars[0]
            step = retract_linear(poset, var_to_remove)
            print(f"\nStep {step_num}: {step}")
        else:
            fac_to_remove = colinear_facs[0]
            step = retract_colinear(poset, fac_to_remove)
            print(f"\nStep {step_num}: {step}")
        
        # Convert back and check marginals
        reduced_fg = to_factor_graph_if_possible(poset)
        
        if reduced_fg and reduced_fg.num_variables > 0:
            print(f"  Remaining variables: {[v.name for v in reduced_fg.variables]}")
            
            # Check marginals for remaining variables
            for var in reduced_fg.variables:
                if var.name in exact_marginals:
                    try:
                        reduced_exact = compute_exact_marginal(reduced_fg, var.name)
                        orig_exact = exact_marginals[var.name]
                        diff = np.max(np.abs(reduced_exact - orig_exact))
                        status = "✓" if diff < 1e-6 else "✗"
                        if diff >= 1e-6:
                            all_passed = False
                        print(f"    P({var.name}): reduced={np.round(reduced_exact, 4)}, "
                              f"orig={np.round(orig_exact, 4)}, diff={diff:.2e} {status}")
                    except Exception as e:
                        print(f"    P({var.name}): Error computing marginal: {e}")
    
    print(f"\nFinal poset: {poset}")
    print(f"  Variables: {poset.variables}")
    print(f"  Factors: {poset.factors}")
    
    if all_passed:
        print("\n✓ TEST PASSED: All intermediate marginals preserved")
    else:
        print("\n✗ TEST FAILED: Some marginals differ")
    
    return all_passed


def test_chain_with_unary_factors():
    """
    Test: f1(x1) -- f12 -- f2(x2) -- f23 -- f3(x3)
    
    Adding unary factors to the chain.
    """
    print("\n" + "=" * 70)
    print("TEST: Chain with Unary Factors")
    print("=" * 70)
    
    fg = FactorGraph("chain_unary")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    # Unary factors (priors)
    fg.add_factor(Factor("f1", [x1], np.array([0.7, 0.3])))
    fg.add_factor(Factor("f2", [x2], np.array([0.5, 0.5])))
    fg.add_factor(Factor("f3", [x3], np.array([0.6, 0.4])))
    
    # Binary factors (pairwise)
    f12_table = np.array([[0.9, 0.1], [0.2, 0.8]])
    f23_table = np.array([[0.7, 0.3], [0.3, 0.7]])
    
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    fg.add_factor(Factor("f23", [x2, x3], f23_table))
    
    print(f"\nOriginal graph: {fg}")
    
    # Exact marginals
    exact_marginals = {v.name: compute_exact_marginal(fg, v.name) for v in fg.variables}
    print("\nExact marginals:")
    for name, marg in exact_marginals.items():
        print(f"  P({name}) = {np.round(marg, 6)}")
    
    # BP marginals (should match exact on tree)
    bp_result = run_bp(fg, tolerance=1e-10)
    print("\nBP marginals:")
    for v in fg.variables:
        print(f"  P({v.name}) = {np.round(bp_result.get_marginal(v.name), 6)}")
    
    # Convert and reduce
    poset = from_factor_graph(fg)
    print(f"\nPoset after absorbing unary factors:")
    print(f"  Variables: {poset.variables}")
    print(f"  Factors: {poset.factors}")
    
    # Show variable tables (should have absorbed unary factors)
    for var_region in poset.variables:
        region = poset.regions[var_region]
        print(f"  ψ_{var_region}: {region.table}")
    
    # Reduce to core
    steps = reduce_to_core_spb(poset)
    print(f"\nReduction steps ({len(steps)}):")
    for step in steps:
        print(f"  {step}")
    
    print(f"\nCore: {poset}")
    
    # Verify final state
    reduced_fg = to_factor_graph_if_possible(poset)
    print(f"Converted back: {reduced_fg}")
    
    return True


def test_star_graph():
    """
    Test star graph: x0 at center connected to x1, x2, x3 via binary factors.
    
    Structure:
        x1
         \
    x2 -- x0 -- x3
    
    This is a tree with x1, x2, x3 as leaves (linear).
    """
    print("\n" + "=" * 70)
    print("TEST: Star Graph (4 variables)")
    print("=" * 70)
    
    fg = FactorGraph("star")
    x0 = fg.add_variable(Variable("x0", [0, 1]))
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    # Add unary on center
    fg.add_factor(Factor("f0", [x0], np.array([0.6, 0.4])))
    
    # Binary factors connecting leaves to center
    fg.add_factor(Factor("f01", [x0, x1], np.array([[0.8, 0.2], [0.3, 0.7]])))
    fg.add_factor(Factor("f02", [x0, x2], np.array([[0.7, 0.3], [0.4, 0.6]])))
    fg.add_factor(Factor("f03", [x0, x3], np.array([[0.9, 0.1], [0.2, 0.8]])))
    
    print(f"\nOriginal graph: {fg}")
    
    # Exact marginals
    exact_marginals = {v.name: compute_exact_marginal(fg, v.name) for v in fg.variables}
    print("\nExact marginals:")
    for name, marg in sorted(exact_marginals.items()):
        print(f"  P({name}) = {np.round(marg, 6)}")
    
    # Convert and analyze
    poset = from_factor_graph(fg)
    linear_vars = poset.get_linear_variables()
    print(f"\nLinear variables: {linear_vars}")
    # x1, x2, x3 should be linear; x0 is not (connected to 3 factors)
    
    # Reduce step by step
    all_passed = True
    step_num = 0
    
    while True:
        linear_vars = poset.get_linear_variables()
        colinear_facs = poset.get_colinear_factors()
        
        if not linear_vars and not colinear_facs:
            break
        
        step_num += 1
        
        if linear_vars:
            step = retract_linear(poset, linear_vars[0])
        else:
            step = retract_colinear(poset, colinear_facs[0])
        
        print(f"\nStep {step_num}: {step}")
        
        # Check surviving variable marginals
        reduced_fg = to_factor_graph_if_possible(poset)
        if reduced_fg and reduced_fg.num_variables > 0:
            for var in reduced_fg.variables:
                if var.name in exact_marginals:
                    try:
                        reduced_exact = compute_exact_marginal(reduced_fg, var.name)
                        orig_exact = exact_marginals[var.name]
                        diff = np.max(np.abs(reduced_exact - orig_exact))
                        status = "✓" if diff < 1e-6 else f"✗ diff={diff:.6f}"
                        if diff >= 1e-6:
                            all_passed = False
                        print(f"  P({var.name}): {status}")
                    except:
                        pass
    
    print(f"\nFinal: {poset}")
    
    if all_passed:
        print("\n✓ TEST PASSED")
    else:
        print("\n✗ TEST FAILED")
    
    return all_passed


def test_longer_chain():
    """Test a longer chain: x1 -- x2 -- x3 -- x4 -- x5"""
    print("\n" + "=" * 70)
    print("TEST: Longer Chain (5 variables)")
    print("=" * 70)
    
    fg = FactorGraph("long_chain")
    vars = []
    for i in range(1, 6):
        v = fg.add_variable(Variable(f"x{i}", [0, 1]))
        vars.append(v)
        # Add unary factor
        fg.add_factor(Factor(f"f{i}", [v], np.random.rand(2) + 0.1))
    
    # Add pairwise factors
    for i in range(4):
        table = np.random.rand(2, 2) + 0.1
        fg.add_factor(Factor(f"f{i+1}{i+2}", [vars[i], vars[i+1]], table))
    
    print(f"\nOriginal graph: {fg}")
    
    # Exact marginals
    exact_marginals = {v.name: compute_exact_marginal(fg, v.name) for v in fg.variables}
    print("\nExact marginals:")
    for name in sorted(exact_marginals.keys()):
        print(f"  P({name}) = {np.round(exact_marginals[name], 4)}")
    
    # Reduce
    poset = from_factor_graph(fg)
    steps = reduce_to_core_spb(poset)
    
    print(f"\nReduction completed in {len(steps)} steps")
    print(f"Core: {poset}")
    
    return True


def test_ternary_factor():
    """Test with a ternary factor to ensure higher-arity works."""
    print("\n" + "=" * 70)
    print("TEST: Tree with Ternary Factor")
    print("=" * 70)
    
    fg = FactorGraph("ternary")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    # One ternary factor
    f123_table = np.random.rand(2, 2, 2) + 0.1
    fg.add_factor(Factor("f123", [x1, x2, x3], f123_table))
    
    print(f"\nOriginal graph: {fg}")
    
    exact_marginals = {v.name: compute_exact_marginal(fg, v.name) for v in fg.variables}
    print("\nExact marginals:")
    for name, marg in sorted(exact_marginals.items()):
        print(f"  P({name}) = {np.round(marg, 4)}")
    
    # Convert and check
    poset = from_factor_graph(fg)
    print(f"\nPoset: {poset}")
    print(f"  Linear vars: {poset.get_linear_variables()}")  # All should be linear
    
    # Reduce
    all_passed = True
    while True:
        linear = poset.get_linear_variables()
        colinear = poset.get_colinear_factors()
        
        if not linear and not colinear:
            break
        
        if linear:
            step = retract_linear(poset, linear[0])
        else:
            step = retract_colinear(poset, colinear[0])
        
        print(f"  {step}")
        
        # Check surviving marginals
        reduced_fg = to_factor_graph_if_possible(poset)
        if reduced_fg and reduced_fg.num_variables > 0:
            for var in reduced_fg.variables:
                if var.name in exact_marginals:
                    try:
                        reduced = compute_exact_marginal(reduced_fg, var.name)
                        orig = exact_marginals[var.name]
                        diff = np.max(np.abs(reduced - orig))
                        if diff >= 1e-6:
                            all_passed = False
                            print(f"    ✗ P({var.name}) diff={diff:.6f}")
                    except:
                        pass
    
    if all_passed:
        print("\n✓ TEST PASSED")
    else:
        print("\n✗ TEST FAILED")
    
    return all_passed


if __name__ == "__main__":
    np.random.seed(42)
    
    print("Running Strict SP-B Reduction Tests on Trees\n")
    
    results = []
    results.append(("Simple Chain", test_simple_chain()))
    results.append(("Chain with Unary", test_chain_with_unary_factors()))
    results.append(("Star Graph", test_star_graph()))
    results.append(("Longer Chain", test_longer_chain()))
    results.append(("Ternary Factor", test_ternary_factor()))
    
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    for name, passed in results:
        status = "✓ PASSED" if passed else "✗ FAILED"
        print(f"  {name}: {status}")
