# reduction/poset_reduction.py
"""
Strict SP-B Poset Reduction for Factor Graphs

This module implements the SP-B (Sergeant-Perthuis & Boitel) reduction procedure
STRICTLY as specified in their paper "Minima and Critical Points of the Bethe 
Free Energy Are Invariant Under Deformation Retractions of Factor Graphs".

KEY DISTINCTION FROM VARIABLE ELIMINATION:

1. LINEAR DELETION (Prop 5): Delete-only operation
   - Removes a linear variable region (minimal element with exactly one upper cover)
   - **NO MODIFICATION** to survivor tables/scopes
   - The deleted belief is recoverable by marginalizing the cover belief
   - This is NOT the same as variable elimination!

2. COLINEAR DELETION (Prop 6/7): Table-updating operation  
   - Removes a colinear factor region (maximal element with exactly one lower cover)
   - Updates survivor tables according to Eq 4.29 or Eq 4.30
   - This is where "information transfer" happens

The incidence poset of a factor graph has:
- Minimal elements: variable regions (with ψ_i ≡ 1 by default)
- Maximal elements: factor regions (including unary factors!)
- Relations: var < fac iff var ∈ scope(fac)

IMPORTANT: Unary factors are NOT absorbed into variable regions. They remain
as maximal (colinear) elements and are removed via the colinear rule.
"""

from dataclasses import dataclass, field
from typing import Dict, Set, List, Tuple, Optional, Any
import numpy as np
from itertools import product


@dataclass
class Region:
    """A region in the incidence poset."""
    name: str
    region_type: str  # 'variable' or 'factor'
    scope: Tuple[str, ...]  # variable names in scope
    table: np.ndarray  # potential function ψ
    arities: Dict[str, int] = field(default_factory=dict)  # var_name -> cardinality


@dataclass
class ReductionStep:
    """Record of a single reduction step."""
    step_type: str  # 'linear' or 'colinear'
    removed_element: str
    target_element: str  # cover (for linear) or lower neighbor (for colinear)
    equation_used: str  # 'prop_5' for linear, 'eq_4.29' or 'eq_4.30' for colinear
    details: Dict[str, Any] = field(default_factory=dict)
    
    def __repr__(self):
        if self.step_type == 'linear':
            return f"ReductionStep(LINEAR delete-only: remove '{self.removed_element}', cover='{self.target_element}')"
        else:
            return f"ReductionStep(COLINEAR {self.equation_used}: remove '{self.removed_element}' → '{self.target_element}')"


class PosetModel:
    """
    Incidence poset representation of a factor graph.
    
    Structure:
    - Minimal elements: variable regions (var:x_i)
    - Maximal elements: factor regions (fac:f_j), INCLUDING unary factors
    - Cover relations: var < fac iff var ∈ scope(fac)
    """
    
    def __init__(self):
        self.regions: Dict[str, Region] = {}
        self.variables: Set[str] = set()  # variable region names
        self.factors: Set[str] = set()     # factor region names
        self.cover_up: Dict[str, Set[str]] = {}   # var -> set of factors above
        self.cover_down: Dict[str, Set[str]] = {} # fac -> set of vars below
        self.reduction_history: List[ReductionStep] = []
    
    def add_variable_region(self, var_name: str, cardinality: int):
        """Add a variable region with uniform potential (ψ ≡ 1)."""
        region_name = f"var:{var_name}"
        table = np.ones(cardinality)  # ALWAYS uniform - this is key to strict SP-B
        region = Region(
            name=region_name,
            region_type='variable',
            scope=(var_name,),
            table=table,
            arities={var_name: cardinality}
        )
        self.regions[region_name] = region
        self.variables.add(region_name)
        self.cover_up[region_name] = set()
    
    def add_factor_region(self, factor_name: str, scope: Tuple[str, ...], 
                          table: np.ndarray, arities: Dict[str, int]):
        """Add a factor region (including unary factors!)."""
        region_name = f"fac:{factor_name}"
        region = Region(
            name=region_name,
            region_type='factor',
            scope=scope,
            table=table.copy(),
            arities=arities.copy()
        )
        self.regions[region_name] = region
        self.factors.add(region_name)
        self.cover_down[region_name] = set()
        
        # Add cover relations
        for var_name in scope:
            var_region = f"var:{var_name}"
            if var_region in self.variables:
                self.cover_up[var_region].add(region_name)
                self.cover_down[region_name].add(var_region)
    
    def is_linear(self, var_region: str) -> bool:
        """Check if a variable region is linear (exactly one upper cover)."""
        if var_region not in self.variables:
            return False
        return len(self.cover_up[var_region]) == 1
    
    def is_colinear(self, fac_region: str) -> bool:
        """Check if a factor region is colinear (exactly one lower cover)."""
        if fac_region not in self.factors:
            return False
        return len(self.cover_down[fac_region]) == 1
    
    def get_linear_variables(self) -> List[str]:
        """Get all linear variable regions."""
        return [v for v in self.variables if self.is_linear(v)]
    
    def get_colinear_factors(self) -> List[str]:
        """Get all colinear factor regions (unary factors)."""
        return [f for f in self.factors if self.is_colinear(f)]
    
    def compute_counting_number(self, region_name: str) -> int:
        """
        Compute the counting number c(α) for a region.
        
        For the incidence poset:
        - c(var) = 1 - |factors above var|
        - c(fac) = 1 (factors are maximal)
        """
        if region_name in self.factors:
            return 1
        elif region_name in self.variables:
            return 1 - len(self.cover_up[region_name])
        return 0
    
    def __repr__(self):
        return f"PosetModel(variables={list(self.variables)}, factors={list(self.factors)})"


