"""
Tests for Direct CRN-to-CRN Reduction

Verifies that SP-B retractions can be applied directly to CRNs,
modifying species, reactions, and rate constants in place.

Key tests:
1. Linear retraction deletes species/reactions without changing rates
2. Colinear retraction deletes + updates survivor rates
3. Full reduction matches recompiled CRN structure
4. Steady states match between both approaches
"""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph
from inference import run_bp
from crn import (
    compile_factor_graph_to_crn, 
    simulate_crn,
    CRNReducer,
    reduce_crn_to_core,
    compare_crn_structures,
)
from reduction import from_factor_graph, reduce_to_core_spb, to_factor_graph_if_possible


def test_linear_retraction_deletes_only():
    """
    Test that linear retraction on CRN deletes species/reactions
    but does NOT modify rate constants of surviving reactions.
    """
    print("=" * 70)
    print("TEST: Linear Retraction Deletes Only (No Rate Changes)")
    print("=" * 70)
    
    # Simple chain: x1 -- f12 -- x2 with unary at x1
    fg = FactorGraph("chain")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    f12_table = np.array([[0.9, 0.1], [0.2, 0.8]])
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    fg.add_factor(Factor("u1", [x1], np.array([0.7, 0.3])))
    
    # Compile to CRN
    crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    
    print(f"\nOriginal CRN: {len(crn.species)} species, {len(crn.reactions)} reactions")
    
    # Get survivor reaction rates BEFORE reduction
    survivor_rates_before = {}
    for rxn in crn.reactions:
        if 'f12' in rxn.description and 'x2' in rxn.description:
            survivor_rates_before[rxn.description] = rxn.rate_constant
    
    print(f"Survivor reaction rates before: {len(survivor_rates_before)}")
    
    # Apply CRN reduction
    reducer = CRNReducer(crn)
    
    # First remove the unary factor (colinear)
    print(f"\nColinear factors: {reducer.get_colinear_factors()}")
    print(f"Linear variables: {reducer.get_linear_variables()}")
    
    if reducer.is_colinear("u1"):
        step1 = reducer.retract_colinear("u1")
        print(f"\n{step1}")
    
    # Now x1 should be linear
    print(f"\nAfter colinear: Linear variables: {reducer.get_linear_variables()}")
    
    if reducer.is_linear("x1"):
        # Capture rates right before linear deletion
        rates_before_linear = {}
        for rxn in crn.reactions:
            if 'Sum msg f12→x2' in rxn.description:
                rates_before_linear[rxn.description] = rxn.rate_constant
        
        step2 = reducer.retract_linear("x1")
        print(f"\n{step2}")
        
        # Check rates AFTER linear deletion
        rates_after_linear = {}
        for rxn in crn.reactions:
            if 'Sum msg f12→x2' in rxn.description:
                rates_after_linear[rxn.description] = rxn.rate_constant
        
        # Rates should be UNCHANGED by linear deletion
        rates_unchanged = True
        for desc, rate_before in rates_before_linear.items():
            rate_after = rates_after_linear.get(desc)
            if rate_after is not None and not np.isclose(rate_before, rate_after):
                print(f"  Rate changed! {desc}: {rate_before} → {rate_after}")
                rates_unchanged = False
        
        if rates_unchanged:
            print(f"\n✓ Linear deletion did NOT change survivor reaction rates")
        else:
            print(f"\n✗ Linear deletion changed rates (BUG!)")
    
    print(f"\nReduced CRN: {len(crn.species)} species, {len(crn.reactions)} reactions")
    
    return True


def test_colinear_retraction_updates_rates():
    """
    Test that colinear retraction updates survivor reaction rates
    according to Eq 4.29.
    """
    print("\n" + "=" * 70)
    print("TEST: Colinear Retraction Updates Rates (Eq 4.29)")
    print("=" * 70)
    
    # x1 with two factors: unary u1 and binary f12
    fg = FactorGraph("test")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    f12_table = np.array([[2.0, 1.0], [1.0, 3.0]])
    u1_table = np.array([10.0, 1.0])  # Strong bias
    
    fg.add_factor(Factor("f12", [x1, x2], f12_table))
    fg.add_factor(Factor("u1", [x1], u1_table))
    
    crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    
    print(f"\nOriginal CRN: {len(crn.species)} species, {len(crn.reactions)} reactions")
    
    # Get f12→x1 sum message rates before reduction
    rates_before = {}
    for rxn in crn.reactions:
        if 'Sum msg f12→x1' in rxn.description:
            # Extract k from products
            for sp in rxn.products:
                if sp.startswith('S_f12→x1['):
                    k = int(sp.split('[')[1].rstrip(']'))
                    rates_before[k] = rxn.rate_constant
                    print(f"  Before: f12→x1[{k}] rate = {rxn.rate_constant}")
    
    # Apply colinear retraction
    reducer = CRNReducer(crn)
    step = reducer.retract_colinear("u1", removed_table=u1_table)
    print(f"\n{step}")
    
    # Get f12→x1 sum message rates after reduction
    # Note: After removing u1, x1 becomes linear, so f12→x1 reactions should
    # have their rates multiplied by u1's values
    rates_after = {}
    for rxn in crn.reactions:
        if 'Sum msg f12→x1' in rxn.description:
            for sp in rxn.products:
                if sp.startswith('S_f12→x1['):
                    k = int(sp.split('[')[1].rstrip(']'))
                    rates_after[k] = rxn.rate_constant
                    print(f"  After:  f12→x1[{k}] rate = {rxn.rate_constant}")
    
    # Verify: rates should be multiplied by u1_table
    print(f"\nExpected rate updates (multiply by u1):")
    all_correct = True
    for k in rates_before:
        expected = rates_before[k] * u1_table[k-1]
        actual = rates_after.get(k, 0)
        match = np.isclose(expected, actual)
        print(f"  k={k}: {rates_before[k]} * {u1_table[k-1]} = {expected}, got {actual} {'✓' if match else '✗'}")
        if not match:
            all_correct = False
    
    if all_correct:
        print(f"\n✓ Colinear retraction correctly updated rates via Eq 4.29")
    else:
        print(f"\n✗ Rate updates incorrect")
    
    return all_correct


