# crn/crn_reduction.py
"""
CRN Reduction Module

Applies SP-B retractions directly to Chemical Reaction Networks,
transforming ODE coefficients in place without recompiling from factor graphs.

This module implements the CRN-level correspondence to SP-B reductions:

1. LINEAR RETRACTION (delete leaf variable):
   - Delete species: S^{f→v}_k, P^{v→f}_k, P^v_k for the removed variable v
   - Delete reactions: All recycling, sum, product, marginal reactions involving v
   - NO rate constant changes to surviving reactions (this is key!)
   
2. COLINEAR RETRACTION (delete unary factor):
   - Delete species: S^{u→v}_k, P^{v→u}_k for the unary factor u
   - Delete reactions: All reactions involving the unary factor
   - UPDATE rate constants of surviving reactions per Eq 4.29 or 4.30:
     * Eq 4.29: Multiply survivor factor's sum message rates by removed unary values
     * Eq 4.30: Update variable region (affects marginal computation)

The key insight is that SP-B retractions correspond to:
- Linear: Structural pruning of the CRN (no parameter changes)
- Colinear: Parameter updates to surviving reactions

Two reduction modes:
- AGGRESSIVE: Removes all reducible structure, may leave only marginal species
- STRUCTURAL: Stops when core is reached, preserving message-passing structure
"""

from typing import List, Dict, Set, Tuple, Optional, Any
from dataclasses import dataclass, field
from copy import deepcopy
import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from crn.crn_compiler import ChemicalReactionNetwork, Species, Reaction


@dataclass
class CRNReductionStep:
    """Record of a single CRN reduction step."""
    step_type: str  # 'linear' or 'colinear'
    removed_variable: Optional[str] = None
    removed_factor: Optional[str] = None
    deleted_species: List[str] = field(default_factory=list)
    deleted_reactions: int = 0
    updated_reactions: int = 0
    details: Dict[str, Any] = field(default_factory=dict)
    
    def __repr__(self):
        if self.step_type == 'linear':
            return (f"CRNReductionStep(LINEAR: remove var '{self.removed_variable}', "
                    f"deleted {len(self.deleted_species)} species, {self.deleted_reactions} reactions)")
        else:
            return (f"CRNReductionStep(COLINEAR: remove factor '{self.removed_factor}', "
                    f"deleted {len(self.deleted_species)} species, "
                    f"updated {self.updated_reactions} reaction rates)")