def from_factor_graph(fg) -> PosetModel:
    """
    Convert a FactorGraph to its incidence poset.
    
    STRICT SP-B REQUIREMENT:
    - Variable regions get ψ ≡ 1 (uniform potential)
    - ALL factors become factor regions, INCLUDING unary factors
    - Unary factors are NOT absorbed into variable regions
    """
    poset = PosetModel()
    
    # Add variable regions with uniform potentials
    for var in fg.variables:
        poset.add_variable_region(var.name, var.cardinality)
    
    # Add ALL factor regions (including unary!)
    for factor in fg.factors:
        scope = tuple(v.name for v in factor.variables)
        arities = {v.name: v.cardinality for v in factor.variables}
        poset.add_factor_region(factor.name, scope, factor.table, arities)
    
    return poset


def retract_linear(poset: PosetModel, var_region: str) -> ReductionStep:
    """
    STRICT SP-B Linear Deletion (Proposition 5).
    
    This is a DELETE-ONLY operation:
    - Remove the linear variable region from the poset
    - DO NOT modify any survivor tables or scopes
    - The deleted belief is recoverable by marginalizing the cover
    
    This is NOT variable elimination!
    
    Args:
        poset: The poset model (modified in place)
        var_region: Name of the linear variable region to remove
        
    Returns:
        ReductionStep recording the deletion
    """
    if not poset.is_linear(var_region):
        raise ValueError(f"{var_region} is not linear (has {len(poset.cover_up[var_region])} covers)")
    
    # Get the unique upper cover (factor region)
    cover = list(poset.cover_up[var_region])[0]
    
    # Record what we're about to delete (for reconstruction)
    deleted_region = poset.regions[var_region]
    var_name = deleted_region.scope[0]
    
    # Snapshot the cover factor BEFORE deletion (for validation)
    cover_obj = poset.regions[cover]
    cover_table_before = cover_obj.table.copy()
    cover_scope_before = cover_obj.scope
    
    # Create the reduction step BEFORE deletion
    step = ReductionStep(
        step_type='linear',
        removed_element=var_region,
        target_element=cover,
        equation_used='prop_5',
        details={
            'var_name': var_name,
            'cover_factor': cover,
            'cover_table_before': cover_table_before,
            'cover_scope_before': cover_scope_before,
            'deleted_table': deleted_region.table.copy(),
            'var_potential': deleted_region.table.copy(),  # For weighted marginalization in colinear
            'reconstruction': 'marginalize cover belief over deleted variable'
        }
    )
    
    # === DELETE-ONLY OPERATION ===
    # Remove the variable region from the poset structure
    # DO NOT modify the cover factor's table or scope!
    
    # Remove from cover relations
    poset.cover_down[cover].remove(var_region)
    del poset.cover_up[var_region]
    
    # Remove from region sets
    poset.variables.remove(var_region)
    del poset.regions[var_region]
    
    # VALIDATION: Ensure cover was NOT modified
    assert np.allclose(poset.regions[cover].table, cover_table_before), \
        "BUG: Linear deletion modified cover table!"
    assert poset.regions[cover].scope == cover_scope_before, \
        "BUG: Linear deletion modified cover scope!"
    
    # Record in history
    poset.reduction_history.append(step)
    
    return step


