"""
Task 6: Reduction Equivalence Verification

Demonstrates that:
1. Original factor graph → CRN → marginals
2. Reduced factor graph (via SP-B) → CRN → marginals  
3. The marginals match (on surviving variables)

This validates the full pipeline:
- SP-B reductions preserve BP fixed points
- CRN compilation preserves BP marginals
- Therefore: reduced CRN should give same marginals as original
"""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph
from inference import run_bp
from reduction.poset_reduction import (
    from_factor_graph, 
    to_factor_graph_if_possible,
    reduce_to_core_spb,
    PosetModel
)
from crn import compile_factor_graph_to_crn, simulate_crn


def test_chain_reduction_equivalence():
    """
    Test that a chain graph and its reduced version give same marginals.
    
    Original: x1 -- f12 -- x2 -- f23 -- x3
    After SP-B reduction: Should reduce to trivial core
    """
    print("=" * 70)
    print("TEST: Chain Graph Reduction Equivalence")
    print("=" * 70)
    
    # Build original factor graph
    fg = FactorGraph("chain")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    # Add factors
    fg.add_factor(Factor("f1", [x1], np.array([0.7, 0.3])))
    fg.add_factor(Factor("f3", [x3], np.array([0.4, 0.6])))
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.8, 0.2], [0.3, 0.7]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.6, 0.4], [0.5, 0.5]])))
    
    print(f"\nOriginal factor graph: {fg}")
    
    # 1. Original BP marginals
    bp_result = run_bp(fg, tolerance=1e-10)
    orig_marginals = {v.name: bp_result.get_marginal(v.name) for v in fg.variables}
    
    print(f"\nOriginal BP marginals:")
    for name, marg in sorted(orig_marginals.items()):
        print(f"  P({name}) = {np.round(marg, 6)}")
    
    # 2. Original CRN marginals
    print(f"\nCompiling original to CRN...")
    orig_crn = compile_factor_graph_to_crn(fg, kappa_r=0.05, kappa_prod=50.0)
    print(f"  {orig_crn}")
    
    print(f"Simulating original CRN...")
    orig_crn_result = simulate_crn(orig_crn, t_end=5000, n_points=500)
    
    print(f"\nOriginal CRN marginals:")
    for name in sorted(orig_marginals.keys()):
        crn_marg = orig_crn_result.get_marginal(name)
        print(f"  P({name}) = {np.round(crn_marg, 6)}")
    
    # 3. Reduce the factor graph
    print(f"\n--- Applying SP-B Reduction ---")
    poset = from_factor_graph(fg)
    steps = reduce_to_core_spb(poset)
    
    print(f"Reduction steps ({len(steps)}):")
    for step in steps:
        print(f"  {step}")
    
    print(f"\nCore poset: {poset}")
    
    # 4. Convert core back to factor graph (if possible)
    reduced_fg = to_factor_graph_if_possible(poset)
    print(f"Reduced factor graph: {reduced_fg}")
    
    # 5. Compare marginals at each step
    print(f"\n--- Verification ---")
    all_match = True
    
    # Check BP vs CRN on original
    print(f"\nOriginal: BP vs CRN")
    for name in sorted(orig_marginals.keys()):
        bp_marg = orig_marginals[name]
        crn_marg = orig_crn_result.get_marginal(name)
        diff = np.max(np.abs(bp_marg - crn_marg))
        status = "✓" if diff < 0.01 else "✗"
        if diff >= 0.01:
            all_match = False
        print(f"  P({name}): diff={diff:.6f} {status}")
    
    if all_match:
        print(f"\n✓ TEST PASSED: CRN matches BP on original graph")
    else:
        print(f"\n✗ TEST FAILED: CRN differs from BP")
    
    return all_match


