"""
Unit Tests for Strict SP-B Reduction

These tests specifically verify that:
1. Linear deletion is DELETE-ONLY (no table updates) - NOT elimination
2. Colinear deletion properly updates tables via Eq 4.29/4.30
3. Unary factors are NOT absorbed into variable regions
4. Reconstruction helper works correctly

The key test distinguishes strict SP-B from elimination by checking that
the cover factor's table and scope are UNCHANGED after linear deletion.
"""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph
from reduction.poset_reduction import (
    from_factor_graph,
    retract_linear,
    retract_colinear,
    reduce_to_core_spb,
    reconstruct_deleted_belief,
    validate_linear_deletion_is_delete_only,
    PosetModel
)


def test_linear_deletion_is_delete_only_not_elimination():
    """
    KEY TEST: Verify linear deletion does NOT modify survivor tables.
    
    This test FAILS if linear deletion is implemented as elimination.
    This test PASSES if linear deletion is strict SP-B (delete-only).
    
    Setup:
    - Variables: x1, x2 (binary)
    - Factor f12(x1, x2) with nontrivial table
    - Unary factor u1(x1) with strong bias [100, 1]
    
    What strict SP-B should do:
    1. u1 is colinear → apply Eq 4.29, multiply u1 into f12 along x1 axis
    2. x1 becomes linear → DELETE-ONLY, f12 unchanged
    
    What elimination would do:
    2. x1 is linear → marginalize x1 out of f12, changing f12's scope and table
    
    We detect elimination by checking if f12's scope/table changed after step 2.
    """
    print("=" * 70)
    print("TEST: Linear Deletion is Delete-Only (Not Elimination)")
    print("=" * 70)
    
    # Build factor graph
    fg = FactorGraph("test_linear")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Binary factor
    f12_table = np.array([[2.0, 1.0],
                          [1.0, 2.0]])
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    
    # Strong unary bias on x1
    u1_table = np.array([100.0, 1.0])
    fg.add_factor(Factor("u1", [x1], u1_table))
    
    print(f"\nFactor graph: {fg}")
    print(f"  f12 table:\n{f12_table}")
    print(f"  u1 table: {u1_table}")
    
    # Convert to poset
    poset = from_factor_graph(fg)
    
    print(f"\nInitial poset:")
    print(f"  Variables: {poset.variables}")
    print(f"  Factors: {poset.factors}")
    
    # Verify unary factor is a factor region (NOT absorbed)
    assert "fac:u1" in poset.factors, "FAIL: Unary factor u1 was absorbed into variable region!"
    print(f"  ✓ Unary factor u1 is a factor region (not absorbed)")
    
    # Verify variable regions have uniform tables
    x1_region = poset.regions["var:x1"]
    x2_region = poset.regions["var:x2"]
    assert np.allclose(x1_region.table, np.ones(2)), "FAIL: x1 table is not uniform!"
    assert np.allclose(x2_region.table, np.ones(2)), "FAIL: x2 table is not uniform!"
    print(f"  ✓ Variable regions have uniform tables (ψ ≡ 1)")
    
    # Step 1: u1 is colinear (unary factor) → apply colinear deletion
    assert poset.is_colinear("fac:u1"), "u1 should be colinear"
    print(f"\n--- Step 1: Colinear deletion of u1 ---")
    
    f12_before_colinear = poset.regions["fac:f12"].table.copy()
    step1 = retract_colinear(poset, "fac:u1")
    f12_after_colinear = poset.regions["fac:f12"].table.copy()
    
    print(f"  {step1}")
    print(f"  f12 table before: \n{f12_before_colinear}")
    print(f"  f12 table after:  \n{f12_after_colinear}")
    
    # f12 should be updated by multiplying u1 along x1 axis (Eq 4.29)
    expected_f12 = f12_table * u1_table.reshape(-1, 1)
    print(f"  Expected f12:     \n{expected_f12}")
    
    assert np.allclose(f12_after_colinear, expected_f12), \
        f"FAIL: Colinear deletion did not correctly update f12!"
    print(f"  ✓ Colinear deletion correctly updated f12 via Eq 4.29")
    
    # Step 2: x1 is now linear → apply linear deletion (DELETE-ONLY!)
    assert poset.is_linear("var:x1"), "x1 should be linear after u1 removal"
    print(f"\n--- Step 2: Linear deletion of x1 ---")
    
    f12_before_linear = poset.regions["fac:f12"].table.copy()
    f12_scope_before = poset.regions["fac:f12"].scope
    
    step2 = retract_linear(poset, "var:x1")
    
    f12_after_linear = poset.regions["fac:f12"].table.copy()
    f12_scope_after = poset.regions["fac:f12"].scope
    
    print(f"  {step2}")
    print(f"  f12 scope before: {f12_scope_before}")
    print(f"  f12 scope after:  {f12_scope_after}")
    print(f"  f12 table before:\n{f12_before_linear}")
    print(f"  f12 table after:\n{f12_after_linear}")
    
    # === KEY ASSERTIONS ===
    # These FAIL if elimination was performed, PASS if delete-only
    
    # 1. Scope should be UNCHANGED
    scope_unchanged = (f12_scope_after == f12_scope_before)
    print(f"\n  Scope unchanged: {scope_unchanged}")
    
    # 2. Table should be UNCHANGED  
    table_unchanged = np.allclose(f12_after_linear, f12_before_linear)
    print(f"  Table unchanged: {table_unchanged}")
    
    # 3. Use validation helper
    is_delete_only = validate_linear_deletion_is_delete_only(step2, poset)
    print(f"  Validation helper: {is_delete_only}")
    
    # 4. x1 should no longer exist as a variable region
    x1_deleted = "var:x1" not in poset.variables
    print(f"  x1 region deleted: {x1_deleted}")
    
    # 5. But x1 should still be in f12's scope (this is the "intermediate not a FG" phenomenon)
    x1_in_scope = "x1" in f12_scope_after
    print(f"  x1 still in f12 scope: {x1_in_scope}")
    
    all_passed = scope_unchanged and table_unchanged and is_delete_only and x1_deleted and x1_in_scope
    
    if all_passed:
        print(f"\n✓ TEST PASSED: Linear deletion is strict SP-B (delete-only)")
        print("  - Scope unchanged")
        print("  - Table unchanged")
        print("  - Variable region deleted")
        print("  - Variable still referenced in factor scope (intermediate object)")
    else:
        print(f"\n✗ TEST FAILED: Linear deletion appears to be elimination!")
        if not scope_unchanged:
            print("  - Scope was modified (elimination shrinks scope)")
        if not table_unchanged:
            print("  - Table was modified (elimination marginalizes)")
    
    return all_passed


