"""
Factor Graph Core Module

This module provides the fundamental data structures for representing factor graphs:
- Variable: A random variable with a finite set of possible values
- Factor: A potential function over a subset of variables
- FactorGraph: The complete factor graph structure

Based on the formalism in Napp & Adams (2013) and Sergeant-Perthuis & Boitel.
"""

from typing import List, Dict, Tuple, Set, Optional, Union, Callable
from itertools import product
import numpy as np
from dataclasses import dataclass, field


@dataclass
class Variable:
    """
    A discrete random variable in a factor graph.
    
    Attributes:
        name: Unique identifier for the variable
        values: List of possible values the variable can take (events)
        
    Example:
        x1 = Variable("x1", [0, 1])        # Binary variable
        x2 = Variable("x2", [0, 1, 2])     # Ternary variable
    """
    name: str
    values: List[int]
    
    def __post_init__(self):
        if len(self.values) == 0:
            raise ValueError(f"Variable {self.name} must have at least one value")
        if len(self.values) != len(set(self.values)):
            raise ValueError(f"Variable {self.name} has duplicate values")
    
    @property
    def cardinality(self) -> int:
        """Number of possible values (K_n in the papers)"""
        return len(self.values)
    
    def value_to_index(self, value: int) -> int:
        """Convert a value to its index in the values list"""
        return self.values.index(value)
    
    def index_to_value(self, index: int) -> int:
        """Convert an index to its corresponding value"""
        return self.values[index]
    
    def __hash__(self):
        return hash(self.name)
    
    def __eq__(self, other):
        if isinstance(other, Variable):
            return self.name == other.name
        return False
    
    def __repr__(self):
        return f"Variable({self.name}, {self.values})"


@dataclass
class Factor:
    """
    A factor (potential function) in a factor graph.
    
    A factor ψ_j(x^j) is a non-negative function over a subset of variables.
    The factor table stores ψ_j(k^j) for all configurations k^j in K^j.
    
    Attributes:
        name: Unique identifier for the factor
        variables: List of variables this factor depends on (neighbors in the graph)
        table: The factor values as a numpy array, indexed by variable configurations
        
    The table is stored as a |K_1| x |K_2| x ... x |K_m| array where the variables
    are ordered as in self.variables.
    """
    name: str
    variables: List[Variable]
    table: np.ndarray = field(default=None, repr=False)
    
    def __post_init__(self):
        if self.table is None:
            # Initialize with uniform factor (all ones)
            shape = tuple(v.cardinality for v in self.variables)
            self.table = np.ones(shape)
        else:
            # Validate table shape
            expected_shape = tuple(v.cardinality for v in self.variables)
            if self.table.shape != expected_shape:
                raise ValueError(
                    f"Factor {self.name} table shape {self.table.shape} "
                    f"doesn't match variable cardinalities {expected_shape}"
                )
        
        # Ensure non-negative values (for standard factor graphs)
        if np.any(self.table < 0):
            raise ValueError(f"Factor {self.name} has negative values")
    
    @property 
    def scope(self) -> Set[Variable]:
        """Set of variables this factor depends on"""
        return set(self.variables)
    
    @property
    def arity(self) -> int:
        """Number of variables this factor depends on"""
        return len(self.variables)
    
    def get_value(self, assignment: Dict[Variable, int]) -> float:
        """
        Get the factor value for a given assignment of variables.
        
        Args:
            assignment: Dictionary mapping variables to their values
            
        Returns:
            The factor value ψ(x^j = assignment)
        """
        indices = tuple(
            var.value_to_index(assignment[var]) 
            for var in self.variables
        )
        return self.table[indices]
    
    def set_value(self, assignment: Dict[Variable, int], value: float):
        """Set the factor value for a given assignment"""
        if value < 0:
            raise ValueError("Factor values must be non-negative")
        indices = tuple(
            var.value_to_index(assignment[var]) 
            for var in self.variables
        )
        self.table[indices] = value
    
    def iter_configurations(self):
        """
        Iterate over all configurations of the factor's variables.
        
        Yields:
            Tuples of (assignment_dict, indices_tuple) for each configuration
        """
        ranges = [range(v.cardinality) for v in self.variables]
        for indices in product(*ranges):
            assignment = {
                var: var.index_to_value(idx) 
                for var, idx in zip(self.variables, indices)
            }
            yield assignment, indices
    
    def marginalize(self, keep_vars: List[Variable]) -> 'Factor':
        """
        Marginalize (sum out) variables not in keep_vars.
        
        Args:
            keep_vars: Variables to keep after marginalization
            
        Returns:
            New factor over keep_vars
        """
        # Find axes to sum over
        sum_axes = tuple(
            i for i, var in enumerate(self.variables) 
            if var not in keep_vars
        )
        
        if not sum_axes:
            return Factor(f"{self.name}_marg", keep_vars, self.table.copy())
        
        # Reorder variables so keep_vars come first
        keep_indices = [self.variables.index(v) for v in keep_vars]
        sum_indices = [i for i in range(len(self.variables)) if i not in keep_indices]
        
        # Transpose and sum
        new_order = keep_indices + sum_indices
        transposed = np.transpose(self.table, new_order)
        
        # Sum over the last len(sum_indices) dimensions
        new_table = transposed
        for _ in range(len(sum_indices)):
            new_table = new_table.sum(axis=-1)
        
        return Factor(f"{self.name}_marg", keep_vars, new_table)
    
    def __hash__(self):
        return hash(self.name)
    
    def __eq__(self, other):
        if isinstance(other, Factor):
            return self.name == other.name
        return False
    
    def __repr__(self):
        var_names = [v.name for v in self.variables]
        return f"Factor({self.name}, vars={var_names})"


