"""
Benchmark Runner for SP-B Reduction + CRN Compilation

Runs comprehensive benchmarks to measure:
1. Reduction size gains (variables, factors, edges, CRN size)
2. Runtime speedups (compile time, simulation time)
3. Correctness preservation (marginal comparison)

Correctness is validated at multiple levels:
- BP(original FG) vs BP(reduced FG)
- BP(original FG) vs CRN simulation (original)
- BP(reduced FG) vs CRN simulation (reduced)
- CRN simulation (original) vs CRN simulation (reduced)
"""

import numpy as np
import time
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph
from inference import run_bp
from reduction.poset_reduction import (
    from_factor_graph, 
    to_factor_graph_if_possible,
    reduce_to_core_spb,
)
from crn import compile_factor_graph_to_crn, simulate_crn 
from crn.crn_reduction import reduce_crn_guided

from benchmarks.graph_generators import (
    generate_chain,
    generate_binary_tree,
    generate_loopy_core_with_tendrils,
    generate_grid_with_pruned_leaves,
    generate_random_with_planted_core,
    compute_graph_stats,
)


# === Helper functions for marginal comparison ===

def _normalize(p: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    """
    Normalize a probability vector to sum to 1.
    Forces non-negative values and handles edge cases.
    """
    p = np.asarray(p, dtype=float)
    p = np.maximum(p, 0.0)  # Force non-negative
    total = p.sum()
    if total < eps:
        # Uniform if degenerate
        return np.ones_like(p) / len(p)
    return p / total
def crn_marginals_from_sim(sim, var_names: List[str]) -> Dict[str, np.ndarray]:
    """
    Build marginals dict {var -> prob vector} from simulation concentrations,
    using species named Marginal_<var>_<k> with k starting at 1.
    """
    out: Dict[str, np.ndarray] = {}
    if sim is None or not hasattr(sim, "concentrations"):
        return out

    for v in var_names:
        # Collect all Marginal_v_k trajectories and take last timepoint
        vals = []
        ks = []
        prefix = f"Marginal_{v}_"
        for sname, traj in sim.concentrations.items():
            if sname.startswith(prefix):
                suffix = sname[len(prefix):]
                if suffix.isdigit():
                    k = int(suffix)
                    if k > 0:
                        ks.append(k)
                        arr = np.asarray(traj, dtype=float)
                        vals.append(float(arr[-1]) if arr.size else np.nan)

        if vals:
            order = np.argsort(ks)
            vec = np.array([vals[i] for i in order], dtype=float)
            out[v] = _normalize(vec)
    return out
def max_marginal_diff_between_dicts(
    A: Dict[str, np.ndarray], 
    B: Dict[str, np.ndarray], 
    keys: List[str]
) -> Tuple[float, int]:
    max_diff = 0.0
    n_compared = 0

    for key in keys:
        if key in A and key in B:
            try:
                p_a = _normalize(A[key])
                p_b = _normalize(B[key])
                if len(p_a) == len(p_b):
                    diff = float(np.max(np.abs(p_a - p_b)))
                    max_diff = max(max_diff, diff)
                    n_compared += 1
            except Exception:
                pass

    return (max_diff if n_compared > 0 else float('nan'), n_compared)


@dataclass
class BenchmarkResult:
    """Results from a single benchmark run."""
    name: str
    
    # Original graph stats
    orig_n_vars: int = 0
    orig_n_factors: int = 0
    orig_n_edges: int = 0
    orig_n_species: int = 0
    orig_n_reactions: int = 0
    
    # Reduced graph stats
    reduced_n_vars: int = 0
    reduced_n_factors: int = 0
    reduced_n_edges: int = 0
    reduced_n_species: int = 0
    reduced_n_reactions: int = 0
    
    # Reduction stats
    n_reduction_steps: int = 0
    n_linear_steps: int = 0
    n_colinear_steps: int = 0
    
    # Timing
    reduction_time: float = 0.0
    orig_compile_time: float = 0.0
    reduced_compile_time: float = 0.0
    orig_sim_time: float = 0.0
    reduced_sim_time: float = 0.0
    orig_bp_time: float = 0.0
    reduced_bp_time: float = 0.0
    
    # Correctness - BP vs BP
    marginal_max_diff: float = float('nan')
    bp_converged_orig: bool = False
    bp_converged_reduced: bool = False
    
    # Correctness - CRN level checks
    crn_vs_bp_orig_max_diff: float = float('nan')
    crn_vs_bp_reduced_max_diff: float = float('nan')
    crn_orig_vs_reduced_max_diff: float = float('nan')

    # Route B (compile → induced-reduce) CRN stats
    induced_n_species: int = 0
    induced_n_reactions: int = 0
    induced_reduce_time: float = 0.0
    induced_sim_time: float = 0.0

    # Route B correctness checks
    induced_crn_vs_bp_orig_max_diff: float = float('nan')
    induced_crn_vs_bp_reduced_max_diff: float = float('nan')
    crn_reduced_vs_induced_max_diff: float = float('nan')

    # Comparability bookkeeping (VERY important for datasets)
    n_survivor_vars: int = 0
    n_bp_compared: int = 0
    n_crn_orig_compared: int = 0
    n_crn_reduced_compared: int = 0
    n_crn_induced_compared: int = 0 #Check 5
    n_crn_crn_compared: int = 0 # Check 4
    n_induced_vs_bp_reduced_compared: int = 0 # Check 6
    n_reduced_vs_induced_compared: int = 0 #Check 7
    # Recording marginals
    has_crn_marginals_orig: bool = False
    has_crn_marginals_reduced: bool = False
    has_crn_marginals_induced: bool = False
    
    # Derived metrics
    @property
    def var_reduction_ratio(self) -> float:
        if self.orig_n_vars == 0:
            return 1.0
        return self.reduced_n_vars / self.orig_n_vars
    
    @property
    def species_reduction_ratio(self) -> float:
        if self.orig_n_species == 0:
            return 1.0
        return self.reduced_n_species / self.orig_n_species
    
    @property
    def compile_speedup(self) -> float:
        if self.reduced_compile_time == 0:
            return float('inf')
        return self.orig_compile_time / self.reduced_compile_time
    
    @property
    def sim_speedup(self) -> float:
        if self.reduced_sim_time == 0:
            return float('inf')
        return self.orig_sim_time / self.reduced_sim_time


def run_single_benchmark(fg: FactorGraph, name: str,
                         sim_time: float = 5000,
                         sim_points: int = 200,
                         verbose: bool = False) -> BenchmarkResult:
    """
    Run a complete benchmark on a single factor graph.
    
    Validates correctness at multiple levels:
    - BP marginals: original vs reduced
    - CRN marginals: simulation vs BP (both original and reduced)
    - CRN marginals: original simulation vs reduced simulation
    """
    result = BenchmarkResult(name=name)
    try:
        
        # === Original graph stats ===
        orig_stats = compute_graph_stats(fg)
        result.orig_n_vars = orig_stats['n_variables']
        result.orig_n_factors = orig_stats['n_factors']
        result.orig_n_edges = orig_stats['n_directed_edges']
        
        if verbose:
            print(f"  Original: {result.orig_n_vars} vars, {result.orig_n_factors} factors")
        
        # Track CRN objects and simulation results
        orig_crn = None
        reduced_crn = None
        orig_sim = None
        reduced_sim = None
        orig_bp = None
        reduced_bp = None
        induced_crn = None
        induced_sim = None
        
        # === Compile original CRN ===
        t0 = time.perf_counter()
        try:
            orig_crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
            result.orig_compile_time = time.perf_counter() - t0
            result.orig_n_species = len(orig_crn.species)
            result.orig_n_reactions = len(orig_crn.reactions)
        except Exception as e:
            if verbose:
                print(f"  Error compiling original CRN: {e}")
            return result
        
        # === Run BP on original ===
        t0 = time.perf_counter()
        try:
            orig_bp = run_bp(fg, tolerance=1e-8, max_iterations=500, damping=0.3)
            result.orig_bp_time = time.perf_counter() - t0
            result.bp_converged_orig = orig_bp.converged
        except Exception as e:
            if verbose:
                print(f"  Error running BP on original: {e}")
        
        # === SP-B Reduction ===
        t0 = time.perf_counter()
        poset = from_factor_graph(fg)
        steps = reduce_to_core_spb(poset)
        result.reduction_time = time.perf_counter() - t0
        
        result.n_reduction_steps = len(steps)
        result.n_linear_steps = sum(1 for s in steps if s.step_type == 'linear')
        result.n_colinear_steps = sum(1 for s in steps if s.step_type == 'colinear')
        
        if verbose:
            print(f"  Reduction: {result.n_reduction_steps} steps "
                f"({result.n_linear_steps} linear, {result.n_colinear_steps} colinear)")
        
        # === Convert to reduced factor graph ===
        reduced_fg = to_factor_graph_if_possible(poset)
        
        if reduced_fg is None or reduced_fg.num_variables == 0:
            if verbose:
                print(f"  Reduced to trivial (no variables)")
            result.reduced_n_vars = 0
            result.reduced_n_factors = 0
            result.reduced_n_species = 0
            result.reduced_n_reactions = 0
            return result
        
        reduced_stats = compute_graph_stats(reduced_fg)
        result.reduced_n_vars = reduced_stats['n_variables']
        result.reduced_n_factors = reduced_stats['n_factors']
        result.reduced_n_edges = reduced_stats['n_directed_edges']
        
        if verbose:
            print(f"  Reduced: {result.reduced_n_vars} vars, {result.reduced_n_factors} factors")
        
        # === Build survivor variable list ===
        orig_var_names = {v.name for v in fg.variables}
        survivor_names = [v.name for v in reduced_fg.variables if v.name in orig_var_names]
        result.n_survivor_vars = len(survivor_names)
        
        # === Compile reduced CRN ===
        t0 = time.perf_counter()
        try:
            reduced_crn = compile_factor_graph_to_crn(reduced_fg, kappa_r=0.02, kappa_prod=50.0)
            result.reduced_compile_time = time.perf_counter() - t0
            result.reduced_n_species = len(reduced_crn.species)
            result.reduced_n_reactions = len(reduced_crn.reactions)
        except Exception as e:
            if verbose:
                print(f"  Error compiling reduced CRN: {e}")
            return result
        # === Route B: Induced reduction on the FULL CRN guided by FG steps ===
        t0 = time.perf_counter()
        try:
            induced_crn, induced_steps = reduce_crn_guided(orig_crn, steps, copy=True)
            result.induced_reduce_time = time.perf_counter() - t0
            result.induced_n_species = len(induced_crn.species)
            result.induced_n_reactions = len(induced_crn.reactions)
        except Exception as e:
            # Do NOT abort the benchmark; just skip Route B.
            if verbose:
                print(f"  Error induced-reducing CRN (guided): {e}")

            induced_crn = None
            induced_steps = None

            # Record a consistent "Route B not available" state
            result.induced_reduce_time = float("nan")
            result.induced_n_species = 0
            result.induced_n_reactions = 0

        
        # === Run BP on reduced ===
        t0 = time.perf_counter()
        try:
            reduced_bp = run_bp(reduced_fg, tolerance=1e-8, max_iterations=500, damping=0.3)
            result.reduced_bp_time = time.perf_counter() - t0
            result.bp_converged_reduced = reduced_bp.converged
        except Exception as e:
            if verbose:
                print(f"  Error running BP on reduced: {e}")



        # === Simulate both CRNs (if small enough) ===
        if result.orig_n_species <= 5000 and result.reduced_n_species <= 5000 and (result.induced_n_species <= 5000 or induced_crn is None):
            # Induced-reduced CRN simulation (Route B)
            if induced_crn is not None and len(induced_crn.reactions) > 0:
                t0 = time.perf_counter()
                try:
                    induced_sim = simulate_crn(induced_crn, t_end=sim_time, n_points=sim_points)
                    result.induced_sim_time = time.perf_counter() - t0
                except Exception as e:
                    if verbose:
                        print(f"  Error simulating induced CRN: {e}")
            # Original CRN simulation
            t0 = time.perf_counter()
            try:
                orig_sim = simulate_crn(orig_crn, t_end=sim_time, n_points=sim_points)
                result.orig_sim_time = time.perf_counter() - t0
            except Exception as e:
                if verbose:
                    print(f"  Error simulating original CRN: {e}")
            
            # Reduced CRN simulation
            t0 = time.perf_counter()
            try:
                reduced_sim = simulate_crn(reduced_crn, t_end=sim_time, n_points=sim_points)
                result.reduced_sim_time = time.perf_counter() - t0
            except Exception as e:
                if verbose:
                    print(f"  Error simulating reduced CRN: {e}")
        
        # === Correctness check 1: BP(orig) vs BP(reduced) ===
        if result.bp_converged_orig and result.bp_converged_reduced and orig_bp and reduced_bp:
            bp_margs_orig = {vname: orig_bp.get_marginal(vname) for vname in survivor_names 
                            if orig_bp.get_marginal(vname) is not None}
            bp_margs_red = {vname: reduced_bp.get_marginal(vname) for vname in survivor_names
                            if reduced_bp.get_marginal(vname) is not None}
            diff, ncmp = max_marginal_diff_between_dicts(bp_margs_orig, bp_margs_red, survivor_names)
            result.marginal_max_diff = diff
            result.n_bp_compared = ncmp   
        # === Correctness check 2: CRN(orig) vs BP(orig) ===
        if orig_sim is not None and result.bp_converged_orig and orig_bp:
            bp_margs_orig = {v: orig_bp.get_marginal(v) for v in survivor_names
                            if orig_bp.get_marginal(v) is not None}
            crn_margs_orig = crn_marginals_from_sim(orig_sim, survivor_names)
            result.has_crn_marginals_orig = (len(crn_margs_orig) > 0)
            diff, ncmp = max_marginal_diff_between_dicts(crn_margs_orig, bp_margs_orig, survivor_names)
            result.crn_vs_bp_orig_max_diff = diff
            result.n_crn_orig_compared = ncmp
        
# === Correctness check 3: CRN(reduced) vs BP(reduced) ===
        if reduced_sim is not None and result.bp_converged_reduced and reduced_bp:
            bp_margs_red = {v: reduced_bp.get_marginal(v) for v in survivor_names
                            if reduced_bp.get_marginal(v) is not None}
            crn_margs_red = crn_marginals_from_sim(reduced_sim, survivor_names)
            result.has_crn_marginals_reduced = (len(crn_margs_red) > 0)
            diff, ncmp = max_marginal_diff_between_dicts(crn_margs_red, bp_margs_red, survivor_names)
            result.crn_vs_bp_reduced_max_diff = diff
            result.n_crn_reduced_compared = ncmp       
        
# === Correctness check 4: CRN(orig) vs CRN(reduced) ===
        if orig_sim is not None and reduced_sim is not None:
            crn_margs_orig = crn_marginals_from_sim(orig_sim, survivor_names)
            crn_margs_red  = crn_marginals_from_sim(reduced_sim, survivor_names)
            result.has_crn_marginals_orig = (len(crn_margs_orig) > 0)
            result.has_crn_marginals_reduced = (len(crn_margs_red) > 0)
            diff, ncmp = max_marginal_diff_between_dicts(crn_margs_orig, crn_margs_red, survivor_names)
            result.crn_orig_vs_reduced_max_diff = diff
            result.n_crn_crn_compared = ncmp

        # === Correctness check 5: Induced CRN vs BP(original) ===
        if induced_sim is not None and result.bp_converged_orig and orig_bp:
            bp_margs_orig = {v: orig_bp.get_marginal(v) for v in survivor_names
                            if orig_bp.get_marginal(v) is not None}
            induced_margs = crn_marginals_from_sim(induced_sim, survivor_names)
            result.has_crn_marginals_induced = (len(induced_margs) >0)
            diff, ncmp = max_marginal_diff_between_dicts(induced_margs, bp_margs_orig, survivor_names)
            result.induced_crn_vs_bp_orig_max_diff = diff
            result.n_crn_induced_compared = ncmp

        # === Correctness check 6: Induced CRN vs BP(reduced) ===
        if induced_sim is not None and result.bp_converged_reduced and reduced_bp:
            bp_margs_red = {v: reduced_bp.get_marginal(v) for v in survivor_names
                            if reduced_bp.get_marginal(v) is not None}
            induced_margs = crn_marginals_from_sim(induced_sim, survivor_names)
            result.has_crn_marginals_induced = (len(induced_margs) >0)
            diff, ncmp = max_marginal_diff_between_dicts(induced_margs, bp_margs_red, survivor_names)
            result.induced_crn_vs_bp_reduced_max_diff = diff 
            result.n_induced_vs_bp_reduced_compared = ncmp

        # === Correctness check 7: Reduced CRN (Route A) vs Induced-reduced CRN (Route B) ===
        if reduced_sim is not None and induced_sim is not None:
            crn_margs_red = crn_marginals_from_sim(reduced_sim, survivor_names)
            induced_margs = crn_marginals_from_sim(induced_sim, survivor_names)
            diff, ncmp = max_marginal_diff_between_dicts(crn_margs_red, induced_margs, survivor_names)
            result.crn_reduced_vs_induced_max_diff = diff 
            result.n_reduced_vs_induced_compared = ncmp

        return result
    except Exception as e:
        if verbose:
            print(f"[run_single_benchmark] ERROR on {name}: {e}")
        induced_crn = None
        result.induced_n_species = 0 
        result.induced_n_reactions =0 
        result.induced_reduce_time = float('nan')
        return result


def run_benchmark_suite(verbose: bool = True) -> List[BenchmarkResult]:
    """Run the full benchmark suite."""
    results = []
    
    # === Chain benchmarks ===
    if verbose:
        print("\n" + "=" * 60)
        print("CHAIN BENCHMARKS")
        print("=" * 60)
    
    for n in [5, 10, 20, 50, 100]:
        if verbose:
            print(f"\nChain({n}):")
        fg = generate_chain(n)
        result = run_single_benchmark(fg, f"chain_{n}", verbose=verbose)
        results.append(result)
    
    # === Tree benchmarks ===
    if verbose:
        print("\n" + "=" * 60)
        print("BINARY TREE BENCHMARKS")
        print("=" * 60)
    
    for depth in [3, 4, 5, 6]:
        if verbose:
            print(f"\nTree(depth={depth}):")
        fg = generate_binary_tree(depth)
        result = run_single_benchmark(fg, f"tree_d{depth}", verbose=verbose)
        results.append(result)
    
    # === Loopy core with tendrils ===
    if verbose:
        print("\n" + "=" * 60)
        print("LOOPY CORE + TENDRILS BENCHMARKS")
        print("=" * 60)
    
    for core_size in [3, 4, 5]:
        for tendril_len in [1, 3, 5, 10]:
            if verbose:
                print(f"\nLoopyCore({core_size})+Tendrils({tendril_len}):")
            fg = generate_loopy_core_with_tendrils(core_size, tendril_len)
            result = run_single_benchmark(fg, f"loopy_c{core_size}_t{tendril_len}", 
                                         verbose=verbose)
            results.append(result)
    
    # === Grid benchmarks ===
    if verbose:
        print("\n" + "=" * 60)
        print("GRID BENCHMARKS")
        print("=" * 60)
    
    for size in [3, 4, 5, 6]:
        if verbose:
            print(f"\nGrid({size}x{size}):")
        fg = generate_grid_with_pruned_leaves(size, size, prune_fraction=0.5)
        result = run_single_benchmark(fg, f"grid_{size}x{size}", verbose=verbose)
        results.append(result)
    
    # === Random with planted core ===
    if verbose:
        print("\n" + "=" * 60)
        print("RANDOM + PLANTED CORE BENCHMARKS")
        print("=" * 60)
    
    for n_total in [15, 25, 40]:
        core_size = max(3, n_total // 5)
        if verbose:
            print(f"\nRandom({n_total}, core={core_size}):")
        fg = generate_random_with_planted_core(n_total, core_size)
        result = run_single_benchmark(fg, f"random_{n_total}_c{core_size}", 
                                     verbose=verbose)
        results.append(result)
    
    return results


def print_summary_table(results: List[BenchmarkResult]):
    """Print a summary table of benchmark results."""
    print("\n" + "=" * 120)
    print("BENCHMARK SUMMARY")
    print("=" * 120)
    
    print(f"\n{'Name':<25} {'Vars':<12} {'Species':<15} {'Compile(s)':<15} {'Sim(s)':<15} {'BP Diff':<12} {'CRN Diff':<12}")
    print(f"{'':<25} {'Orig→Red':<12} {'Orig→Red':<15} {'Orig→Red':<15} {'Orig→Red':<15} {'(BP↔BP)':<12} {'(CRN↔CRN)':<12}")
    print("-" * 120)
    
    for r in results:
        vars_str = f"{r.orig_n_vars}→{r.reduced_n_vars}"
        species_str = f"{r.orig_n_species}→{r.reduced_n_species}"
        compile_str = f"{r.orig_compile_time:.3f}→{r.reduced_compile_time:.3f}"
        
        if r.orig_sim_time > 0 and r.reduced_sim_time > 0:
            sim_str = f"{r.orig_sim_time:.2f}→{r.reduced_sim_time:.2f}"
        else:
            sim_str = "N/A"
        
        # BP-vs-BP diff
        if np.isfinite(r.marginal_max_diff):
            bp_diff_str = f"{r.marginal_max_diff:.2e}" if r.marginal_max_diff > 1e-10 else "<1e-10"
        else:
            bp_diff_str = "N/A"
        
        # CRN-vs-CRN diff
        if np.isfinite(r.crn_orig_vs_reduced_max_diff):
            crn_diff_str = f"{r.crn_orig_vs_reduced_max_diff:.2e}" if r.crn_orig_vs_reduced_max_diff > 1e-10 else "<1e-10"
        else:
            crn_diff_str = "N/A"
        
        print(f"{r.name:<25} {vars_str:<12} {species_str:<15} {compile_str:<15} {sim_str:<15} {bp_diff_str:<12} {crn_diff_str:<12}")
    
    # Compute aggregate statistics
    print("\n" + "-" * 120)
    print("AGGREGATE STATISTICS:")
    
    valid_results = [r for r in results if r.orig_n_vars > 0]
    
    avg_var_reduction = np.mean([r.var_reduction_ratio for r in valid_results])
    avg_species_reduction = np.mean([r.species_reduction_ratio for r in valid_results 
                                     if r.orig_n_species > 0])
    
    compile_speedups = [r.compile_speedup for r in valid_results 
                       if r.reduced_compile_time > 0 and not np.isinf(r.compile_speedup)]
    avg_compile_speedup = np.mean(compile_speedups) if compile_speedups else 0
    
    sim_speedups = [r.sim_speedup for r in valid_results 
                   if r.reduced_sim_time > 0 and not np.isinf(r.sim_speedup)]
    avg_sim_speedup = np.mean(sim_speedups) if sim_speedups else 0
    
    # Filter to finite correctness metrics
    bp_diffs = [r.marginal_max_diff for r in valid_results if np.isfinite(r.marginal_max_diff)]
    crn_vs_bp_orig = [r.crn_vs_bp_orig_max_diff for r in valid_results if np.isfinite(r.crn_vs_bp_orig_max_diff)]
    crn_vs_bp_red = [r.crn_vs_bp_reduced_max_diff for r in valid_results if np.isfinite(r.crn_vs_bp_reduced_max_diff)]
    crn_vs_crn = [r.crn_orig_vs_reduced_max_diff for r in valid_results if np.isfinite(r.crn_orig_vs_reduced_max_diff)]
    
    print(f"  Average variable reduction: {avg_var_reduction:.1%}")
    print(f"  Average species reduction: {avg_species_reduction:.1%}")
    print(f"  Average compile speedup: {avg_compile_speedup:.2f}x")
    print(f"  Average simulation speedup: {avg_sim_speedup:.2f}x")
    
    print(f"\n  Correctness metrics (max differences):")
    if bp_diffs:
        print(f"    BP(orig) vs BP(reduced):     max={max(bp_diffs):.2e}, mean={np.mean(bp_diffs):.2e} ({len(bp_diffs)} cases)")
    else:
        print(f"    BP(orig) vs BP(reduced):     N/A (no converged cases)")
    
    if crn_vs_bp_orig:
        print(f"    CRN(orig) vs BP(orig):       max={max(crn_vs_bp_orig):.2e}, mean={np.mean(crn_vs_bp_orig):.2e} ({len(crn_vs_bp_orig)} cases)")
    else:
        print(f"    CRN(orig) vs BP(orig):       N/A")
    
    if crn_vs_bp_red:
        print(f"    CRN(reduced) vs BP(reduced): max={max(crn_vs_bp_red):.2e}, mean={np.mean(crn_vs_bp_red):.2e} ({len(crn_vs_bp_red)} cases)")
    else:
        print(f"    CRN(reduced) vs BP(reduced): N/A")
    
    if crn_vs_crn:
        print(f"    CRN(orig) vs CRN(reduced):   max={max(crn_vs_crn):.2e}, mean={np.mean(crn_vs_crn):.2e} ({len(crn_vs_crn)} cases)")
    else:
        print(f"    CRN(orig) vs CRN(reduced):   N/A")


def save_results_to_csv(results: List[BenchmarkResult], filename: str):
    """Save benchmark results to CSV for plotting."""
    import csv
    
    with open(filename, 'w', newline='') as f:
        writer = csv.writer(f)
        
        # Header
        writer.writerow([
            'name', 'orig_vars', 'reduced_vars', 'orig_factors', 'reduced_factors',
            'orig_edges', 'reduced_edges', 'orig_species', 'reduced_species',
            'orig_reactions', 'reduced_reactions', 'n_reduction_steps',
            'reduction_time', 'orig_compile_time', 'reduced_compile_time',
            'orig_sim_time', 'reduced_sim_time', 'orig_bp_time', 'reduced_bp_time',
            'marginal_max_diff', 'var_reduction_ratio', 'species_reduction_ratio',
            'compile_speedup', 'sim_speedup',
            # New columns for CRN-level correctness
            'bp_converged_orig', 'bp_converged_reduced',
            'crn_vs_bp_orig_max_diff', 'crn_vs_bp_reduced_max_diff', 
            'crn_orig_vs_reduced_max_diff',
            'induced_species', 'induced_reactions',
            'induced_reduce_time', 'induced_sim_time',
            'induced_crn_vs_bp_orig_max_diff', 'induced_crn_vs_bp_reduced_max_diff',
            'crn_reduced_vs_induced_max_diff',
            'n_survivor_vars', 'n_bp_compared', 'n_crn_orig_compared',
            'n_crn_reduced_compared', 'n_crn_induced_compared',
            'n_crn_crn_compared', 'n_induced_vs_bp_reduced_compared', 'n_reduced_vs_induced_compared'
        ])
       
        for r in results:
            writer.writerow([
                r.name, r.orig_n_vars, r.reduced_n_vars, r.orig_n_factors, r.reduced_n_factors,
                r.orig_n_edges, r.reduced_n_edges, r.orig_n_species, r.reduced_n_species,
                r.orig_n_reactions, r.reduced_n_reactions, r.n_reduction_steps,
                r.reduction_time, r.orig_compile_time, r.reduced_compile_time,
                r.orig_sim_time, r.reduced_sim_time, r.orig_bp_time, r.reduced_bp_time,
                r.marginal_max_diff if np.isfinite(r.marginal_max_diff) else '',
                r.var_reduction_ratio, r.species_reduction_ratio,
                r.compile_speedup if not np.isinf(r.compile_speedup) else -1,
                r.sim_speedup if not np.isinf(r.sim_speedup) else -1,
                # New columns
                1 if r.bp_converged_orig else 0,
                1 if r.bp_converged_reduced else 0,
                r.crn_vs_bp_orig_max_diff if np.isfinite(r.crn_vs_bp_orig_max_diff) else '',
                r.crn_vs_bp_reduced_max_diff if np.isfinite(r.crn_vs_bp_reduced_max_diff) else '',
                r.crn_orig_vs_reduced_max_diff if np.isfinite(r.crn_orig_vs_reduced_max_diff) else '',
                r.induced_n_species, r.induced_n_reactions,
                r.induced_reduce_time, r.induced_sim_time,
                r.induced_crn_vs_bp_orig_max_diff if np.isfinite(r.induced_crn_vs_bp_orig_max_diff) else '',
                r.induced_crn_vs_bp_reduced_max_diff if np.isfinite(r.induced_crn_vs_bp_reduced_max_diff) else '',
                r.crn_reduced_vs_induced_max_diff if np.isfinite(r.crn_reduced_vs_induced_max_diff) else '',
                r.n_survivor_vars, r.n_bp_compared, r.n_crn_orig_compared,
                r.n_crn_reduced_compared, r.n_crn_induced_compared, r.n_crn_crn_compared, r.n_induced_vs_bp_reduced_compared, r.n_reduced_vs_induced_compared
            ])
    
    print(f"\nResults saved to {filename}")


if __name__ == "__main__":
    print("Running SP-B Reduction Benchmarks")
    print("=" * 60)
    
    results = run_benchmark_suite(verbose=True)
    print_summary_table(results)
    
    # Save to CSV
    os.makedirs("/home/mauwork/factor_graph_project/results", exist_ok=True)
    save_results_to_csv(results, "/home/mauwork/factor_graph_project/results/benchmark_results.csv")