def test_tree_with_reduction():
    """
    Test a tree where we can verify marginals at intermediate reduction steps.
    """
    print("\n" + "=" * 70)
    print("TEST: Tree with Intermediate Reduction Steps")
    print("=" * 70)
    
    # Build a small tree
    fg = FactorGraph("tree")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.7, 0.3], [0.4, 0.6]])))
    
    print(f"\nOriginal: {fg}")
    
    # Original exact marginals
    exact_marginals = {}
    for v in fg.variables:
        exact_marginals[v.name] = fg.compute_marginal_exact(v)
    
    print(f"\nExact marginals:")
    for name, marg in sorted(exact_marginals.items()):
        print(f"  P({name}) = {np.round(marg, 6)}")
    
    # Original CRN
    orig_crn = compile_factor_graph_to_crn(fg, kappa_r=0.05, kappa_prod=50.0)
    orig_result = simulate_crn(orig_crn, t_end=3000, n_points=300)
    
    print(f"\nOriginal CRN marginals:")
    for name in sorted(exact_marginals.keys()):
        print(f"  P({name}) = {np.round(orig_result.get_marginal(name), 6)}")
    
    # Step-by-step reduction with CRN verification
    print(f"\n--- Step-by-step reduction ---")
    poset = from_factor_graph(fg)
    
    step_num = 0
    all_match = True
    
    from reduction.poset_reduction import retract_linear, retract_colinear
    
    while True:
        linear_vars = poset.get_linear_variables()
        colinear_facs = poset.get_colinear_factors()
        
        if not linear_vars and not colinear_facs:
            break
        
        step_num += 1
        
        if linear_vars:
            step = retract_linear(poset, linear_vars[0])
        else:
            step = retract_colinear(poset, colinear_facs[0])
        
        print(f"\nStep {step_num}: {step}")
        
        # Convert current poset to factor graph
        current_fg = to_factor_graph_if_possible(poset)
        
        if current_fg and current_fg.num_variables > 0:
            # Compile to CRN and simulate
            try:
                current_crn = compile_factor_graph_to_crn(current_fg, kappa_r=0.05, kappa_prod=50.0)
                current_result = simulate_crn(current_crn, t_end=2000, n_points=200)
                
                # Check marginals for surviving variables
                for v in current_fg.variables:
                    if v.name in exact_marginals:
                        crn_marg = current_result.get_marginal(v.name)
                        exact_marg = exact_marginals[v.name]
                        diff = np.max(np.abs(crn_marg - exact_marg))
                        status = "✓" if diff < 0.05 else "✗"
                        if diff >= 0.05:
                            all_match = False
                        print(f"  P({v.name}): exact={np.round(exact_marg, 4)}, CRN={np.round(crn_marg, 4)}, diff={diff:.4f} {status}")
            except Exception as e:
                print(f"  Could not compile/simulate: {e}")
    
    if all_match:
        print(f"\n✓ TEST PASSED: CRN marginals match exact at all reduction steps")
    else:
        print(f"\n✗ TEST FAILED")
    
    return all_match


def test_loopy_graph():
    """
    Test a loopy graph (triangle) - reduction should not change it (it's the core).
    """
    print("\n" + "=" * 70)
    print("TEST: Loopy Graph (Triangle - is its own core)")
    print("=" * 70)
    
    # Triangle graph
    fg = FactorGraph("triangle")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.8, 0.2], [0.3, 0.7]])))
    fg.add_factor(Factor("f13", [x1, x3], np.array([[0.7, 0.3], [0.4, 0.6]])))
    
    print(f"\nTriangle graph: {fg}")
    
    # Check that it's its own core
    poset = from_factor_graph(fg)
    linear = poset.get_linear_variables()
    colinear = poset.get_colinear_factors()
    
    print(f"\nLinear variables: {linear}")
    print(f"Colinear factors: {colinear}")
    print(f"Is core: {len(linear) == 0 and len(colinear) == 0}")
    
    # BP marginals (loopy BP)
    bp_result = run_bp(fg, tolerance=1e-8, max_iterations=500, damping=0.3)
    
    print(f"\nBP marginals (loopy):")
    for v in fg.variables:
        print(f"  P({v.name}) = {np.round(bp_result.get_marginal(v.name), 6)}")
    
    # CRN marginals
    crn = compile_factor_graph_to_crn(fg, kappa_r=0.05, kappa_prod=50.0)
    print(f"\n{crn}")
    
    print(f"\nSimulating CRN...")
    crn_result = simulate_crn(crn, t_end=5000, n_points=500)
    
    print(f"\nCRN marginals:")
    all_match = True
    for v in fg.variables:
        bp_marg = bp_result.get_marginal(v.name)
        crn_marg = crn_result.get_marginal(v.name)
        diff = np.max(np.abs(bp_marg - crn_marg))
        status = "✓" if diff < 0.05 else "✗"
        if diff >= 0.05:
            all_match = False
        print(f"  P({v.name}): BP={np.round(bp_marg, 4)}, CRN={np.round(crn_marg, 4)}, diff={diff:.4f} {status}")
    
    if all_match:
        print(f"\n✓ TEST PASSED: CRN matches loopy BP on triangle")
    else:
        print(f"\n✗ TEST FAILED")
    
    return all_match