class CRNReducer:
    """
    Applies SP-B retractions directly to a Chemical Reaction Network.
    
    This transforms the CRN in place, modifying species, reactions, and rate
    constants to reflect the reduced factor graph structure.
    """
    
    def __init__(self, crn: ChemicalReactionNetwork):
        """
        Initialize the reducer with a CRN to reduce.
        
        Args:
            crn: The CRN to reduce (will be modified in place, so copy first if needed)
        """
        self.crn = crn
        self.reduction_history: List[CRNReductionStep] = []
        
        # Parse the CRN structure to understand variable/factor relationships
        self._parse_crn_structure()
    
    def _parse_crn_structure(self):
        """
        Parse the CRN to extract variable and factor information.
        
        Builds indices for:
        - Which species belong to which variable/factor
        - Which reactions involve which species
        """
        self.variables: Set[str] = set()
        self.factors: Set[str] = set()
        
        # Species grouped by type and source
        self.var_to_species: Dict[str, Set[str]] = {}  # var -> {species names}
        self.factor_to_species: Dict[str, Set[str]] = {}  # factor -> {species names}
        
        # Parse species names to extract structure
        # Actual formats used by CRNCompiler:
        # - Sum message: S_{factor}_to_{var}_{k}
        # - Product message: P_{var}_to_{factor}_{k}
        # - Marginal: Marginal_{var}_{k}
        
        for name, species in self.crn.species.items():
            if name.startswith('S_') and '_to_' in name:
                # Sum message: S_{factor}_to_{var}_{k}
                parts = name[2:].split('_to_')
                if len(parts) == 2:
                    factor = parts[0]
                    # var_k format: {var}_{k}
                    var_parts = parts[1].rsplit('_', 1)
                    if len(var_parts) == 2:
                        var = var_parts[0]
                        self.factors.add(factor)
                        self.variables.add(var)
                        
                        if factor not in self.factor_to_species:
                            self.factor_to_species[factor] = set()
                        self.factor_to_species[factor].add(name)
                        
                        if var not in self.var_to_species:
                            self.var_to_species[var] = set()
                        self.var_to_species[var].add(name)
                    
            elif name.startswith('P_') and '_to_' in name:
                # Product message: P_{var}_to_{factor}_{k}
                parts = name[2:].split('_to_')
                if len(parts) == 2:
                    var = parts[0]
                    # factor_k format: {factor}_{k}
                    factor_parts = parts[1].rsplit('_', 1)
                    if len(factor_parts) == 2:
                        factor = factor_parts[0]
                        self.variables.add(var)
                        self.factors.add(factor)
                        
                        if var not in self.var_to_species:
                            self.var_to_species[var] = set()
                        self.var_to_species[var].add(name)
                        
                        if factor not in self.factor_to_species:
                            self.factor_to_species[factor] = set()
                        self.factor_to_species[factor].add(name)
                    
            elif name.startswith('Marginal_'):
                # Marginal: Marginal_{var}_{k}
                parts = name[9:].rsplit('_', 1)  # Skip "Marginal_"
                if len(parts) == 2:
                    var = parts[0]
                    self.variables.add(var)
                    
                    if var not in self.var_to_species:
                        self.var_to_species[var] = set()
                    self.var_to_species[var].add(name)
    
    def get_variable_factors(self, var: str) -> Set[str]:
        """Get all factors connected to a variable."""
        factors = set()
        for name in self.var_to_species.get(var, set()):
            if '_to_' in name:
                if name.startswith('S_'):
                    # S_{factor}_to_{var}_{k} -> extract factor
                    parts = name[2:].split('_to_')
                    if len(parts) >= 1:
                        factor = parts[0]
                        factors.add(factor)
                elif name.startswith('P_'):
                    # P_{var}_to_{factor}_{k} -> extract factor
                    parts = name[2:].split('_to_')
                    if len(parts) >= 2:
                        # factor_k, need to strip trailing _k
                        factor_k = parts[1]
                        factor = factor_k.rsplit('_', 1)[0]
                        factors.add(factor)
        return factors
    
    def get_factor_variables(self, factor: str) -> Set[str]:
        """Get all variables connected to a factor."""
        variables = set()
        for name in self.factor_to_species.get(factor, set()):
            if '_to_' in name:
                if name.startswith('S_'):
                    # S_{factor}_to_{var}_{k} -> extract var
                    parts = name[2:].split('_to_')
                    if len(parts) >= 2:
                        var_k = parts[1]
                        var = var_k.rsplit('_', 1)[0]
                        variables.add(var)
                elif name.startswith('P_'):
                    # P_{var}_to_{factor}_{k} -> extract var
                    parts = name[2:].split('_to_')
                    if len(parts) >= 1:
                        var = parts[0]
                        variables.add(var)
        return variables
    
    def is_linear(self, var: str) -> bool:
        """Check if a variable is linear (connected to exactly one factor)."""
        if var not in self.variables:
            return False
        return len(self.get_variable_factors(var)) == 1
    
    def is_colinear(self, factor: str) -> bool:
        """Check if a factor is colinear (connected to exactly one variable)."""
        if factor not in self.factors:
            return False
        return len(self.get_factor_variables(factor)) == 1
    
    def get_linear_variables(self) -> List[str]:
        """Get all linear variables."""
        return [v for v in self.variables if self.is_linear(v)]
    
    def get_colinear_factors(self) -> List[str]:
        """Get all colinear factors (unary factors)."""
        return [f for f in self.factors if self.is_colinear(f)]
    
    def retract_linear(self, var: str) -> CRNReductionStep:
        """
        Apply STRICT SP-B linear retraction: remove a leaf variable from the CRN.
        
        CRITICAL: This is a DELETE-ONLY operation per SP-B Proposition 5:
        - Delete species and reactions associated with the variable
        - DO NOT modify any surviving reaction rate constants
        
        The functorial claim in "Reduction of Probabilistic CRNs" paper states:
        "induced map deletes exactly corresponding bundles and leaves recycling 
        rates unchanged."
        
        This is NOT variable elimination (which would marginalize)!
        
        Args:
            var: Name of the linear variable to remove
            
        Returns:
            CRNReductionStep recording the deletion
        """
        if not self.is_linear(var):
            raise ValueError(f"Variable {var} is not linear")
        
        # Get the unique factor connected to this variable
        factors = self.get_variable_factors(var)
        factor = list(factors)[0]
        
        # Identify species to delete - all species associated with this variable
        species_to_delete = set()
        
        for name in list(self.crn.species.keys()):
            # Sum messages to this variable: S_{*}_to_{var}_{*}
            if f'_to_{var}_' in name and name.startswith('S_'):
                species_to_delete.add(name)
            # Product messages from this variable: P_{var}_to_{*}_{*}
            if name.startswith(f'P_{var}_to_'):
                species_to_delete.add(name)
            # Marginals: Marginal_{var}_{*}
            if name.startswith(f'Marginal_{var}_'):
                species_to_delete.add(name)
        
        # Identify reactions to delete - all reactions involving deleted species
        reactions_to_delete = []
        for i, rxn in enumerate(self.crn.reactions):
            rxn_species = set(rxn.reactants.keys()) | set(rxn.products.keys())
            if rxn_species & species_to_delete:
                reactions_to_delete.append(i)
        
        # Create reduction step record BEFORE deletion
        step = CRNReductionStep(
            step_type='linear',
            removed_variable=var,
            deleted_species=list(species_to_delete),
            deleted_reactions=len(reactions_to_delete),
            updated_reactions=0,  # DELETE-ONLY: no rate updates!
            details={
                'connected_factor': factor,
                'strict_spb': True,
                'note': 'Delete-only per SP-B Prop 5 - NO rate modifications'
            }
        )
        
        # === DELETE-ONLY OPERATION ===
        # Delete species
        for sp_name in species_to_delete:
            if sp_name in self.crn.species:
                del self.crn.species[sp_name]
        
        # Delete reactions (in reverse order to preserve indices)
        for i in reversed(reactions_to_delete):
            if i < len(self.crn.reactions):
                del self.crn.reactions[i]
        
        # Update internal tracking
        self.variables.discard(var)
        if var in self.var_to_species:
            del self.var_to_species[var]
        
        # Update factor's species set
        if factor in self.factor_to_species:
            self.factor_to_species[factor] -= species_to_delete
        
        # Update CRN's mapping dicts
        keys_to_delete = [k for k in self.crn.sum_msg_species.keys() 
                         if k[1] == var]  # (factor, var, k)
        for k in keys_to_delete:
            if k in self.crn.sum_msg_species:
                del self.crn.sum_msg_species[k]
            
        keys_to_delete = [k for k in self.crn.prod_msg_species.keys()
                         if k[0] == var]  # (var, factor, k)
        for k in keys_to_delete:
            if k in self.crn.prod_msg_species:
                del self.crn.prod_msg_species[k]
            
        keys_to_delete = [k for k in self.crn.marginal_species.keys()
                         if k[0] == var]  # (var, k)
        for k in keys_to_delete:
            if k in self.crn.marginal_species:
                del self.crn.marginal_species[k]
        
        self.reduction_history.append(step)
        return step
    
    def retract_colinear(self, factor: str, 
                         removed_table: Optional[np.ndarray] = None) -> CRNReductionStep:
        """
        Apply colinear retraction: remove a unary factor from the CRN.
        
        STRATEGY: Accumulate-then-emit
        
        - When deleting a unary factor u from variable x:
          - If x has other non-canonical factors remaining: 
            - Store u's table in accumulated_potentials[x]
            - Delete u's species/reactions
          - If x only has this unary factor remaining (becomes isolated):
            - Create canonical_x with accumulated_potentials[x] * u.table
            - Delete u's species/reactions
        
        The accumulated potential is stored in self.accumulated_potentials dict
        and is applied when the variable becomes isolated OR when converting
        back to factor graph form.
        
        Args:
            factor: Name of the colinear (unary) factor to remove
            removed_table: The unary factor's table values (needed for rate updates)
                          If None, extracts from existing reaction rates
            
        Returns:
            CRNReductionStep recording the changes
        """
        if not self.is_colinear(factor):
            raise ValueError(f"Factor {factor} is not colinear")
        
        # Get the unique variable connected to this factor
        variables = self.get_factor_variables(factor)
        var = list(variables)[0]
        
        # Extract the factor's table from reaction rates if not provided
        if removed_table is None:
            removed_table = self._extract_unary_table(factor, var)
        
        # Initialize accumulated potentials dict if needed
        if not hasattr(self, 'accumulated_potentials'):
            self.accumulated_potentials = {}
        
        # Count how many OTHER factors (non-canonical) the variable has
        var_factors = self.get_variable_factors(var)
        other_factors = [f for f in var_factors if f != factor and not f.startswith('canonical_')]
        
        # Identify species to delete (those for this factor)
        species_to_delete = set()
        for name in list(self.crn.species.keys()):
            if name.startswith(f'S_{factor}_to_'):
                species_to_delete.add(name)
            if f'_to_{factor}_' in name and name.startswith('P_'):
                species_to_delete.add(name)
        
        # Identify reactions to delete
        reactions_to_delete = []
        for i, rxn in enumerate(self.crn.reactions):
            rxn_species = set(rxn.reactants.keys()) | set(rxn.products.keys())
            if rxn_species & species_to_delete:
                reactions_to_delete.append(i)
        
        updated_count = 0
        canonical_factor_name = f"canonical_{var}"
        n_states = len(removed_table)
        
        if other_factors:
            # Variable has other non-canonical factors
            # Just accumulate the potential for later, don't create canonical yet
            if var in self.accumulated_potentials:
                self.accumulated_potentials[var] = self.accumulated_potentials[var] * removed_table
            else:
                self.accumulated_potentials[var] = removed_table.copy()
        else:
            # Variable becomes isolated - create/update canonical factor
            # Include any previously accumulated potential
            effective_table = removed_table.copy()
            if var in self.accumulated_potentials:
                effective_table = self.accumulated_potentials[var] * removed_table
                del self.accumulated_potentials[var]
            
            # Check if canonical already exists
            if canonical_factor_name in self.factors:
                # Update existing canonical rates by multiplying
                for rxn in self.crn.reactions:
                    if f'Sum msg {canonical_factor_name}→{var}' in rxn.description:
                        for sp_name in rxn.products:
                            if sp_name.startswith(f'S_{canonical_factor_name}_to_{var}_'):
                                k_str = sp_name.split('_')[-1]
                                try:
                                    k = int(k_str)
                                    if 1 <= k <= len(effective_table):
                                        rxn.rate_constant *= effective_table[k-1]
                                        updated_count += 1
                                except ValueError:
                                    pass
                                break
            else:
                # Create new canonical factor
                from crn.crn_compiler import Species, Reaction
                
                for k in range(n_states + 1):
                    sum_sp_name = f'S_{canonical_factor_name}_to_{var}_{k}'
                    init_conc = 0.0 if k == 0 else 1.0 / n_states
                    self.crn.species[sum_sp_name] = Species(
                        name=sum_sp_name,
                        species_type='sum_msg',
                        description=f'Canonical sum message to {var}',
                        initial_concentration=init_conc
                    )
                
                for k in range(n_states + 1):
                    prod_sp_name = f'P_{var}_to_{canonical_factor_name}_{k}'
                    init_conc = 0.0 if k == 0 else 1.0 / n_states
                    self.crn.species[prod_sp_name] = Species(
                        name=prod_sp_name,
                        species_type='prod_msg',
                        description=f'Canonical product message from {var}',
                        initial_concentration=init_conc
                    )
                
                # Recycling reactions
                for k in range(1, n_states + 1):
                    sum_assigned = f'S_{canonical_factor_name}_to_{var}_{k}'
                    sum_unassigned = f'S_{canonical_factor_name}_to_{var}_0'
                    self.crn.reactions.append(Reaction(
                        reactants={sum_assigned: 1},
                        products={sum_unassigned: 1},
                        rate_constant=self.crn.kappa_r,
                        description=f'Recycling S_{canonical_factor_name}→{var}[{k}]'
                    ))
                
                for k in range(1, n_states + 1):
                    prod_assigned = f'P_{var}_to_{canonical_factor_name}_{k}'
                    prod_unassigned = f'P_{var}_to_{canonical_factor_name}_0'
                    self.crn.reactions.append(Reaction(
                        reactants={prod_assigned: 1},
                        products={prod_unassigned: 1},
                        rate_constant=self.crn.kappa_r,
                        description=f'Recycling P_{var}→{canonical_factor_name}[{k}]'
                    ))
                
                # Sum message production with effective rates
                sum_unassigned = f'S_{canonical_factor_name}_to_{var}_0'
                for k in range(1, n_states + 1):
                    sum_assigned = f'S_{canonical_factor_name}_to_{var}_{k}'
                    rate = effective_table[k-1]
                    self.crn.reactions.append(Reaction(
                        reactants={sum_unassigned: 1},
                        products={sum_assigned: 1},
                        rate_constant=rate,
                        description=f'Sum msg {canonical_factor_name}→{var}[{k}], config=()'
                    ))
                    updated_count += 1
                
                # Product message production
                prod_unassigned = f'P_{var}_to_{canonical_factor_name}_0'
                for k in range(1, n_states + 1):
                    prod_assigned = f'P_{var}_to_{canonical_factor_name}_{k}'
                    self.crn.reactions.append(Reaction(
                        reactants={prod_unassigned: 1},
                        products={prod_assigned: 1},
                        rate_constant=self.crn.kappa_prod,
                        description=f'Prod msg {var}→{canonical_factor_name}[{k}]'
                    ))
                
                # Marginal reactions
                marg_unassigned = f'Marginal_{var}_0'
                for k in range(1, n_states + 1):
                    marg_assigned = f'Marginal_{var}_{k}'
                    sum_catalyst = f'S_{canonical_factor_name}_to_{var}_{k}'
                    if marg_unassigned in self.crn.species and marg_assigned in self.crn.species:
                        self.crn.reactions.append(Reaction(
                            reactants={marg_unassigned: 1, sum_catalyst: 1},
                            products={marg_assigned: 1, sum_catalyst: 1},
                            rate_constant=self.crn.kappa_prod,
                            description=f'Marginal {var}[{k}] from {canonical_factor_name}'
                        ))
                
                # Track canonical factor
                self.factors.add(canonical_factor_name)
                self.factor_to_species[canonical_factor_name] = set()
                for k in range(n_states + 1):
                    self.factor_to_species[canonical_factor_name].add(f'S_{canonical_factor_name}_to_{var}_{k}')
                    self.factor_to_species[canonical_factor_name].add(f'P_{var}_to_{canonical_factor_name}_{k}')
                if var in self.var_to_species:
                    for k in range(n_states + 1):
                        self.var_to_species[var].add(f'S_{canonical_factor_name}_to_{var}_{k}')
                        self.var_to_species[var].add(f'P_{var}_to_{canonical_factor_name}_{k}')
        
        # Create step record
        step = CRNReductionStep(
            step_type='colinear',
            removed_factor=factor,
            deleted_species=list(species_to_delete),
            deleted_reactions=len(reactions_to_delete),
            updated_reactions=updated_count,
            details={
                'connected_variable': var,
                'removed_table': removed_table.tolist() if removed_table is not None else None,
                'canonical_created': canonical_factor_name in self.factors and not other_factors
            }
        )
        
        # Delete species
        for sp_name in species_to_delete:
            if sp_name in self.crn.species:
                del self.crn.species[sp_name]
        
        # Delete reactions
        for i in reversed(reactions_to_delete):
            del self.crn.reactions[i]
        
        # Update tracking
        self.factors.discard(factor)
        if factor in self.factor_to_species:
            del self.factor_to_species[factor]
        if var in self.var_to_species:
            self.var_to_species[var] -= species_to_delete
        
        self.reduction_history.append(step)
        return step
    
    def _extract_unary_table(self, factor: str, var: str) -> np.ndarray:
        """
        Extract a unary factor's table values from the CRN reaction rates.
        
        For a unary factor, the sum message reactions have rates = ψ(k).
        """
        # Find sum message reactions for this factor
        table_values = {}
        
        for rxn in self.crn.reactions:
            if f'Sum msg {factor}→{var}' in rxn.description:
                # Extract k from products
                for sp_name in rxn.products:
                    if sp_name.startswith(f'S_{factor}_to_{var}_'):
                        k_str = sp_name.split('_')[-1]
                        try:
                            k = int(k_str)
                            table_values[k] = rxn.rate_constant
                        except ValueError:
                            pass
                        break
        
        if not table_values:
            return np.ones(2)  # Default binary uniform
        
        max_k = max(table_values.keys())
        table = np.array([table_values.get(k, 1.0) for k in range(1, max_k + 1)])
        return table
    
    def reduce_to_core(self, mode: str = 'aggressive') -> List[CRNReductionStep]:
        """
        Reduce the CRN to its core using strict SP-B retractions.
        
        Args:
            mode: 'aggressive' - removes all reducible structure (may leave only marginals)
                  'structural' - stops when reaching the true core, preserving 
                                 message-passing structure for surviving variables
        
        Ordering: colinear factors first, then linear variables.
        This matches FG reduction ordering.
        
        IMPORTANT: A variable connected only to its canonical factor is NOT 
        considered linear for deletion purposes - that's the final reduced state.
        
        Returns:
            List of all reduction steps taken
        """
        all_steps = []
        
        while True:
            # Re-parse structure after changes
            self._parse_crn_structure()
            
            # First priority: colinear factors (non-canonical)
            colinear = sorted(self.get_colinear_factors())
            non_canonical = [f for f in colinear if not f.startswith('canonical_')]
            
            if non_canonical:
                factor = non_canonical[0]
                step = self.retract_colinear(factor)
                all_steps.append(step)
                continue
            
            # Second priority: linear variables
            # BUT: don't delete a variable if its only factor is its canonical
            linear = sorted(self.get_linear_variables())
            deletable_linear = []
            for var in linear:
                var_factors = self.get_variable_factors(var)
                # Check if all factors are canonical for this variable
                all_canonical = all(f == f'canonical_{var}' for f in var_factors)
                if not all_canonical:
                    deletable_linear.append(var)
            
            if deletable_linear:
                step = self.retract_linear(deletable_linear[0])
                all_steps.append(step)
                continue
            
            # No more reductions possible
            break
        
        return all_steps
    
    def reduce_structural(self) -> List[CRNReductionStep]:
        """
        Reduce CRN while preserving message-passing structure for core variables.
        
        This performs the same reductions as FG→reduce→recompile would,
        stopping when only the core remains (variables with degree > 1).
        
        The key difference from reduce_to_core():
        - Stops when all remaining variables have degree ≥ 2 (loopy core)
        - Preserves sum/product message bundles for the core
        - For single-variable case: creates effective unary via final colinear retraction
        
        Returns:
            List of reduction steps taken
        """
        all_steps = []
        
        while True:
            self._parse_crn_structure()
            
            # Check stopping condition: is remaining structure a valid core?
            # Core = no linear variables AND no colinear factors
            
            colinear = sorted(self.get_colinear_factors())
            linear = sorted(self.get_linear_variables())
            
            # If no reductions possible, we're at the core
            if not colinear and not linear:
                break
            
            # First priority: colinear factors  
            if colinear:
                factor = colinear[0]
                var = list(self.get_factor_variables(factor))[0]
                var_factors = self.get_variable_factors(var)
                
                # Special case: single variable with single factor
                # This is where we CREATE the effective unary factor
                # by doing the final colinear retraction
                if len(self.variables) == 1 and len(var_factors) == 1:
                    # Do the final retraction - this creates the effective unary
                    step = self.retract_colinear(factor)
                    all_steps.append(step)
                    break  # Now we're done
                
                step = self.retract_colinear(factor)
                all_steps.append(step)
                continue
            
            # Second priority: linear variables
            if linear:
                step = self.retract_linear(linear[0])
                all_steps.append(step)
                continue
            
            break
        
        return all_steps