def test_unary_factors_not_absorbed():
    """
    Test that unary factors remain as factor regions, not absorbed into variables.
    """
    print("\n" + "=" * 70)
    print("TEST: Unary Factors Not Absorbed")
    print("=" * 70)
    
    fg = FactorGraph("test_unary")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Unary factors
    fg.add_factor(Factor("u1", [x1], np.array([0.8, 0.2])))
    fg.add_factor(Factor("u2", [x2], np.array([0.3, 0.7])))
    
    # Binary factor
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    
    poset = from_factor_graph(fg)
    
    print(f"\nPoset structure:")
    print(f"  Variables: {poset.variables}")
    print(f"  Factors: {poset.factors}")
    
    # Check unary factors exist as factor regions
    has_u1 = "fac:u1" in poset.factors
    has_u2 = "fac:u2" in poset.factors
    
    # Check variable tables are uniform
    x1_uniform = np.allclose(poset.regions["var:x1"].table, np.ones(2))
    x2_uniform = np.allclose(poset.regions["var:x2"].table, np.ones(2))
    
    all_passed = has_u1 and has_u2 and x1_uniform and x2_uniform
    
    print(f"\n  u1 is factor region: {has_u1}")
    print(f"  u2 is factor region: {has_u2}")
    print(f"  x1 table uniform: {x1_uniform}")
    print(f"  x2 table uniform: {x2_uniform}")
    
    if all_passed:
        print(f"\n✓ TEST PASSED: Unary factors are factor regions")
    else:
        print(f"\n✗ TEST FAILED: Unary factors were absorbed")
    
    return all_passed