class FactorGraph:
    """
    A factor graph representation.
    
    A factor graph G = (V, F, E) consists of:
    - V: Variable nodes
    - F: Factor nodes  
    - E: Edges connecting factors to variables they depend on
    
    This corresponds to a hypergraph H = (I, A) where:
    - I = variable indices
    - A = hyperedges (each factor's scope)
    
    The associated poset A(H) has elements I ∪ A with i ≤ a iff i ∈ a.
    """
    
    def __init__(self, name: str = "FactorGraph"):
        self.name = name
        self._variables: Dict[str, Variable] = {}
        self._factors: Dict[str, Factor] = {}
        
    def add_variable(self, variable: Variable) -> Variable:
        """Add a variable to the factor graph"""
        if variable.name in self._variables:
            raise ValueError(f"Variable {variable.name} already exists")
        self._variables[variable.name] = variable
        return variable
    
    def add_variables(self, *variables: Variable) -> List[Variable]:
        """Add multiple variables"""
        return [self.add_variable(v) for v in variables]
    
    def add_factor(self, factor: Factor) -> Factor:
        """
        Add a factor to the factor graph.
        All variables in the factor's scope must already be in the graph.
        """
        if factor.name in self._factors:
            raise ValueError(f"Factor {factor.name} already exists")
        
        for var in factor.variables:
            if var.name not in self._variables:
                raise ValueError(
                    f"Variable {var.name} in factor {factor.name} "
                    "not found in graph"
                )
        
        self._factors[factor.name] = factor
        return factor
    
    def add_factors(self, *factors: Factor) -> List[Factor]:
        """Add multiple factors"""
        return [self.add_factor(f) for f in factors]
    
    def get_variable(self, name: str) -> Variable:
        """Get a variable by name"""
        return self._variables[name]
    
    def get_factor(self, name: str) -> Factor:
        """Get a factor by name"""
        return self._factors[name]
    
    @property
    def variables(self) -> List[Variable]:
        """List of all variables"""
        return list(self._variables.values())
    
    @property
    def factors(self) -> List[Factor]:
        """List of all factors"""
        return list(self._factors.values())
    
    @property
    def num_variables(self) -> int:
        return len(self._variables)
    
    @property
    def num_factors(self) -> int:
        return len(self._factors)
    
    def neighbors_of_variable(self, var: Union[Variable, str]) -> List[Factor]:
        """Get all factors connected to a variable (N(i) in the papers)"""
        if isinstance(var, str):
            var = self._variables[var]
        return [f for f in self._factors.values() if var in f.scope]
    
    def neighbors_of_factor(self, factor: Union[Factor, str]) -> List[Variable]:
        """Get all variables connected to a factor (N(a) = ne(j) in the papers)"""
        if isinstance(factor, str):
            factor = self._factors[factor]
        return list(factor.variables)
    
    def get_edges(self) -> List[Tuple[str, str]]:
        """Get all edges as (factor_name, variable_name) pairs"""
        edges = []
        for factor in self._factors.values():
            for var in factor.variables:
                edges.append((factor.name, var.name))
        return edges
    
    def is_tree(self) -> bool:
        """
        Check if the factor graph is a tree (acyclic).
        A factor graph is a tree if there's exactly one path between any two nodes.
        """
        # Use BFS to detect cycles
        if not self._factors:
            return True
            
        # Build adjacency list
        adj = {v.name: set() for v in self._variables.values()}
        adj.update({f.name: set() for f in self._factors.values()})
        
        for factor in self._factors.values():
            for var in factor.variables:
                adj[factor.name].add(var.name)
                adj[var.name].add(factor.name)
        
        # BFS from first node
        start = next(iter(adj.keys()))
        visited = {start}
        queue = [(start, None)]  # (node, parent)
        
        while queue:
            node, parent = queue.pop(0)
            for neighbor in adj[node]:
                if neighbor == parent:
                    continue
                if neighbor in visited:
                    return False  # Cycle detected
                visited.add(neighbor)
                queue.append((neighbor, node))
        
        # Check connectivity
        return len(visited) == len(adj)
    
    def compute_joint_unnormalized(self, assignment: Dict[Variable, int]) -> float:
        """
        Compute the unnormalized joint probability for a full assignment.
        P(x) ∝ ∏_j ψ_j(x^j)
        """
        result = 1.0
        for factor in self._factors.values():
            result *= factor.get_value(assignment)
        return result
    
    def compute_partition_function(self) -> float:
        """
        Compute the partition function Z = ∑_x ∏_j ψ_j(x^j)
        Warning: This is exponential in the number of variables!
        """
        Z = 0.0
        for assignment, _ in self._iter_all_assignments():
            Z += self.compute_joint_unnormalized(assignment)
        return Z
    
    def compute_marginal_exact(self, var: Variable) -> np.ndarray:
        """
        Compute exact marginal distribution for a variable.
        Warning: This is exponential in the number of variables!
        
        Returns:
            Array of probabilities for each value of the variable
        """
        marginal = np.zeros(var.cardinality)
        
        for assignment, _ in self._iter_all_assignments():
            prob = self.compute_joint_unnormalized(assignment)
            idx = var.value_to_index(assignment[var])
            marginal[idx] += prob
        
        # Normalize
        marginal /= marginal.sum()
        return marginal
    
    def _iter_all_assignments(self):
        """Iterate over all possible assignments to all variables"""
        vars_list = list(self._variables.values())
        ranges = [range(v.cardinality) for v in vars_list]
        
        for indices in product(*ranges):
            assignment = {
                var: var.index_to_value(idx)
                for var, idx in zip(vars_list, indices)
            }
            yield assignment, indices
    
    def to_poset(self) -> 'Poset':
        """
        Convert the factor graph to its associated poset A(H).
        
        The poset has elements I ∪ A with i ≤ a iff variable i is in factor a's scope.
        """
        from core.poset import Poset
        
        elements = []
        relations = []
        
        # Add variable nodes (minimal elements)
        for var in self._variables.values():
            elements.append(('var', var.name))
        
        # Add factor nodes (maximal elements)
        for factor in self._factors.values():
            elements.append(('fac', factor.name))
            # Add relations: var < factor for each var in factor's scope
            for var in factor.variables:
                relations.append((('var', var.name), ('fac', factor.name)))
        
        return Poset(elements, relations)
    
    def copy(self) -> 'FactorGraph':
        """Create a deep copy of the factor graph"""
        new_graph = FactorGraph(self.name + "_copy")
        
        # Copy variables
        for var in self._variables.values():
            new_graph.add_variable(Variable(var.name, var.values.copy()))
        
        # Copy factors with references to new variables
        for factor in self._factors.values():
            new_vars = [new_graph.get_variable(v.name) for v in factor.variables]
            new_graph.add_factor(Factor(factor.name, new_vars, factor.table.copy()))
        
        return new_graph
    
    def __repr__(self):
        return (f"FactorGraph({self.name}, "
                f"variables={[v.name for v in self.variables]}, "
                f"factors={[f.name for f in self.factors]})")


# Convenience function for building factor graphs
def build_factor_graph(
    variables_spec: Dict[str, List[int]],
    factors_spec: Dict[str, Tuple[List[str], Optional[np.ndarray]]],
    name: str = "FactorGraph"
) -> FactorGraph:
    """
    Build a factor graph from specifications.
    
    Args:
        variables_spec: Dict mapping variable names to their possible values
            Example: {"x1": [0, 1], "x2": [0, 1, 2]}
        factors_spec: Dict mapping factor names to (variable_names, table) tuples
            Example: {"f1": (["x1", "x2"], np.array([[1, 0.1], [0.1, 1]]))}
            If table is None, initializes to all ones.
        name: Name for the factor graph
        
    Returns:
        Constructed FactorGraph
    """
    fg = FactorGraph(name)
    
    # Create variables
    for var_name, values in variables_spec.items():
        fg.add_variable(Variable(var_name, values))
    
    # Create factors
    for factor_name, (var_names, table) in factors_spec.items():
        variables = [fg.get_variable(vn) for vn in var_names]
        fg.add_factor(Factor(factor_name, variables, table))
    
    return fg
