"""
Belief Propagation Module

Implements the sum-product algorithm (belief propagation) for factor graphs.

Based on:
- Napp & Adams (2013): Equations 3-5 for message updates
- Sergeant-Perthuis & Boitel: Section 2.3-2.5 for the general BP algorithm

Message Types (following Napp & Adams notation):
- S^(j→n)_k : Sum message from factor j to variable n, component k
- P^(n→j)_k : Product message from variable n to factor j, component k

The algorithm computes marginal distributions via message passing:
- For trees: exact marginals
- For loopy graphs: approximate marginals (loopy BP)
"""

from typing import List, Dict, Tuple, Set, Optional, Union
from dataclasses import dataclass, field
from itertools import product
import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph


@dataclass
class Message:
    """
    A message in belief propagation.
    
    Messages are vectors indexed by the values of a variable.
    
    Attributes:
        source: Name of the sending node (factor or variable)
        target: Name of the receiving node (variable or factor)
        variable: The variable whose values index this message
        values: The message vector (length = variable.cardinality)
        message_type: 'sum' (factor→variable) or 'product' (variable→factor)
    """
    source: str
    target: str
    variable: Variable
    values: np.ndarray
    message_type: str  # 'sum' or 'product'
    
    def __post_init__(self):
        if len(self.values) != self.variable.cardinality:
            raise ValueError(
                f"Message length {len(self.values)} doesn't match "
                f"variable {self.variable.name} cardinality {self.variable.cardinality}"
            )
    
    def normalized(self) -> np.ndarray:
        """Return normalized message (sums to 1)"""
        total = np.sum(self.values)
        if total > 0:
            return self.values / total
        return self.values
    
    def log_values(self) -> np.ndarray:
        """Return log of message values (for numerical stability)"""
        return np.log(self.values + 1e-300)
    
    def copy(self) -> 'Message':
        """Create a copy of this message"""
        return Message(
            self.source, self.target, self.variable,
            self.values.copy(), self.message_type
        )
    
    def __repr__(self):
        return f"Message({self.source}→{self.target}, {self.message_type}, {np.round(self.values, 4)})"


@dataclass
class BPResult:
    """
    Results from belief propagation inference.
    
    Attributes:
        marginals: Dict mapping variable names to marginal distributions
        messages: Dict of all converged messages
        converged: Whether BP converged
        iterations: Number of iterations run
        max_diff: Maximum message change in final iteration
    """
    marginals: Dict[str, np.ndarray]
    messages: Dict[Tuple[str, str], Message]
    converged: bool
    iterations: int
    max_diff: float
    
    def get_marginal(self, var_name: str) -> np.ndarray:
        """Get marginal distribution for a variable"""
        return self.marginals[var_name]
    
    def __repr__(self):
        status = "converged" if self.converged else "not converged"
        return f"BPResult({status}, {self.iterations} iterations, max_diff={self.max_diff:.6f})"