def retract_colinear(poset: PosetModel, fac_region: str) -> ReductionStep:
    """
    Factor-Graph-Restricted Colinear Deletion.
    
    In the factor-graph restricted regime, colinear deletion is simple:
    - A colinear factor is a UNARY factor (scope size 1)
    - We ALWAYS absorb the unary factor into the variable's potential
    - This works regardless of how many other factors the variable has
    
    We do NOT implement general SP-B region-poset "Eq 4.30 counting-number" semantics.
    For factor graphs, unary factor deletion is just multiplication into the variable table.
    
    Args:
        poset: The poset model (modified in place)
        fac_region: Name of the colinear factor region to remove
        
    Returns:
        ReductionStep recording the deletion
    """
    if not poset.is_colinear(fac_region):
        raise ValueError(f"{fac_region} is not colinear (has {len(poset.cover_down[fac_region])} lower covers)")
    
    # Get the unique lower cover (variable region)
    var_below = list(poset.cover_down[fac_region])[0]
    var_region = poset.regions[var_below]
    fac_obj = poset.regions[fac_region]
    
    var_name = var_region.scope[0]  # The surviving variable
    
    # Get the effective 1D table for the colinear factor
    # If the factor has phantom variables (deleted from poset but still in scope),
    # we need to marginalize over them WEIGHTED by their accumulated potentials
    fac_table = fac_obj.table.copy()
    fac_scope = fac_obj.scope
    
    # Find which variables in the factor's scope are "phantom" (not in poset.variables)
    phantom_vars = []
    for v in fac_scope:
        if f"var:{v}" not in poset.variables:
            phantom_vars.append(v)
    
    # Marginalize over phantom variables to get effective 1D table
    # IMPORTANT: Use weighted sum with the phantom variable's accumulated potential
    if phantom_vars:
        for phantom_var in phantom_vars:
            axis = fac_scope.index(phantom_var)
            
            # Get the phantom variable's accumulated potential from reduction history
            # This was stored when the variable was deleted
            phantom_potential = None
            for step in poset.reduction_history:
                if step.step_type == 'linear' and step.details.get('var_name') == phantom_var:
                    # The variable's potential at deletion time is in the step details
                    phantom_potential = step.details.get('var_potential')
                    break
            
            if phantom_potential is not None:
                # Weighted marginalization: Σ_v ψ_v[v] * fac_table[..., v, ...]
                # Reshape potential to broadcast correctly
                shape = [1] * len(fac_table.shape)
                shape[axis] = len(phantom_potential)
                weighted = fac_table * phantom_potential.reshape(shape)
                fac_table = np.sum(weighted, axis=axis)
            else:
                # Fallback to uniform marginalization if potential not found
                fac_table = np.sum(fac_table, axis=axis)
            
            # Update scope for next iteration
            fac_scope = tuple(v for v in fac_scope if v != phantom_var)
    
    # Now fac_table should be 1D
    removed_table = fac_table.flatten()
    
    # FACTOR-GRAPH-RESTRICTED SEMANTICS:
    # Unary factor deletion = absorb into variable potential, ALWAYS.
    # This is exact for factor graphs and does not require counting-number exponents.
    var_region.table = var_region.table * removed_table
    
    equation_used = 'absorb'
    target = var_below
    
    # Create reduction step
    step = ReductionStep(
        step_type='colinear',
        removed_element=fac_region,
        target_element=target,
        equation_used=equation_used,
        details={
            'removed_table': removed_table.copy(),
            'phantom_vars': phantom_vars
        }
    )
    
    # === REMOVE THE FACTOR REGION ===
    poset.cover_up[var_below].remove(fac_region)
    del poset.cover_down[fac_region]
    
    poset.factors.remove(fac_region)
    del poset.regions[fac_region]
    
    poset.reduction_history.append(step)
    
    return step


def reduce_to_core_spb(poset: PosetModel) -> List[ReductionStep]:
    """
    Reduce the poset to its core using strict SP-B retractions.
    
    ORDERING (per strict SP-B):
    1. Apply COLINEAR deletions first (unary factors → Eq 4.29/4.30)
    2. Then apply LINEAR deletions (leaf variables → delete-only)
    
    This ensures unary factor information is properly transferred before
    we start removing variable regions.
    
    Returns:
        List of all reduction steps taken
    """
    all_steps = []
    
    while True:
        # First priority: colinear factors (unary factors)
        colinear = sorted(poset.get_colinear_factors())  # sorted for determinism
        
        if colinear:
            step = retract_colinear(poset, colinear[0])
            all_steps.append(step)
            continue
        
        # Second priority: linear variables (leaf variables)
        linear = sorted(poset.get_linear_variables())
        
        if linear:
            step = retract_linear(poset, linear[0])
            all_steps.append(step)
            continue
        
        # No more reductions possible - we've reached the core
        break
    
    return all_steps


