"""
Factor Graph Reduction Module

Implements deformation retractions of factor graphs following
Sergeant-Perthuis & Boitel's theory.

Key Results from the Paper:
- Proposition 5: Retracting a linear point induces bijection on critical points
- Proposition 6-7: Retracting a colinear point with updated Hamiltonians preserves critical points
- Theorem 1: The core of a factor graph has isomorphic critical points of Bethe free energy

Since critical points of Bethe free energy ↔ fixed points of BP,
the reductions preserve BP marginals when done correctly.

Linear Points (Variables):
    A variable i is linear if there exists a factor a (called i↑) such that
    every factor containing i also contains all variables that i↑ contains.
    Rule: ψ'_c = ψ_c for all surviving c (just restrict/marginalize)
    
Colinear Points (Factors):
    A factor a is colinear if there exists another factor b (called a↓) such that
    scope(a) ⊇ scope(b).
    
    Two cases for the update rule:
    Case 1 (Eq 4.29): If a↓ is LINEAR in B = A\\{a}
        Let b↑ be the unique upper cover of a↓ in B
        ψ'_{b↑}(k) = ψ_{b↑}(k) · [Σ_{z: z|_{a↓}=x} ψ_a(z)] / ψ_{a↓}(x)
        where x = k|_{a↓}
        
    Case 2 (Eq 4.30): If a↓ is NOT LINEAR in B
        ψ'_{a↓}(x) = ψ_{a↓}(x)^{1-1/c_B(a↓)} · [Σ_{z: z|_{a↓}=x} ψ_a(z)]^{1/c_B(a↓)}
        where c_B(a↓) is the counting number in B
"""

from typing import List, Dict, Tuple, Set, Optional, Union
from dataclasses import dataclass, field
from itertools import product
import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph, build_factor_graph, Poset, PosetElement


@dataclass
class ReductionStep:
    """
    Records a single reduction step.
    """
    step_type: str  # 'linear' or 'colinear'
    removed_element: str
    target_element: str
    element_type: str  # 'variable' or 'factor'
    case: Optional[str] = None  # For colinear: 'eq_4.29' or 'eq_4.30'
    details: Optional[Dict] = None
    
    def __repr__(self):
        case_str = f" ({self.case})" if self.case else ""
        return (f"ReductionStep({self.step_type}{case_str}, remove {self.element_type} "
                f"'{self.removed_element}' → '{self.target_element}')")


@dataclass
class ReductionResult:
    """
    Result of factor graph reduction.
    """
    original_graph: FactorGraph
    reduced_graph: FactorGraph
    reduction_steps: List[ReductionStep]
    is_core: bool
    
    @property
    def num_reductions(self) -> int:
        return len(self.reduction_steps)
    
    def __repr__(self):
        status = "core" if self.is_core else "partially reduced"
        return (f"ReductionResult({self.num_reductions} steps, {status}, "
                f"{self.original_graph.num_variables}→{self.reduced_graph.num_variables} vars, "
                f"{self.original_graph.num_factors}→{self.reduced_graph.num_factors} factors)")