class BeliefPropagation:
    """
    Sum-product belief propagation for factor graphs.
    
    Implements the message passing equations from Napp & Adams:
    
    Sum messages (factor j → variable n), Equation 3:
        S^(j→n)_k = Σ_{k^j: k^j_n = k} ψ_j(x^j = k^j) ∏_{n' ∈ ne(j)\\n} P^(n'→j)_{k^j_{n'}}
    
    Product messages (variable n → factor j), Equation 4:
        P^(n→j)_k = ∏_{j' ∈ ne(n)\\j} S^(j'→n)_k
    
    Marginals, Equation 5:
        P(x_n = k) ∝ ∏_{j ∈ ne(n)} S^(j→n)_k
    """
    
    def __init__(self, factor_graph: FactorGraph):
        """
        Initialize BP for a factor graph.
        
        Args:
            factor_graph: The factor graph to perform inference on
        """
        self.fg = factor_graph
        self.messages: Dict[Tuple[str, str], Message] = {}
        self._initialize_messages()
    
    def _initialize_messages(self, init_type: str = 'uniform'):
        """
        Initialize all messages.
        
        Args:
            init_type: 'uniform' (all ones) or 'random'
        """
        self.messages = {}
        
        for factor in self.fg.factors:
            for var in factor.variables:
                # Factor → Variable message (sum message)
                key_fv = (factor.name, var.name)
                if init_type == 'uniform':
                    values = np.ones(var.cardinality)
                else:
                    values = np.random.rand(var.cardinality) + 0.1
                
                self.messages[key_fv] = Message(
                    source=factor.name,
                    target=var.name,
                    variable=var,
                    values=values,
                    message_type='sum'
                )
                
                # Variable → Factor message (product message)
                key_vf = (var.name, factor.name)
                if init_type == 'uniform':
                    values = np.ones(var.cardinality)
                else:
                    values = np.random.rand(var.cardinality) + 0.1
                
                self.messages[key_vf] = Message(
                    source=var.name,
                    target=factor.name,
                    variable=var,
                    values=values,
                    message_type='product'
                )
    
    def _compute_sum_message(self, factor: Factor, target_var: Variable) -> np.ndarray:
        """
        Compute sum message from factor to variable.
        
        S^(j→n)_k = Σ_{k^j: k^j_n = k} ψ_j(x^j = k^j) ∏_{n' ∈ ne(j)\\n} P^(n'→j)_{k^j_{n'}}
        
        This is Equation 3 from Napp & Adams, and Equation 2.11 from Sergeant-Perthuis & Boitel.
        
        Args:
            factor: The factor sending the message
            target_var: The variable receiving the message
            
        Returns:
            Message vector of length target_var.cardinality
        """
        result = np.zeros(target_var.cardinality)
        
        # Get the index of target_var in factor's variable list
        target_idx = factor.variables.index(target_var)
        
        # Get other variables in this factor (ne(j) \ n)
        other_vars = [v for v in factor.variables if v != target_var]
        
        # For each value k of the target variable
        for k in range(target_var.cardinality):
            # Sum over all configurations of other variables
            if not other_vars:
                # Unary factor: just return the factor value
                indices = [0] * len(factor.variables)
                indices[target_idx] = k
                result[k] = factor.table[tuple(indices)]
            else:
                # Sum over configurations of other variables
                other_ranges = [range(v.cardinality) for v in other_vars]
                
                for other_indices in product(*other_ranges):
                    # Build full index tuple for factor table
                    full_indices = []
                    other_iter = iter(other_indices)
                    
                    for i, v in enumerate(factor.variables):
                        if v == target_var:
                            full_indices.append(k)
                        else:
                            full_indices.append(next(other_iter))
                    
                    # Get factor value ψ_j(x^j = k^j)
                    factor_value = factor.table[tuple(full_indices)]
                    
                    # Multiply by incoming product messages from other variables
                    # ∏_{n' ∈ ne(j)\n} P^(n'→j)_{k^j_{n'}}
                    msg_product = 1.0
                    other_iter = iter(other_indices)
                    
                    for v in factor.variables:
                        if v != target_var:
                            v_idx = next(other_iter)
                            # Get message from variable v to this factor
                            msg_key = (v.name, factor.name)
                            msg_product *= self.messages[msg_key].values[v_idx]
                    
                    # Reset iterator for second pass
                    result[k] += factor_value * msg_product
        
        return result
    
    def _compute_product_message(self, source_var: Variable, target_factor: Factor) -> np.ndarray:
        """
        Compute product message from variable to factor.
        
        P^(n→j)_k = ∏_{j' ∈ ne(n)\\j} S^(j'→n)_k
        
        This is Equation 4 from Napp & Adams.
        
        Args:
            source_var: The variable sending the message
            target_factor: The factor receiving the message
            
        Returns:
            Message vector of length source_var.cardinality
        """
        result = np.ones(source_var.cardinality)
        
        # Get all factors connected to this variable except target
        # ne(n) \ j
        other_factors = [f for f in self.fg.neighbors_of_variable(source_var) 
                        if f != target_factor]
        
        # Multiply incoming sum messages from other factors
        for factor in other_factors:
            msg_key = (factor.name, source_var.name)
            result *= self.messages[msg_key].values
        
        return result
    
    def _compute_belief(self, var: Variable) -> np.ndarray:
        """
        Compute the belief (marginal) for a variable.
        
        P(x_n = k) ∝ ∏_{j ∈ ne(n)} S^(j→n)_k
        
        This is Equation 5 from Napp & Adams.
        
        Args:
            var: The variable to compute the marginal for
            
        Returns:
            Normalized marginal distribution
        """
        belief = np.ones(var.cardinality)
        
        # Multiply all incoming sum messages
        for factor in self.fg.neighbors_of_variable(var):
            msg_key = (factor.name, var.name)
            belief *= self.messages[msg_key].values
        
        # Normalize
        total = np.sum(belief)
        if total > 0:
            belief /= total
        
        return belief
    
    def run(self, 
            max_iterations: int = 100,
            tolerance: float = 1e-6,
            damping: float = 0.0,
            update_order: str = 'synchronous') -> BPResult:
        """
        Run belief propagation until convergence or max iterations.
        
        Args:
            max_iterations: Maximum number of iterations
            tolerance: Convergence threshold (max message change)
            damping: Damping factor in [0, 1). New message = (1-damping)*new + damping*old
                     Higher damping can help convergence in loopy graphs.
            update_order: 'synchronous' (all messages at once) or 'asynchronous'
            
        Returns:
            BPResult with marginals and convergence info
        """
        for iteration in range(max_iterations):
            max_diff = 0.0
            
            if update_order == 'synchronous':
                # Compute all new messages first, then update
                new_messages = {}
                
                # Update all factor→variable messages
                for factor in self.fg.factors:
                    for var in factor.variables:
                        key = (factor.name, var.name)
                        new_values = self._compute_sum_message(factor, var)
                        
                        # Apply damping
                        if damping > 0:
                            old_values = self.messages[key].values
                            new_values = (1 - damping) * new_values + damping * old_values
                        
                        new_messages[key] = new_values
                        
                        # Track max change
                        diff = np.max(np.abs(new_values - self.messages[key].values))
                        max_diff = max(max_diff, diff)
                
                # Update all variable→factor messages
                for var in self.fg.variables:
                    for factor in self.fg.neighbors_of_variable(var):
                        key = (var.name, factor.name)
                        new_values = self._compute_product_message(var, factor)
                        
                        # Apply damping
                        if damping > 0:
                            old_values = self.messages[key].values
                            new_values = (1 - damping) * new_values + damping * old_values
                        
                        new_messages[key] = new_values
                        
                        # Track max change
                        diff = np.max(np.abs(new_values - self.messages[key].values))
                        max_diff = max(max_diff, diff)
                
                # Apply all updates
                for key, new_values in new_messages.items():
                    self.messages[key].values = new_values
            
            else:  # asynchronous
                # Update messages one at a time
                # Factor → Variable messages
                for factor in self.fg.factors:
                    for var in factor.variables:
                        key = (factor.name, var.name)
                        new_values = self._compute_sum_message(factor, var)
                        
                        if damping > 0:
                            old_values = self.messages[key].values
                            new_values = (1 - damping) * new_values + damping * old_values
                        # After computing new_values for a message update:
                        s = float(np.sum(new_values))
                        if np.isfinite(s) and s > 0:
                            new_values = new_values / s
                            s = float(np.sum(new_values))
                        else:
                            new_values = np.ones_like(new_values, dtype=float) / len(new_values)
                        diff = np.max(np.abs(new_values - self.messages[key].values))
                        max_diff = max(max_diff, diff)
                        self.messages[key].values = new_values
                
                # Variable → Factor messages
                for var in self.fg.variables:
                    for factor in self.fg.neighbors_of_variable(var):
                        key = (var.name, factor.name)
                        new_values = self._compute_product_message(var, factor)
                        
                        if damping > 0:
                            old_values = self.messages[key].values
                            new_values = (1 - damping) * new_values + damping * old_values
                        if np.isfinite(s) and s > 0:
                            new_values = new_values / s
                            s = float(np.sum(new_values))
                        else:
                            new_values = np.ones_like(new_values, dtype=float) / len(new_values)
                        diff = np.max(np.abs(new_values - self.messages[key].values))
                        max_diff = max(max_diff, diff)
                        self.messages[key].values = new_values
            
            # Check convergence
            if max_diff < tolerance:
                marginals = self._compute_all_marginals()
                return BPResult(
                    marginals=marginals,
                    messages=self.messages.copy(),
                    converged=True,
                    iterations=iteration + 1,
                    max_diff=max_diff
                )
        
        # Did not converge
        marginals = self._compute_all_marginals()
        return BPResult(
            marginals=marginals,
            messages=self.messages.copy(),
            converged=False,
            iterations=max_iterations,
            max_diff=max_diff
        )
    
    def _compute_all_marginals(self) -> Dict[str, np.ndarray]:
        """Compute marginals for all variables"""
        marginals = {}
        for var in self.fg.variables:
            marginals[var.name] = self._compute_belief(var)
        return marginals
    
    def get_message(self, source: str, target: str) -> Message:
        """Get a specific message"""
        return self.messages[(source, target)]
    
    def get_all_sum_messages(self) -> Dict[Tuple[str, str], Message]:
        """Get all factor→variable (sum) messages"""
        return {k: v for k, v in self.messages.items() if v.message_type == 'sum'}
    
    def get_all_product_messages(self) -> Dict[Tuple[str, str], Message]:
        """Get all variable→factor (product) messages"""
        return {k: v for k, v in self.messages.items() if v.message_type == 'product'}
    
    def reset(self, init_type: str = 'uniform'):
        """Reset all messages to initial values"""
        self._initialize_messages(init_type)


def run_bp(factor_graph: FactorGraph, 
           max_iterations: int = 100,
           tolerance: float = 1e-6,
           damping: float = 0.0) -> BPResult:
    """
    Convenience function to run BP on a factor graph.
    
    Args:
        factor_graph: The factor graph
        max_iterations: Maximum iterations
        tolerance: Convergence tolerance
        damping: Damping factor for loopy graphs
        
    Returns:
        BPResult with marginals
    """
    bp = BeliefPropagation(factor_graph)
    return bp.run(max_iterations=max_iterations, tolerance=tolerance, damping=damping)