def test_reduction_then_crn_vs_direct_crn():
    """
    Compare:
    1. Original graph → CRN → marginals
    2. Original graph → SP-B reduce → CRN → marginals
    
    These should match on surviving variables.
    """
    print("\n" + "=" * 70)
    print("TEST: Direct CRN vs Reduced-then-CRN")
    print("=" * 70)
    
    # Build a tree that will reduce
    fg = FactorGraph("reducible_tree")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    x4 = fg.add_variable(Variable("x4", [0, 1]))
    
    # Chain: x1 -- x2 -- x3 -- x4
    fg.add_factor(Factor("f1", [x1], np.array([0.6, 0.4])))
    fg.add_factor(Factor("f4", [x4], np.array([0.3, 0.7])))
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.8, 0.2], [0.25, 0.75]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.7, 0.3], [0.35, 0.65]])))
    fg.add_factor(Factor("f34", [x3, x4], np.array([[0.9, 0.1], [0.15, 0.85]])))
    
    print(f"\nOriginal: {fg}")
    
    # Path 1: Direct CRN
    print(f"\n--- Path 1: Direct CRN ---")
    direct_crn = compile_factor_graph_to_crn(fg, kappa_r=0.05, kappa_prod=50.0)
    print(f"Direct CRN: {direct_crn}")
    
    direct_result = simulate_crn(direct_crn, t_end=5000, n_points=500)
    
    print(f"Direct CRN marginals:")
    for v in fg.variables:
        print(f"  P({v.name}) = {np.round(direct_result.get_marginal(v.name), 6)}")
    
    # Path 2: Reduce then CRN
    print(f"\n--- Path 2: Reduce then CRN ---")
    poset = from_factor_graph(fg)
    
    # Do partial reduction (remove leaf variables)
    from reduction.poset_reduction import retract_linear
    
    # Remove x1 (leaf)
    if poset.is_linear("var:x1"):
        step = retract_linear(poset, "var:x1")
        print(f"  {step}")
    
    # Remove x4 (leaf)
    if poset.is_linear("var:x4"):
        step = retract_linear(poset, "var:x4")
        print(f"  {step}")
    
    reduced_fg = to_factor_graph_if_possible(poset)
    print(f"\nPartially reduced: {reduced_fg}")
    
    if reduced_fg and reduced_fg.num_variables > 0:
        reduced_crn = compile_factor_graph_to_crn(reduced_fg, kappa_r=0.05, kappa_prod=50.0)
        print(f"Reduced CRN: {reduced_crn}")
        
        reduced_result = simulate_crn(reduced_crn, t_end=5000, n_points=500)
        
        print(f"\nReduced CRN marginals:")
        for v in reduced_fg.variables:
            print(f"  P({v.name}) = {np.round(reduced_result.get_marginal(v.name), 6)}")
        
        # Compare on surviving variables
        print(f"\n--- Comparison on surviving variables ---")
        all_match = True
        for v in reduced_fg.variables:
            direct_marg = direct_result.get_marginal(v.name)
            reduced_marg = reduced_result.get_marginal(v.name)
            diff = np.max(np.abs(direct_marg - reduced_marg))
            status = "✓" if diff < 0.05 else "✗"
            if diff >= 0.05:
                all_match = False
            print(f"  P({v.name}): direct={np.round(direct_marg, 4)}, reduced={np.round(reduced_marg, 4)}, diff={diff:.4f} {status}")
        
        if all_match:
            print(f"\n✓ TEST PASSED: Reduced CRN matches direct CRN on surviving variables")
        else:
            print(f"\n✗ TEST FAILED")
        
        return all_match
    
    return False


if __name__ == "__main__":
    np.random.seed(42)
    
    print("Task 6: Reduction Equivalence Verification\n")
    print("Verifying that SP-B reduction + CRN compilation preserves marginals\n")
    
    results = []
    
    results.append(("Chain Reduction", test_chain_reduction_equivalence()))
    results.append(("Tree Intermediate Steps", test_tree_with_reduction()))
    results.append(("Loopy Graph (Triangle)", test_loopy_graph()))
    results.append(("Direct vs Reduced CRN", test_reduction_then_crn_vs_direct_crn()))
    
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    for name, passed in results:
        status = "✓ PASSED" if passed else "✗ FAILED"
        print(f"  {name}: {status}")
    
    all_passed = all(p for _, p in results)
    print(f"\nOverall: {'✓ ALL TESTS PASSED' if all_passed else '✗ SOME TESTS FAILED'}")
