"""
Belief Propagation Examples and Tests

This file demonstrates belief propagation on various factor graphs
and compares BP results with exact marginals.
"""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph, build_factor_graph
from inference import BeliefPropagation, run_bp


def test_simple_chain():
    """
    Test BP on a simple chain: x1 -- f1 -- x2
    
    This is a tree, so BP should give exact results.
    """
    print("=" * 60)
    print("Test: Simple Chain (x1 -- f12 -- x2)")
    print("=" * 60)
    
    # Build a simple chain
    fg = FactorGraph("SimpleChain")
    
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Prior on x1: prefer 0
    fg.add_factor(Factor("f1", [x1], np.array([0.8, 0.2])))
    
    # Prior on x2: prefer 1
    fg.add_factor(Factor("f2", [x2], np.array([0.3, 0.7])))
    
    # Pairwise: prefer same values (correlation)
    fg.add_factor(Factor("f12", [x1, x2], np.array([
        [2.0, 0.5],  # x1=0
        [0.5, 2.0]   # x1=1
    ])))
    
    print(f"\nFactor Graph: {fg}")
    print(f"Is tree: {fg.is_tree()}")
    
    # Run BP
    bp = BeliefPropagation(fg)
    result = bp.run(max_iterations=50, tolerance=1e-10)
    
    print(f"\nBP Result: {result}")
    
    # Compare with exact marginals
    print("\nComparison with exact marginals:")
    print(f"{'Variable':<10} {'BP Marginal':<25} {'Exact Marginal':<25} {'Max Diff':<10}")
    print("-" * 70)
    
    max_error = 0
    for var in fg.variables:
        bp_marginal = result.get_marginal(var.name)
        exact_marginal = fg.compute_marginal_exact(var)
        diff = np.max(np.abs(bp_marginal - exact_marginal))
        max_error = max(max_error, diff)
        
        print(f"{var.name:<10} {str(np.round(bp_marginal, 6)):<25} "
              f"{str(np.round(exact_marginal, 6)):<25} {diff:.2e}")
    
    print(f"\nMaximum error: {max_error:.2e}")
    assert max_error < 1e-6, "BP should be exact for trees!"
    print("✓ Test passed: BP matches exact marginals for tree")
    
    return fg, result


def test_four_variable_tree():
    """
    Test BP on the 4-variable tree from Napp & Adams Figure 2a.
    """
    print("\n" + "=" * 60)
    print("Test: Four-Variable Tree (Napp & Adams Fig 2a)")
    print("=" * 60)
    
    # Factor tables from Figure 2c
    psi1 = np.array([1.0, 0.1])
    psi2 = np.array([1.0, 0.1])
    psi3 = np.array([2.0, 1.0, 1.0])
    psi7 = np.array([1.0, 1.0])
    
    psi4 = np.array([
        [1.0, 0.1],
        [0.1, 3.0]
    ])
    
    psi5 = np.array([
        [0.1, 2.0, 0.1],
        [3.0, 0.1, 1.0]
    ])
    
    psi6 = np.array([
        [0.1, 0.1],
        [1.0, 0.1],
        [0.5, 0.5]
    ])
    
    fg = build_factor_graph(
        variables_spec={
            "x1": [0, 1],
            "x2": [0, 1],
            "x3": [0, 1, 2],
            "x4": [0, 1],
        },
        factors_spec={
            "psi1": (["x1"], psi1),
            "psi2": (["x2"], psi2),
            "psi3": (["x3"], psi3),
            "psi7": (["x4"], psi7),
            "psi4": (["x1", "x2"], psi4),
            "psi5": (["x2", "x3"], psi5),
            "psi6": (["x3", "x4"], psi6),
        },
        name="FourVarTree"
    )
    
    print(f"\nFactor Graph: {fg}")
    print(f"Is tree: {fg.is_tree()}")
    
    # Run BP
    result = run_bp(fg, max_iterations=100, tolerance=1e-10)
    
    print(f"\nBP Result: {result}")
    
    # Compare with exact
    print("\nComparison with exact marginals:")
    print(f"{'Variable':<10} {'BP Marginal':<30} {'Exact Marginal':<30}")
    print("-" * 70)
    
    max_error = 0
    for var in fg.variables:
        bp_marginal = result.get_marginal(var.name)
        exact_marginal = fg.compute_marginal_exact(var)
        diff = np.max(np.abs(bp_marginal - exact_marginal))
        max_error = max(max_error, diff)
        
        print(f"{var.name:<10} {str(np.round(bp_marginal, 4)):<30} "
              f"{str(np.round(exact_marginal, 4)):<30}")
    
    print(f"\nMaximum error: {max_error:.2e}")
    assert max_error < 1e-5, "BP should be exact for trees!"
    print("✓ Test passed: BP matches exact marginals for 4-variable tree")
    
    return fg, result