def reconstruct_deleted_belief(step: ReductionStep, cover_belief: np.ndarray,
                               cover_scope: Tuple[str, ...], 
                               arities: Dict[str, int]) -> np.ndarray:
    """
    Reconstruct the belief on a deleted linear variable region.
    
    For linear deletion, SP-B guarantees:
        Q_b(x) = Σ_{y: y|_b = x} Q_{b↑}(y)
    
    i.e., the deleted belief is the marginal of the cover belief.
    
    Args:
        step: The ReductionStep from linear deletion
        cover_belief: Current belief on the cover factor
        cover_scope: Scope of the cover factor
        arities: Variable cardinalities
        
    Returns:
        Reconstructed belief on the deleted variable
    """
    if step.step_type != 'linear':
        raise ValueError("Reconstruction is for linear deletions only")
    
    var_name = step.details['var_name']
    
    if var_name not in cover_scope:
        raise ValueError(f"Variable {var_name} not in cover scope {cover_scope}")
    
    # Find axis and marginalize
    axis = cover_scope.index(var_name)
    
    # Sum over all axes except the variable's axis
    axes_to_sum = tuple(i for i in range(len(cover_scope)) if i != axis)
    
    if axes_to_sum:
        reconstructed = np.sum(cover_belief, axis=axes_to_sum)
    else:
        reconstructed = cover_belief.copy()
    
    # Normalize
    if np.sum(reconstructed) > 0:
        reconstructed = reconstructed / np.sum(reconstructed)
    
    return reconstructed


def to_factor_graph_if_possible(poset: PosetModel):
    """
    Convert the poset back to a FactorGraph if possible.
    
    NOTE: After strict SP-B reductions, the "intermediate objects" may not
    be valid factor graphs! Specifically:
    - Factor scopes may reference variables that have been deleted
    - This is expected and correct per SP-B theory
    
    This function creates a valid FactorGraph by:
    1. Including only surviving variable regions
    2. For factors with deleted variables in scope, we marginalize them out
       (this is for CRN compilation purposes - the SP-B structure is preserved
       in the poset representation)
    
    Returns:
        FactorGraph, or None if conversion fails
    """
    from core import Variable, Factor, FactorGraph
    
    # Collect all variable names still in the poset
    surviving_vars = {}
    for var_region in poset.variables:
        region = poset.regions[var_region]
        var_name = region.scope[0]
        surviving_vars[var_name] = region.arities[var_name]
    
    if not surviving_vars:
        # No variables left - return empty FG
        fg = FactorGraph("from_poset")
        return fg
    
    # Build the factor graph
    fg = FactorGraph("from_poset")
    
    # Add variables
    var_objects = {}
    for var_name, cardinality in surviving_vars.items():
        var = fg.add_variable(Variable(var_name, list(range(cardinality))))
        var_objects[var_name] = var
    
    # Add variable region tables as unary factors (if non-uniform)
    for var_region in poset.variables:
        region = poset.regions[var_region]
        var_name = region.scope[0]
        # Only add unary if table is not uniform
        if not np.allclose(region.table, np.ones_like(region.table)):
            fg.add_factor(Factor(f"unary_{var_name}", [var_objects[var_name]], region.table))
    
    # Add factor regions
    for fac_region in poset.factors:
        region = poset.regions[fac_region]
        
        # Check which scope variables survive
        surviving_scope = [v for v in region.scope if v in surviving_vars]
        deleted_scope = [v for v in region.scope if v not in surviving_vars]
        
        if not surviving_scope:
            # Factor has no surviving variables - skip it
            continue
        
        if deleted_scope:
            # Factor references deleted variables
            # Marginalize them out for FG representation
            table = region.table
            for del_var in reversed(deleted_scope):  # reverse to keep axes aligned
                axis = region.scope.index(del_var)
                table = np.sum(table, axis=axis)
            
            fac_vars = [var_objects[v] for v in surviving_scope]
        else:
            # All scope variables survive
            fac_vars = [var_objects[v] for v in region.scope]
            table = region.table
        
        fac_name = fac_region.replace('fac:', '')
        fg.add_factor(Factor(fac_name, fac_vars, table))
    
    return fg


# ============================================================================
# VALIDATION HELPERS
# ============================================================================

def validate_linear_deletion_is_delete_only(step: ReductionStep, 
                                             poset: PosetModel) -> bool:
    """
    Validate that a linear deletion did NOT modify the survivor factor.
    
    Returns True if the deletion was strict SP-B (delete-only).
    Returns False if it looks like elimination was performed.
    """
    if step.step_type != 'linear':
        return True  # Not a linear step
    
    cover = step.target_element
    if cover not in poset.regions:
        return True  # Cover was also deleted (shouldn't happen)
    
    cover_obj = poset.regions[cover]
    before_table = step.details['cover_table_before']
    before_scope = step.details['cover_scope_before']
    
    # Check scope unchanged
    if cover_obj.scope != before_scope:
        return False
    
    # Check table unchanged
    if not np.allclose(cover_obj.table, before_table):
        return False
    
    return True