def test_colinear_deletion_updates_tables():
    """
    Test that colinear deletion properly updates survivor tables.
    """
    print("\n" + "=" * 70)
    print("TEST: Colinear Deletion Updates Tables")
    print("=" * 70)
    
    fg = FactorGraph("test_colinear")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Binary factor
    f12_table = np.array([[0.9, 0.1], [0.2, 0.8]])
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    
    # Unary factor on x1
    u1_table = np.array([2.0, 0.5])
    fg.add_factor(Factor("u1", [x1], u1_table))
    
    poset = from_factor_graph(fg)
    
    print(f"\nBefore colinear deletion:")
    print(f"  f12 table:\n{poset.regions['fac:f12'].table}")
    
    # Apply colinear deletion
    step = retract_colinear(poset, "fac:u1")
    
    print(f"\n{step}")
    print(f"\nAfter colinear deletion:")
    print(f"  f12 table:\n{poset.regions['fac:f12'].table}")
    
    # Expected: f12 multiplied by u1 along x1 axis (Eq 4.29)
    expected = f12_table * u1_table.reshape(-1, 1)
    print(f"\nExpected (f12 * u1 along x1):\n{expected}")
    
    table_correct = np.allclose(poset.regions["fac:f12"].table, expected)
    
    if table_correct:
        print(f"\n✓ TEST PASSED: Colinear deletion correctly updated f12")
    else:
        print(f"\n✗ TEST FAILED: Table update incorrect")
    
    return table_correct


def test_reconstruction_helper():
    """
    Test the belief reconstruction helper for linear deletions.
    """
    print("\n" + "=" * 70)
    print("TEST: Reconstruction Helper")
    print("=" * 70)
    
    fg = FactorGraph("test_recon")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Binary factor
    f12_table = np.array([[0.9, 0.1], [0.2, 0.8]])
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    
    poset = from_factor_graph(fg)
    
    # x1 is linear (only one factor above it)
    step = retract_linear(poset, "var:x1")
    
    print(f"\n{step}")
    
    # Simulate a belief on the cover factor
    cover_belief = np.array([[0.6, 0.1], [0.1, 0.2]])  # Unnormalized
    cover_belief = cover_belief / np.sum(cover_belief)  # Normalize
    
    print(f"\nCover belief (Q_f12):\n{cover_belief}")
    
    # Reconstruct deleted belief
    reconstructed = reconstruct_deleted_belief(
        step, 
        cover_belief, 
        step.details['cover_scope_before'],
        {"x1": 2, "x2": 2}
    )
    
    print(f"\nReconstructed Q_x1: {reconstructed}")
    
    # Should be marginal of cover belief over x1
    expected = np.sum(cover_belief, axis=1)
    expected = expected / np.sum(expected)
    
    print(f"Expected (marginal): {expected}")
    
    recon_correct = np.allclose(reconstructed, expected)
    
    if recon_correct:
        print(f"\n✓ TEST PASSED: Reconstruction helper works correctly")
    else:
        print(f"\n✗ TEST FAILED: Reconstruction incorrect")
    
    return recon_correct


