"""
Poset Module

Represents partially ordered sets (posets) associated with factor graphs.

For a factor graph with hypergraph H = (I, A):
- The poset A(H) has elements I ∪ A
- The partial order is: i ≤ a iff variable i is in factor a's scope

This module supports:
- Basic poset operations
- Detection of linear/colinear points
- Deformation retractions for Bethe free energy optimization
"""

from typing import List, Dict, Set, Tuple, Optional, Any, Callable
from dataclasses import dataclass, field
from collections import defaultdict


@dataclass
class PosetElement:
    """
    An element in a poset.
    
    For factor graphs:
    - element_type='var' for variable nodes
    - element_type='fac' for factor nodes
    """
    element_type: str  # 'var' or 'fac'
    name: str
    
    def __hash__(self):
        return hash((self.element_type, self.name))
    
    def __eq__(self, other):
        if isinstance(other, PosetElement):
            return self.element_type == other.element_type and self.name == other.name
        if isinstance(other, tuple):
            return (self.element_type, self.name) == other
        return False
    
    def __repr__(self):
        return f"({self.element_type}, {self.name})"


class Poset:
    """
    A finite partially ordered set.
    
    The poset is stored as:
    - elements: set of elements
    - covers: dict mapping each element to its direct upper covers
    - covered_by: dict mapping each element to elements it covers
    
    For factor graph posets, chains have length at most 1:
    - Variables are minimal elements
    - Factors are maximal elements
    - A variable covers a factor iff the variable is in the factor's scope
    """
    
    def __init__(self, elements: List[Tuple[str, str]] = None, 
                 relations: List[Tuple[Tuple[str, str], Tuple[str, str]]] = None):
        """
        Initialize a poset.
        
        Args:
            elements: List of (type, name) tuples
            relations: List of ((type1, name1), (type2, name2)) pairs where first < second
        """
        self._elements: Set[PosetElement] = set()
        
        # Covers: a covers b means a < b and there's no c with a < c < b
        # For factor graphs with chains of length ≤ 1, these are all relations
        self._covers: Dict[PosetElement, Set[PosetElement]] = defaultdict(set)
        self._covered_by: Dict[PosetElement, Set[PosetElement]] = defaultdict(set)
        
        if elements:
            for etype, name in elements:
                self._elements.add(PosetElement(etype, name))
        
        if relations:
            for (t1, n1), (t2, n2) in relations:
                e1 = PosetElement(t1, n1)
                e2 = PosetElement(t2, n2)
                self._covers[e1].add(e2)
                self._covered_by[e2].add(e1)
    
    @property
    def elements(self) -> Set[PosetElement]:
        return self._elements.copy()
    
    def add_element(self, element: PosetElement):
        """Add an element to the poset"""
        self._elements.add(element)
    
    def add_relation(self, lower: PosetElement, upper: PosetElement):
        """Add a cover relation lower < upper"""
        if lower not in self._elements:
            self._elements.add(lower)
        if upper not in self._elements:
            self._elements.add(upper)
        self._covers[lower].add(upper)
        self._covered_by[upper].add(lower)
    
    def upper_covers(self, element: PosetElement) -> Set[PosetElement]:
        """Get elements that directly cover this element (immediate successors)"""
        return self._covers.get(element, set()).copy()
    
    def lower_covers(self, element: PosetElement) -> Set[PosetElement]:
        """Get elements directly covered by this element (immediate predecessors)"""
        return self._covered_by.get(element, set()).copy()
    
    def is_less_than(self, a: PosetElement, b: PosetElement) -> bool:
        """Check if a < b in the poset (transitive closure)"""
        if a == b:
            return False
        
        # BFS from a to b
        visited = {a}
        queue = [a]
        
        while queue:
            current = queue.pop(0)
            for upper in self._covers.get(current, set()):
                if upper == b:
                    return True
                if upper not in visited:
                    visited.add(upper)
                    queue.append(upper)
        
        return False
    
    def is_less_than_or_equal(self, a: PosetElement, b: PosetElement) -> bool:
        """Check if a ≤ b"""
        return a == b or self.is_less_than(a, b)
    
    def is_minimal(self, element: PosetElement) -> bool:
        """Check if element is minimal (has no elements below it)"""
        return len(self._covered_by.get(element, set())) == 0
    
    def is_maximal(self, element: PosetElement) -> bool:
        """Check if element is maximal (has no elements above it)"""
        return len(self._covers.get(element, set())) == 0
    
    def minimal_elements(self) -> Set[PosetElement]:
        """Get all minimal elements"""
        return {e for e in self._elements if self.is_minimal(e)}
    
    def maximal_elements(self) -> Set[PosetElement]:
        """Get all maximal elements"""
        return {e for e in self._elements if self.is_maximal(e)}
    
    def is_linear_point(self, element: PosetElement) -> Tuple[bool, Optional[PosetElement]]:
        """
        Check if element is a linear (up-beat) point.
        
        An element a is linear if there exists a_up > a such that
        for all b >= a, we have b >= a_up.
        
        For factor graphs: a variable is linear if all factors containing it
        also contain some specific other variable.
        
        Returns:
            (is_linear, a_up) where a_up is the element if linear, else None
        """
        upper_covers = self.upper_covers(element)
        
        if not upper_covers:
            # Maximal element - check if all b with a ≤ b have b ≤ a↑
            # For maximal elements, this means there's a↑ ≠ a such that
            # the only element ≥ a is a itself, so need a↑ > a which contradicts maximality
            # Actually for maximal elements, condition is vacuously true for any a↑
            # But we need a↑ > a to exist, which doesn't for maximal elements
            return False, None
        
        # For each potential a↑ in upper covers, check if all upper elements 
        # are also above a↑
        all_upper = self._get_all_upper(element)
        
        for candidate in upper_covers:
            # Check if all elements ≥ element are also ≥ candidate
            valid = True
            for upper in all_upper:
                if upper != element and not self.is_less_than_or_equal(candidate, upper):
                    valid = False
                    break
            if valid:
                return True, candidate
        
        return False, None
    
    def is_colinear_point(self, element: PosetElement) -> Tuple[bool, Optional[PosetElement]]:
        """
        Check if element is a colinear (down-beat) point.
        
        An element a is colinear if there exists a_down < a such that
        for all b <= a, we have b <= a_down.
        
        For factor graphs: a factor is colinear if its scope is contained
        in the scope of some other factor.
        
        Returns:
            (is_colinear, a_down) where a_down is the element if colinear, else None
        """
        lower_covers = self.lower_covers(element)
        
        if not lower_covers:
            # Minimal element - need a↓ < a which doesn't exist
            return False, None
        
        # For each potential a↓ in lower covers, check if all lower elements
        # are also below a↓
        all_lower = self._get_all_lower(element)
        
        for candidate in lower_covers:
            # Check if all elements ≤ element are also ≤ candidate  
            valid = True
            for lower in all_lower:
                if lower != element and not self.is_less_than_or_equal(lower, candidate):
                    valid = False
                    break
            if valid:
                return True, candidate
        
        return False, None
    
    def _get_all_upper(self, element: PosetElement) -> Set[PosetElement]:
        """Get all elements b such that element ≤ b (transitive closure upward)"""
        result = {element}
        queue = [element]
        
        while queue:
            current = queue.pop(0)
            for upper in self._covers.get(current, set()):
                if upper not in result:
                    result.add(upper)
                    queue.append(upper)
        
        return result
    
    def _get_all_lower(self, element: PosetElement) -> Set[PosetElement]:
        """Get all elements b such that b ≤ element (transitive closure downward)"""
        result = {element}
        queue = [element]
        
        while queue:
            current = queue.pop(0)
            for lower in self._covered_by.get(current, set()):
                if lower not in result:
                    result.add(lower)
                    queue.append(lower)
        
        return result
    
    def find_linear_points(self) -> List[Tuple[PosetElement, PosetElement]]:
        """Find all linear points and their a↑ elements"""
        result = []
        for element in self._elements:
            is_linear, a_up = self.is_linear_point(element)
            if is_linear:
                result.append((element, a_up))
        return result
    
    def find_colinear_points(self) -> List[Tuple[PosetElement, PosetElement]]:
        """Find all colinear points and their a↓ elements"""
        result = []
        for element in self._elements:
            is_colinear, a_down = self.is_colinear_point(element)
            if is_colinear:
                result.append((element, a_down))
        return result
    
    def retract_linear_point(self, element: PosetElement) -> 'Poset':
        """
        Retract a linear point from the poset.
        
        The retraction r_{a↑} sends a to a↑ and is identity elsewhere.
        Returns a new poset B = A \ {a}.
        
        This is a deformation retraction (Proposition 3 in Sergeant-Perthuis & Boitel).
        """
        is_linear, a_up = self.is_linear_point(element)
        if not is_linear:
            raise ValueError(f"{element} is not a linear point")
        
        # Create new poset without this element
        new_poset = Poset()
        
        for e in self._elements:
            if e != element:
                new_poset._elements.add(e)
        
        for e, covers in self._covers.items():
            if e != element:
                new_poset._covers[e] = {c for c in covers if c != element}
        
        for e, covered in self._covered_by.items():
            if e != element:
                new_poset._covered_by[e] = {c for c in covered if c != element}
        
        return new_poset
    
    def retract_colinear_point(self, element: PosetElement) -> 'Poset':
        """
        Retract a colinear point from the poset.
        
        The retraction r_{a↓} sends a to a↓ and is identity elsewhere.
        Returns a new poset B = A \ {a}.
        """
        is_colinear, a_down = self.is_colinear_point(element)
        if not is_colinear:
            raise ValueError(f"{element} is not a colinear point")
        
        # Create new poset without this element
        new_poset = Poset()
        
        for e in self._elements:
            if e != element:
                new_poset._elements.add(e)
        
        for e, covers in self._covers.items():
            if e != element:
                new_poset._covers[e] = {c for c in covers if c != element}
        
        for e, covered in self._covered_by.items():
            if e != element:
                new_poset._covered_by[e] = {c for c in covered if c != element}
        
        return new_poset
    
    def compute_core(self) -> Tuple['Poset', List[Tuple[str, PosetElement, PosetElement]]]:
        """
        Compute the core of the poset.
        
        The core is the minimal sub-poset homotopy equivalent to the original,
        obtained by successively retracting linear and colinear points.
        
        Returns:
            (core_poset, retraction_sequence) where retraction_sequence is a list of
            ('linear', element, a_up) or ('colinear', element, a_down) tuples
        """
        current = self.copy()
        retractions = []
        
        changed = True
        while changed:
            changed = False
            
            # Try to find and retract a linear point
            linear_points = current.find_linear_points()
            if linear_points:
                element, a_up = linear_points[0]
                current = current.retract_linear_point(element)
                retractions.append(('linear', element, a_up))
                changed = True
                continue
            
            # Try to find and retract a colinear point
            colinear_points = current.find_colinear_points()
            if colinear_points:
                element, a_down = colinear_points[0]
                current = current.retract_colinear_point(element)
                retractions.append(('colinear', element, a_down))
                changed = True
        
        return current, retractions
    
    def mobius_function(self, a: PosetElement, b: PosetElement) -> int:
        """
        Compute the Möbius function μ(a, b).
        
        μ(a, a) = 1
        μ(a, b) = -∑_{a ≤ c < b} μ(a, c) for a < b
        μ(a, b) = 0 if a ≰ b
        """
        if not self.is_less_than_or_equal(a, b):
            return 0
        if a == b:
            return 1
        
        # Recursively compute
        result = 0
        for c in self._get_all_upper(a):
            if c != b and self.is_less_than(c, b):
                result -= self.mobius_function(a, c)
            elif c == a:
                result -= 1  # μ(a, a) = 1
        
        return result
    
    def counting_number(self, a: PosetElement) -> int:
        """
        Compute the counting number c(a) = ∑_{b ≥ a} μ(b, a).
        
        This appears as coefficients in the Bethe free energy.
        """
        result = 0
        for b in self._get_all_upper(a):
            result += self.mobius_function(b, a)
        return result
    
    def copy(self) -> 'Poset':
        """Create a copy of this poset"""
        new_poset = Poset()
        new_poset._elements = self._elements.copy()
        new_poset._covers = {k: v.copy() for k, v in self._covers.items()}
        new_poset._covered_by = {k: v.copy() for k, v in self._covered_by.items()}
        return new_poset
    
    def chain_length(self) -> int:
        """
        Compute the maximum chain length (number of relations in longest chain).
        
        For factor graphs, this should be at most 1.
        """
        max_length = 0
        
        for element in self._elements:
            # Find longest chain starting from this element
            length = self._longest_chain_from(element)
            max_length = max(max_length, length)
        
        return max_length
    
    def _longest_chain_from(self, element: PosetElement) -> int:
        """Find the longest chain starting from element"""
        covers = self._covers.get(element, set())
        if not covers:
            return 0
        return 1 + max(self._longest_chain_from(c) for c in covers)
    
    def __len__(self):
        return len(self._elements)
    
    def __repr__(self):
        return f"Poset({len(self._elements)} elements, chain_length={self.chain_length()})"
    
    def __str__(self):
        lines = [f"Poset with {len(self._elements)} elements:"]
        lines.append(f"  Minimal: {[str(e) for e in self.minimal_elements()]}")
        lines.append(f"  Maximal: {[str(e) for e in self.maximal_elements()]}")
        lines.append(f"  Relations:")
        for e, covers in sorted(self._covers.items(), key=lambda x: str(x[0])):
            if covers:
                lines.append(f"    {e} < {[str(c) for c in covers]}")
        return "\n".join(lines)