def reduce_crn_to_core(crn: ChemicalReactionNetwork, 
                       copy: bool = True,
                       mode: str = 'aggressive') -> Tuple[ChemicalReactionNetwork, List[CRNReductionStep]]:
    """
    Convenience function to reduce a CRN to its core.
    
    Args:
        crn: The CRN to reduce
        copy: If True, work on a copy; if False, modify in place
        mode: 'aggressive' - removes all reducible structure
              'structural' - preserves message-passing for core
        
    Returns:
        Tuple of (reduced CRN, list of reduction steps)
    """
    if copy:
        crn = deepcopy(crn)
    
    reducer = CRNReducer(crn)
    
    if mode == 'structural':
        steps = reducer.reduce_structural()
    else:
        steps = reducer.reduce_to_core()
    
    return crn, steps


def reduce_crn_guided(crn: ChemicalReactionNetwork,
                      fg_steps: List,  # ReductionSteps from FG reduction
                      copy: bool = True) -> Tuple[ChemicalReactionNetwork, List[CRNReductionStep]]:
    """
    Reduce CRN guided by FG reduction steps.
    
    This ensures CRN reduction uses the same tables as FG reduction,
    which is critical for commutation when factors have phantom variables
    that were deleted earlier.
    
    Args:
        crn: The CRN to reduce
        fg_steps: List of ReductionStep from FG reduction
        copy: If True, work on a copy
        
    Returns:
        Tuple of (reduced CRN, list of CRN reduction steps)
    """
    if copy:
        crn = deepcopy(crn)
    
    reducer = CRNReducer(crn)
    crn_steps = []
    
    for fg_step in fg_steps:
        reducer._parse_crn_structure()
        
        if fg_step.step_type == 'colinear':
            # Absorb unary factor - use the table from FG step
            factor_name = fg_step.removed_element.replace('fac:', '')
            removed_table = fg_step.details.get('removed_table')
            
            if removed_table is not None:
                removed_table = np.array(removed_table)
            
            # Check if factor still exists in CRN
            if factor_name in reducer.factors:
                step = reducer.retract_colinear(factor_name, removed_table)
                crn_steps.append(step)
                
        elif fg_step.step_type == 'linear':
            # Delete leaf variable
            var_name = fg_step.details.get('var_name')
            
            if var_name and var_name in reducer.variables:
                step = reducer.retract_linear(var_name)
                crn_steps.append(step)
    
    return crn, crn_steps


def reduce_crn_structural(crn: ChemicalReactionNetwork,
                          copy: bool = True) -> Tuple[ChemicalReactionNetwork, List[CRNReductionStep]]:
    """
    Reduce CRN while preserving message-passing structure.
    
    This produces a CRN equivalent to: FG → SP-B reduce → recompile
    
    Args:
        crn: The CRN to reduce
        copy: If True, work on a copy
        
    Returns:
        Tuple of (reduced CRN, list of reduction steps)
    """
    return reduce_crn_to_core(crn, copy=copy, mode='structural')


def compare_crn_structures(crn1: ChemicalReactionNetwork, 
                           crn2: ChemicalReactionNetwork) -> Dict[str, Any]:
    """
    Compare two CRN structures for debugging.
    """
    return {
        'crn1_species': len(crn1.species),
        'crn2_species': len(crn2.species),
        'crn1_reactions': len(crn1.reactions),
        'crn2_reactions': len(crn2.reactions),
        'species_diff': set(crn1.species.keys()) ^ set(crn2.species.keys()),
    }