def test_reduction_ordering():
    """
    Test that colinear deletions are preferred when both options are available.
    
    Note: A factor can become colinear AFTER a linear deletion, so we can't
    require all colinear steps to come before all linear steps. Instead, we
    verify that when BOTH options are available, colinear is chosen first.
    """
    print("\n" + "=" * 70)
    print("TEST: Reduction Ordering (Colinear Preferred)")
    print("=" * 70)
    
    fg = FactorGraph("test_order")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    # Both x1 and x2 have unary factors
    fg.add_factor(Factor("u1", [x1], np.array([0.8, 0.2])))
    fg.add_factor(Factor("u2", [x2], np.array([0.3, 0.7])))
    
    # Binary factor
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    
    poset = from_factor_graph(fg)
    
    print(f"\nInitial state:")
    print(f"  Colinear factors: {sorted(poset.get_colinear_factors())}")
    print(f"  Linear variables: {sorted(poset.get_linear_variables())}")
    
    # Initially, there are colinear factors but no linear variables
    # (each variable has 2 factors above it)
    initial_colinear = poset.get_colinear_factors()
    initial_linear = poset.get_linear_variables()
    
    # Run reduction
    steps = reduce_to_core_spb(poset)
    
    print(f"\nReduction steps:")
    for i, step in enumerate(steps):
        print(f"  {i+1}. {step}")
    
    # The first steps should be colinear (removing unary factors)
    # since initially there are no linear variables
    first_steps_colinear = len(initial_colinear) > 0 and len(initial_linear) == 0
    
    if first_steps_colinear:
        first_colinear_count = 0
        for step in steps:
            if step.step_type == 'colinear':
                first_colinear_count += 1
            else:
                break
        ordering_good = first_colinear_count >= len(initial_colinear)
    else:
        ordering_good = True  # No clear preference needed
    
    if ordering_good:
        print(f"\n✓ TEST PASSED: Colinear deletions applied when available")
    else:
        print(f"\n✗ TEST FAILED: Ordering incorrect")
    
    return ordering_good


def test_core_reduction_on_tree():
    """
    Test that SP-B reduction on a tree preserves BP beliefs.
    
    On trees, BP is exact, so the beliefs on surviving regions should
    be recoverable/consistent.
    """
    print("\n" + "=" * 70)
    print("TEST: Core Reduction on Tree Preserves Structure")
    print("=" * 70)
    
    # Build a simple tree: x1 -- f12 -- x2 -- f23 -- x3
    fg = FactorGraph("tree")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.8, 0.2], [0.3, 0.7]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.6, 0.4], [0.5, 0.5]])))
    
    # Add unary factors
    fg.add_factor(Factor("u1", [x1], np.array([0.7, 0.3])))
    fg.add_factor(Factor("u3", [x3], np.array([0.4, 0.6])))
    
    poset = from_factor_graph(fg)
    
    print(f"\nInitial poset:")
    print(f"  Variables: {sorted(poset.variables)}")
    print(f"  Factors: {sorted(poset.factors)}")
    
    # Reduce to core
    steps = reduce_to_core_spb(poset)
    
    print(f"\nReduction steps ({len(steps)}):")
    for step in steps:
        print(f"  {step}")
    
    print(f"\nFinal core:")
    print(f"  Variables: {sorted(poset.variables)}")
    print(f"  Factors: {sorted(poset.factors)}")
    
    # On a tree, the core should be very small (possibly empty or single element)
    # since all variables eventually become linear
    
    # Verify all linear deletions were delete-only
    linear_steps = [s for s in steps if s.step_type == 'linear']
    all_delete_only = all(validate_linear_deletion_is_delete_only(s, poset) for s in linear_steps)
    
    if all_delete_only:
        print(f"\n✓ TEST PASSED: All linear deletions were delete-only")
    else:
        print(f"\n✗ TEST FAILED: Some linear deletions modified survivor tables")
    
    return all_delete_only


if __name__ == "__main__":
    print("Strict SP-B Reduction Unit Tests\n")
    print("These tests verify the implementation matches SP-B theory exactly.\n")
    
    results = []
    
    results.append(("Linear Deletion is Delete-Only", 
                    test_linear_deletion_is_delete_only_not_elimination()))
    results.append(("Unary Factors Not Absorbed", 
                    test_unary_factors_not_absorbed()))
    results.append(("Colinear Deletion Updates Tables", 
                    test_colinear_deletion_updates_tables()))
    results.append(("Reconstruction Helper", 
                    test_reconstruction_helper()))
    results.append(("Reduction Ordering", 
                    test_reduction_ordering()))
    results.append(("Core Reduction on Tree", 
                    test_core_reduction_on_tree()))
    
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    for name, passed in results:
        status = "✓ PASSED" if passed else "✗ FAILED"
        print(f"  {name}: {status}")
    
    all_passed = all(p for _, p in results)
    print(f"\nOverall: {'✓ ALL TESTS PASSED' if all_passed else '✗ SOME TESTS FAILED'}")