def test_loopy_triangle():
    """
    Test BP on a loopy (cyclic) graph - triangle with 3 variables.
    
    For loopy graphs, BP gives approximate results.
    """
    print("\n" + "=" * 60)
    print("Test: Loopy Triangle Graph")
    print("=" * 60)
    
    fg = FactorGraph("Triangle")
    
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    # Uniform priors
    fg.add_factor(Factor("f1", [x1], np.array([1.0, 1.0])))
    fg.add_factor(Factor("f2", [x2], np.array([1.0, 1.0])))
    fg.add_factor(Factor("f3", [x3], np.array([1.0, 1.0])))
    
    # Anti-correlating pairwise factors (prefer different values)
    anti_corr = np.array([
        [0.1, 1.0],
        [1.0, 0.1]
    ])
    
    fg.add_factor(Factor("f12", [x1, x2], anti_corr.copy()))
    fg.add_factor(Factor("f23", [x2, x3], anti_corr.copy()))
    fg.add_factor(Factor("f13", [x1, x3], anti_corr.copy()))
    
    print(f"\nFactor Graph: {fg}")
    print(f"Is tree: {fg.is_tree()}")  # Should be False
    
    # Run BP without damping - may oscillate
    print("\n--- Without damping ---")
    bp = BeliefPropagation(fg)
    result_no_damp = bp.run(max_iterations=100, tolerance=1e-6, damping=0.0)
    print(f"BP Result: {result_no_damp}")
    
    # Run BP with damping - should converge
    print("\n--- With damping=0.5 ---")
    bp.reset()
    result_damped = bp.run(max_iterations=100, tolerance=1e-6, damping=0.5)
    print(f"BP Result: {result_damped}")
    
    # Compare with exact
    print("\nComparison (using damped result):")
    print(f"{'Variable':<10} {'BP Marginal':<25} {'Exact Marginal':<25} {'Diff':<10}")
    print("-" * 70)
    
    for var in fg.variables:
        bp_marginal = result_damped.get_marginal(var.name)
        exact_marginal = fg.compute_marginal_exact(var)
        diff = np.max(np.abs(bp_marginal - exact_marginal))
        
        print(f"{var.name:<10} {str(np.round(bp_marginal, 4)):<25} "
              f"{str(np.round(exact_marginal, 4)):<25} {diff:.4f}")
    
    print("\nNote: For loopy graphs, BP gives approximate results.")
    print("The exact marginals for this symmetric frustrated system are uniform [0.5, 0.5]")
    
    return fg, result_damped


def test_single_variable():
    """
    Test BP on a single variable with one factor (simplest case).
    """
    print("\n" + "=" * 60)
    print("Test: Single Variable")
    print("=" * 60)
    
    fg = FactorGraph("SingleVar")
    
    x = fg.add_variable(Variable("x", [0, 1, 2]))
    fg.add_factor(Factor("f", [x], np.array([1.0, 2.0, 3.0])))
    
    result = run_bp(fg)
    
    bp_marginal = result.get_marginal("x")
    exact_marginal = fg.compute_marginal_exact(x)
    
    print(f"BP marginal:    {np.round(bp_marginal, 4)}")
    print(f"Exact marginal: {np.round(exact_marginal, 4)}")
    
    diff = np.max(np.abs(bp_marginal - exact_marginal))
    print(f"Max difference: {diff:.2e}")
    
    assert diff < 1e-10, "Single variable should be exact!"
    print("✓ Test passed")
    
    return fg, result


def test_message_inspection():
    """
    Demonstrate message inspection during BP.
    """
    print("\n" + "=" * 60)
    print("Test: Message Inspection")
    print("=" * 60)
    
    fg = FactorGraph("MsgInspect")
    
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    fg.add_factor(Factor("f1", [x1], np.array([0.7, 0.3])))
    fg.add_factor(Factor("f12", [x1, x2], np.array([
        [1.0, 0.2],
        [0.2, 1.0]
    ])))
    fg.add_factor(Factor("f2", [x2], np.array([0.4, 0.6])))
    
    bp = BeliefPropagation(fg)
    result = bp.run()
    
    print("\nAll messages after convergence:")
    print("\nSum messages (Factor → Variable):")
    for key, msg in bp.get_all_sum_messages().items():
        print(f"  {msg}")
    
    print("\nProduct messages (Variable → Factor):")
    for key, msg in bp.get_all_product_messages().items():
        print(f"  {msg}")
    
    print("\nMarginals:")
    for var_name, marginal in result.marginals.items():
        print(f"  P({var_name}) = {np.round(marginal, 4)}")
    
    return fg, result


def test_napp_adams_example():
    """
    Recreate the exact example from Napp & Adams paper Figure 1.
    
    Two variables x1, x2 with three factors:
    - ψ1(x1): unary on x1
    - ψ2(x2): unary on x2  
    - ψ3(x1, x2): pairwise
    """
    print("\n" + "=" * 60)
    print("Test: Napp & Adams Figure 1 Example")
    print("=" * 60)
    
    fg = FactorGraph("NappAdamsFig1")
    
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Using factors from Figure 2c for demonstration
    # ψ1(x1): [1, 0.1] means x1=0 is preferred
    fg.add_factor(Factor("psi1", [x1], np.array([1.0, 0.1])))
    
    # ψ2(x2): [1, 0.1] means x2=0 is preferred
    fg.add_factor(Factor("psi2", [x2], np.array([1.0, 0.1])))
    
    # ψ3(x1, x2): correlation factor
    psi3 = np.array([
        [1.0, 0.1],   # x1=0
        [0.1, 3.0]    # x1=1
    ])
    fg.add_factor(Factor("psi3", [x1, x2], psi3))
    
    print(f"\nFactor Graph: {fg}")
    
    # Run BP
    result = run_bp(fg, tolerance=1e-10)
    
    print(f"\nBP converged in {result.iterations} iterations")
    
    # Show results
    print("\nResults:")
    for var in fg.variables:
        bp_marg = result.get_marginal(var.name)
        exact_marg = fg.compute_marginal_exact(var)
        print(f"  P({var.name}): BP={np.round(bp_marg, 4)}, Exact={np.round(exact_marg, 4)}")
    
    return fg, result


if __name__ == "__main__":
    print("Running Belief Propagation Tests\n")
    
    test_single_variable()
    test_simple_chain()
    test_four_variable_tree()
    test_loopy_triangle()
    test_message_inspection()
    test_napp_adams_example()
    
    print("\n" + "=" * 60)
    print("All tests completed!")
    print("=" * 60)