def test_full_reduction_matches_recompile():
    """
    Test that direct CRN reduction produces same structure as
    FG reduce → recompile.
    """
    print("\n" + "=" * 70)
    print("TEST: Full CRN Reduction Matches Recompile")
    print("=" * 70)
    
    # Chain: x1 -- f12 -- x2 -- f23 -- x3 with unaries at endpoints
    fg = FactorGraph("chain")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.8, 0.2], [0.3, 0.7]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.6, 0.4], [0.5, 0.5]])))
    fg.add_factor(Factor("u1", [x1], np.array([0.7, 0.3])))
    fg.add_factor(Factor("u3", [x3], np.array([0.4, 0.6])))
    
    print(f"\nOriginal FG: {fg.num_variables} vars, {fg.num_factors} factors")
    
    # PATH A: FG → CRN → reduce CRN directly
    crn_direct = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    print(f"\nPATH A: Direct CRN reduction")
    print(f"  Original CRN: {len(crn_direct.species)} species, {len(crn_direct.reactions)} reactions")
    
    reduced_crn, steps = reduce_crn_to_core(crn_direct, copy=True)
    print(f"  Reduction steps: {len(steps)}")
    for s in steps:
        print(f"    {s}")
    print(f"  Reduced CRN: {len(reduced_crn.species)} species, {len(reduced_crn.reactions)} reactions")
    
    # PATH B: FG → reduce FG → recompile CRN
    print(f"\nPATH B: FG reduce → recompile")
    poset = from_factor_graph(fg)
    fg_steps = reduce_to_core_spb(poset)
    reduced_fg = to_factor_graph_if_possible(poset)
    
    print(f"  FG reduction steps: {len(fg_steps)}")
    
    if reduced_fg and reduced_fg.num_variables > 0:
        crn_recompiled = compile_factor_graph_to_crn(reduced_fg, kappa_r=0.02, kappa_prod=50.0)
        print(f"  Reduced FG: {reduced_fg.num_variables} vars, {reduced_fg.num_factors} factors")
        print(f"  Recompiled CRN: {len(crn_recompiled.species)} species, {len(crn_recompiled.reactions)} reactions")
    else:
        print(f"  Reduced to trivial FG")
        crn_recompiled = None
    
    # Compare structures
    print(f"\nComparison:")
    print(f"  Direct reduced: {len(reduced_crn.species)} species, {len(reduced_crn.reactions)} reactions")
    if crn_recompiled:
        print(f"  Recompiled:     {len(crn_recompiled.species)} species, {len(crn_recompiled.reactions)} reactions")
    
    return True


def test_steady_states_match():
    """
    Test that steady states match between:
    - Original CRN
    - Directly reduced CRN
    - Recompiled CRN
    """
    print("\n" + "=" * 70)
    print("TEST: Steady States Match")
    print("=" * 70)
    
    # Chain with unary at one end
    fg = FactorGraph("chain")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.7, 0.3], [0.4, 0.6]])))
    fg.add_factor(Factor("u1", [x1], np.array([0.8, 0.2])))
    fg.add_factor(Factor("u3", [x3], np.array([0.3, 0.7])))
    
    # Get BP marginals as ground truth
    bp_result = run_bp(fg)
    print(f"\nBP marginals:")
    for var in fg.variables:
        m = bp_result.get_marginal(var.name)
        print(f"  P({var.name}) = {m}")
    
    # Original CRN steady state
    crn_orig = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    sim_orig = simulate_crn(crn_orig, t_end=5000, n_points=100)
    
    print(f"\nOriginal CRN marginals:")
    orig_marginals = {}
    for var in fg.variables:
        m = sim_orig.get_marginal(var.name)
        orig_marginals[var.name] = m
        print(f"  P({var.name}) = {m}")
    
    # Directly reduced CRN
    reduced_crn, _ = reduce_crn_to_core(crn_orig, copy=True)
    
    print(f"\nReduced CRN: {len(reduced_crn.species)} species, {len(reduced_crn.reactions)} reactions")
    
    if len(reduced_crn.species) > 0 and len(reduced_crn.reactions) > 0:
        try:
            sim_reduced = simulate_crn(reduced_crn, t_end=5000, n_points=100)
            print(f"\nReduced CRN surviving species:")
            for name in sorted(reduced_crn.species.keys()):
                print(f"  {name}")
        except Exception as e:
            print(f"\nCould not simulate reduced CRN: {e}")
    else:
        print(f"\nReduced CRN is trivial (no reactions)")
    
    return True


if __name__ == "__main__":
    print("Direct CRN Reduction Tests\n")
    
    results = []
    results.append(("Linear Retraction Deletes Only", test_linear_retraction_deletes_only()))
    results.append(("Colinear Retraction Updates Rates", test_colinear_retraction_updates_rates()))
    results.append(("Full Reduction Matches Recompile", test_full_reduction_matches_recompile()))
    results.append(("Steady States Match", test_steady_states_match()))
    
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    for name, passed in results:
        status = "✓ PASSED" if passed else "✗ FAILED"
        print(f"  {name}: {status}")