class FactorGraphReducer:
    """
    Reduces factor graphs via poset deformation retractions.
    
    The reduction preserves critical points of the Bethe free energy,
    which means BP marginals are preserved.
    """
    
    def __init__(self, factor_graph: FactorGraph):
        self.original_fg = factor_graph
        self.current_fg = factor_graph.copy()
        self.reduction_steps: List[ReductionStep] = []
    
    def _build_poset(self) -> Poset:
        """Build the poset A(H) for the current factor graph."""
        elements = []
        relations = []
        
        for var in self.current_fg.variables:
            elements.append(('var', var.name))
        
        for factor in self.current_fg.factors:
            elements.append(('fac', factor.name))
            for var in factor.variables:
                relations.append((('var', var.name), ('fac', factor.name)))
        
        return Poset(elements, relations)
    
    def find_linear_variables(self) -> List[Tuple[Variable, Factor]]:
        """
        Find all linear variables in the current factor graph.
        
        A variable i is linear if there exists a factor a↑ such that
        for all factors f containing i: scope(a↑) \\ {i} ⊆ scope(f) \\ {i}
        """
        linear_vars = []
        
        for var in self.current_fg.variables:
            factors_with_var = self.current_fg.neighbors_of_variable(var)
            
            if not factors_with_var:
                continue
            
            for candidate_factor in factors_with_var:
                candidate_others = set(v for v in candidate_factor.variables if v != var)
                
                is_linear = True
                for other_factor in factors_with_var:
                    if other_factor == candidate_factor:
                        continue
                    other_others = set(v for v in other_factor.variables if v != var)
                    
                    if not candidate_others.issubset(other_others):
                        is_linear = False
                        break
                
                if is_linear:
                    linear_vars.append((var, candidate_factor))
                    break
        
        return linear_vars
    
    def find_colinear_factors(self) -> List[Tuple[Factor, Factor]]:
        """
        Find all colinear factors in the current factor graph.
        
        A factor a is colinear if there exists factor a↓ such that
        scope(a) ⊃ scope(a↓) (strict superset).
        """
        colinear_factors = []
        factors = self.current_fg.factors
        
        for factor_a in factors:
            scope_a = set(factor_a.variables)
            
            for factor_b in factors:
                if factor_a == factor_b:
                    continue
                    
                scope_b = set(factor_b.variables)
                
                # a is colinear to b if scope(a) ⊃ scope(b)
                if scope_a.issuperset(scope_b) and scope_a != scope_b:
                    colinear_factors.append((factor_a, factor_b))
                    break
        
        return colinear_factors
    
    def _is_linear_in_reduced_poset(self, element_name: str, 
                                      excluded_factor: str) -> Tuple[bool, Optional[str]]:
        """
        Check if an element would be linear in the poset after removing a factor.
        
        For factor graphs: check if a variable has a unique factor containing it
        after we remove the excluded_factor.
        
        Returns:
            (is_linear, upper_cover_name) - the unique upper cover if linear
        """
        # Find the variable
        var = None
        for v in self.current_fg.variables:
            if v.name == element_name:
                var = v
                break
        
        if var is None:
            # It's a factor, check differently
            # For factors in our poset (maximal elements), they can't be linear
            # unless there's a factor above them, which doesn't happen
            return False, None
        
        # Get factors containing this variable, excluding the one being removed
        factors_with_var = [f for f in self.current_fg.neighbors_of_variable(var)
                          if f.name != excluded_factor]
        
        if len(factors_with_var) == 1:
            # Unique factor - this variable is linear
            return True, factors_with_var[0].name
        elif len(factors_with_var) == 0:
            # Variable has no factors - degenerate case
            return False, None
        else:
            # Multiple factors - check if one dominates
            for candidate in factors_with_var:
                candidate_others = set(v for v in candidate.variables if v != var)
                
                is_dominated = True
                for other in factors_with_var:
                    if other == candidate:
                        continue
                    other_others = set(v for v in other.variables if v != var)
                    if not candidate_others.issubset(other_others):
                        is_dominated = False
                        break
                
                if is_dominated:
                    return True, candidate.name
            
            return False, None
    
    def _compute_counting_number_in_reduced(self, element_name: str, 
                                             excluded_factor: str) -> int:
        """
        Compute the counting number c_B(element) in the reduced poset B = A \\ {excluded_factor}.
        
        For factor graphs with chains of length 1:
        - For variables: c(i) = 1 - |{factors containing i}|
        - For factors (maximal): c(a) = 1
        """
        # Find if element is a variable or factor
        var = None
        for v in self.current_fg.variables:
            if v.name == element_name:
                var = v
                break
        
        if var is not None:
            # It's a variable
            # Count factors containing it in the reduced graph
            num_factors = len([f for f in self.current_fg.neighbors_of_variable(var)
                              if f.name != excluded_factor])
            return 1 - num_factors
        else:
            # It's a factor - maximal element has c(a) = 1
            return 1
    
    def retract_linear_variable(self, var: Variable, target_factor: Factor) -> FactorGraph:
        """
        Retract a linear variable from the factor graph.
        
        Rule: ψ'_c = ψ_c for all surviving c
        We marginalize out the variable from factors containing it.
        """
        new_fg = FactorGraph(f"{self.current_fg.name}_reduced")
        
        # Add all variables except the one being removed
        var_map = {}
        for v in self.current_fg.variables:
            if v != var:
                new_var = Variable(v.name, v.values.copy())
                new_fg.add_variable(new_var)
                var_map[v.name] = new_var
        
        # Collect factors containing the variable and those without
        factors_with_var = [f for f in self.current_fg.factors if var in f.variables]
        factors_without_var = [f for f in self.current_fg.factors if var not in f.variables]
        
        # Add factors that don't contain the variable (unchanged)
        for factor in factors_without_var:
            new_vars = [var_map[v.name] for v in factor.variables]
            new_fg.add_factor(Factor(factor.name, new_vars, factor.table.copy()))
        
        # For factors containing the variable, combine and marginalize
        if factors_with_var:
            other_vars_set = set()
            for f in factors_with_var:
                for v in f.variables:
                    if v != var:
                        other_vars_set.add(v)
            other_vars = sorted(list(other_vars_set), key=lambda v: v.name)
            
            if other_vars:
                new_vars = [var_map[v.name] for v in other_vars]
                new_shape = tuple(v.cardinality for v in other_vars)
                new_table = np.zeros(new_shape)
                
                other_ranges = [range(v.cardinality) for v in other_vars]
                for other_indices in product(*other_ranges):
                    other_assignment = {
                        v: v.index_to_value(idx)
                        for v, idx in zip(other_vars, other_indices)
                    }
                    
                    total = 0.0
                    for var_idx in range(var.cardinality):
                        var_value = var.index_to_value(var_idx)
                        full_assignment = other_assignment.copy()
                        full_assignment[var] = var_value
                        
                        product_val = 1.0
                        for f in factors_with_var:
                            f_assignment = {v: full_assignment[v] for v in f.variables}
                            product_val *= f.get_value(f_assignment)
                        
                        total += product_val
                    
                    new_table[other_indices] = total
                
                combined_name = f"combined_{'_'.join(f.name for f in factors_with_var)}"
                new_fg.add_factor(Factor(combined_name, new_vars, new_table))
        
        self.reduction_steps.append(ReductionStep(
            step_type='linear',
            removed_element=var.name,
            target_element=target_factor.name,
            element_type='variable'
        ))
        
        self.current_fg = new_fg
        return new_fg
    
    def retract_colinear_factor(self, factor_a: Factor, factor_a_down: Factor) -> FactorGraph:
        """
        Retract a colinear factor from the factor graph.
        
        factor_a: The colinear factor being removed (larger scope)
        factor_a_down: The factor it retracts to (smaller scope, a↓)
        
        Two cases based on whether a↓ is linear in B = A \\ {a}:
        
        Case 1 (Eq 4.29): a↓ is linear in B
            Let b↑ be unique upper cover of a↓ in B
            ψ'_{b↑}(k) = ψ_{b↑}(k) · [Σ_z ψ_a(z)] / ψ_{a↓}(x)  where x = k|_{a↓}
            
        Case 2 (Eq 4.30): a↓ is NOT linear in B
            ψ'_{a↓}(x) = ψ_{a↓}(x)^{1-1/c} · [Σ_z ψ_a(z)]^{1/c}  where c = c_B(a↓)
        """
        new_fg = FactorGraph(f"{self.current_fg.name}_reduced")
        
        # Copy all variables
        var_map = {}
        for v in self.current_fg.variables:
            new_var = Variable(v.name, v.values.copy())
            new_fg.add_variable(new_var)
            var_map[v.name] = new_var
        
        # Check if a↓ is linear in B = A \\ {a}
        # For factor graphs, a↓ is a unary factor on some variable(s)
        # It's linear in B if that variable has exactly one other factor containing it
        
        # Actually, we need to check: is a↓ (as an element) linear in the reduced poset?
        # In factor graph terms: are the variables of a↓ each contained in exactly one
        # other factor (besides a)?
        
        # For unary factor a↓ on variable x: 
        # a↓ is linear in B if x is contained in exactly one factor in B besides a↓
        
        is_linear, b_up_name = self._check_colinear_case(factor_a, factor_a_down)
        
        # Precompute: for each x in E_{a↓}, compute Σ_{z: z|_{a↓}=x} ψ_a(z)
        sum_psi_a = self._compute_marginalized_sum(factor_a, factor_a_down)
        
        if is_linear and b_up_name is not None:
            # Case 1: Eq 4.29
            b_up = self.current_fg.get_factor(b_up_name)
            
            # ψ'_{b↑}(k) = ψ_{b↑}(k) · [Σ_z ψ_a(z)] / ψ_{a↓}(x)
            updated_b_up_table = self._apply_eq_4_29(
                b_up, factor_a_down, sum_psi_a
            )
            
            # Add all factors
            for f in self.current_fg.factors:
                if f == factor_a:
                    continue  # Skip removed factor
                elif f == b_up:
                    new_vars = [var_map[v.name] for v in f.variables]
                    new_fg.add_factor(Factor(f.name, new_vars, updated_b_up_table))
                else:
                    new_vars = [var_map[v.name] for v in f.variables]
                    new_fg.add_factor(Factor(f.name, new_vars, f.table.copy()))
            
            case = 'eq_4.29'
            details = {'b_up': b_up_name}
        else:
            # Case 2: Eq 4.30
            c_B = self._compute_counting_number_in_reduced(
                factor_a_down.name, factor_a.name
            )
            
            # Handle edge case where c_B = 0
            if c_B == 0:
                # Degenerate case - just use marginalization
                updated_a_down_table = sum_psi_a
            else:
                # ψ'_{a↓}(x) = ψ_{a↓}(x)^{1-1/c} · [Σ_z ψ_a(z)]^{1/c}
                updated_a_down_table = self._apply_eq_4_30(
                    factor_a_down, sum_psi_a, c_B
                )
            
            # Add all factors
            for f in self.current_fg.factors:
                if f == factor_a:
                    continue  # Skip removed factor
                elif f == factor_a_down:
                    new_vars = [var_map[v.name] for v in f.variables]
                    new_fg.add_factor(Factor(f.name, new_vars, updated_a_down_table))
                else:
                    new_vars = [var_map[v.name] for v in f.variables]
                    new_fg.add_factor(Factor(f.name, new_vars, f.table.copy()))
            
            case = 'eq_4.30'
            details = {'c_B': c_B}
        
        self.reduction_steps.append(ReductionStep(
            step_type='colinear',
            removed_element=factor_a.name,
            target_element=factor_a_down.name,
            element_type='factor',
            case=case,
            details=details
        ))
        
        self.current_fg = new_fg
        return new_fg
    
    def _check_colinear_case(self, factor_a: Factor, factor_a_down: Factor) -> Tuple[bool, Optional[str]]:
        """
        Determine which colinear reduction case applies.
        
        Case 1 (Eq 4.29): a↓ becomes linear after removing a
            - This happens when variables of a↓ are each in exactly one other factor
            - Returns (True, name of b↑)
            
        Case 2 (Eq 4.30): a↓ is not linear after removing a
            - Returns (False, None)
        """
        # Get variables in a↓
        a_down_vars = set(factor_a_down.variables)
        
        # For each variable in a↓, find factors containing it (excluding a and a↓)
        other_factors_per_var = {}
        for var in factor_a_down.variables:
            factors_with_var = [f for f in self.current_fg.neighbors_of_variable(var)
                               if f != factor_a and f != factor_a_down]
            other_factors_per_var[var.name] = factors_with_var
        
        # For a↓ to be linear in B, there must be a unique factor b↑ that contains
        # all variables of a↓ (and possibly more)
        
        # Find candidate b↑: factors that contain ALL variables of a↓
        candidate_b_ups = None
        for var in factor_a_down.variables:
            var_factors = set(f.name for f in other_factors_per_var[var.name])
            if candidate_b_ups is None:
                candidate_b_ups = var_factors
            else:
                candidate_b_ups = candidate_b_ups.intersection(var_factors)
        
        if candidate_b_ups is None or len(candidate_b_ups) == 0:
            # No factor contains all variables of a↓ → Case 2
            return False, None
        
        if len(candidate_b_ups) == 1:
            # Unique b↑ → Case 1
            return True, list(candidate_b_ups)[0]
        
        # Multiple candidates - need to find if one is minimal (dominates others)
        # For now, just pick one (the theory says there should be a unique one if linear)
        # Actually if there are multiple, a↓ is not linear
        return False, None
    
    def _compute_marginalized_sum(self, factor_a: Factor, factor_a_down: Factor) -> np.ndarray:
        """
        Compute Σ_{z ∈ E_a : z|_{a↓} = x} ψ_a(z) for each x ∈ E_{a↓}.
        
        This marginalizes factor_a over variables not in factor_a_down.
        """
        # Variables to marginalize (in a but not in a↓)
        a_down_var_set = set(factor_a_down.variables)
        marginalize_vars = [v for v in factor_a.variables if v not in a_down_var_set]
        
        result_shape = tuple(v.cardinality for v in factor_a_down.variables)
        result = np.zeros(result_shape)
        
        # For each x in E_{a↓}
        a_down_ranges = [range(v.cardinality) for v in factor_a_down.variables]
        
        for a_down_indices in product(*a_down_ranges):
            x_assignment = {
                v: v.index_to_value(idx)
                for v, idx in zip(factor_a_down.variables, a_down_indices)
            }
            
            # Sum over z where z|_{a↓} = x
            total = 0.0
            
            if marginalize_vars:
                marg_ranges = [range(v.cardinality) for v in marginalize_vars]
                for marg_indices in product(*marg_ranges):
                    z_assignment = x_assignment.copy()
                    for v, idx in zip(marginalize_vars, marg_indices):
                        z_assignment[v] = v.index_to_value(idx)
                    
                    total += factor_a.get_value(z_assignment)
            else:
                # No variables to marginalize
                total = factor_a.get_value(x_assignment)
            
            result[a_down_indices] = total
        
        return result
    
    def _apply_eq_4_29(self, b_up: Factor, a_down: Factor, 
                        sum_psi_a: np.ndarray) -> np.ndarray:
        """
        Apply Eq 4.29: ψ'_{b↑}(k) = ψ_{b↑}(k) · [Σ_z ψ_a(z)] / ψ_{a↓}(x)
        where x = k|_{a↓}
        """
        result = b_up.table.copy()
        
        # Get mapping from b_up variables to a_down variables
        a_down_var_names = [v.name for v in a_down.variables]
        
        # For each k in E_{b↑}
        b_up_ranges = [range(v.cardinality) for v in b_up.variables]
        
        for b_up_indices in product(*b_up_ranges):
            k_assignment = {
                v: v.index_to_value(idx)
                for v, idx in zip(b_up.variables, b_up_indices)
            }
            
            # Project k to x = k|_{a↓}
            a_down_indices = tuple(
                a_down.variables[i].value_to_index(k_assignment[a_down.variables[i]])
                for i in range(len(a_down.variables))
            )
            
            # Get values
            psi_b_up_k = b_up.table[b_up_indices]
            sum_psi_a_x = sum_psi_a[a_down_indices]
            psi_a_down_x = a_down.table[a_down_indices]
            
            # Avoid division by zero
            if psi_a_down_x > 1e-300:
                result[b_up_indices] = psi_b_up_k * sum_psi_a_x / psi_a_down_x
            else:
                result[b_up_indices] = psi_b_up_k * sum_psi_a_x
        
        return result
    
    def _apply_eq_4_30(self, a_down: Factor, sum_psi_a: np.ndarray, 
                        c_B: int) -> np.ndarray:
        """
        Apply Eq 4.30: ψ'_{a↓}(x) = ψ_{a↓}(x)^{1-1/c} · [Σ_z ψ_a(z)]^{1/c}
        """
        if c_B == 0:
            return sum_psi_a.copy()
        
        exponent_original = 1.0 - 1.0 / c_B
        exponent_sum = 1.0 / c_B
        
        # Element-wise computation
        result = np.power(a_down.table, exponent_original) * np.power(sum_psi_a, exponent_sum)
        
        return result
    
    def reduce_once(self) -> Optional[ReductionStep]:
        """Perform a single reduction step if possible."""
        # Try linear reduction first
        linear_vars = self.find_linear_variables()
        if linear_vars:
            var, target = linear_vars[0]
            self.retract_linear_variable(var, target)
            return self.reduction_steps[-1]
        
        # Try colinear reduction
        colinear_factors = self.find_colinear_factors()
        if colinear_factors:
            factor, target = colinear_factors[0]
            self.retract_colinear_factor(factor, target)
            return self.reduction_steps[-1]
        
        return None
    
    def reduce_to_core(self) -> ReductionResult:
        """Reduce the factor graph to its core."""
        while True:
            step = self.reduce_once()
            if step is None:
                break
        
        return ReductionResult(
            original_graph=self.original_fg,
            reduced_graph=self.current_fg,
            reduction_steps=self.reduction_steps,
            is_core=True
        )
    
    def reduce_n_steps(self, n: int) -> ReductionResult:
        """Perform at most n reduction steps."""
        for _ in range(n):
            step = self.reduce_once()
            if step is None:
                break
        
        is_core = (self.find_linear_variables() == [] and 
                   self.find_colinear_factors() == [])
        
        return ReductionResult(
            original_graph=self.original_fg,
            reduced_graph=self.current_fg,
            reduction_steps=self.reduction_steps,
            is_core=is_core
        )
    
    def get_current_graph(self) -> FactorGraph:
        """Get the current (possibly partially reduced) factor graph"""
        return self.current_fg
    
    def reset(self):
        """Reset to original factor graph"""
        self.current_fg = self.original_fg.copy()
        self.reduction_steps = []


def reduce_factor_graph(fg: FactorGraph, to_core: bool = True) -> ReductionResult:
    """Convenience function to reduce a factor graph."""
    reducer = FactorGraphReducer(fg)
    if to_core:
        return reducer.reduce_to_core()
    else:
        reducer.reduce_once()
        return ReductionResult(
            original_graph=fg,
            reduced_graph=reducer.current_fg,
            reduction_steps=reducer.reduction_steps,
            is_core=False
        )
