"""
CRN Compiler Module

Compiles factor graphs into chemical reaction networks following
Napp & Adams (2013) "Message Passing Inference with Chemical Reaction Networks".

The compilation creates:
1. Species for each message component and marginal belief
2. Recycling reactions to return mass to unassigned state
3. Sum message reactions (factor → variable, Eq. 12)
4. Product message reactions (variable → factor, Eq. 16)
5. Marginal belief reactions (Eq. 20)

Key notation from the paper:
- S^{j→n}_k: Sum message from factor j to variable n, component k
- P^{n→j}_k: Product message from variable n to factor j, component k  
- P^n_k: Marginal belief for variable n = k
- κ_r: Recycling rate constant
- κ_prod: Product message rate constant
- ψ_j(k^j): Factor table value (used as rate constant for sum messages)
"""

from typing import List, Dict, Tuple, Set, Optional, Union
from dataclasses import dataclass, field
from itertools import product
import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph


@dataclass
class Species:
    """
    A chemical species in the CRN.
    
    Attributes:
        name: Unique identifier
        species_type: 'sum_msg', 'prod_msg', 'marginal', or 'unassigned'
        description: Human-readable description
        initial_concentration: Starting concentration
    """
    name: str
    species_type: str
    description: str = ""
    initial_concentration: float = 1.0
    
    def __hash__(self):
        return hash(self.name)
    
    def __eq__(self, other):
        if isinstance(other, Species):
            return self.name == other.name
        return False
    
    def __repr__(self):
        return f"Species({self.name})"


@dataclass
class Reaction:
    """
    A chemical reaction in the CRN.
    
    Format: r1*A + r2*B + ... --κ--> p1*C + p2*D + ...
    
    Attributes:
        reactants: Dict mapping species name to stoichiometric coefficient
        products: Dict mapping species name to stoichiometric coefficient
        rate_constant: Reaction rate κ
        description: Human-readable description
    """
    reactants: Dict[str, int]
    products: Dict[str, int]
    rate_constant: float
    description: str = ""
    
    def __repr__(self):
        lhs = " + ".join(f"{c}*{s}" if c > 1 else s 
                         for s, c in self.reactants.items() if c > 0)
        rhs = " + ".join(f"{c}*{s}" if c > 1 else s 
                         for s, c in self.products.items() if c > 0)
        return f"{lhs} --{self.rate_constant:.4g}--> {rhs}"
    
    def get_order(self) -> int:
        """Get the reaction order (sum of reactant coefficients)."""
        return sum(self.reactants.values())
    
    def get_catalysts(self) -> Set[str]:
        """Get species that appear on both sides (catalysts)."""
        return set(self.reactants.keys()) & set(self.products.keys())


class ChemicalReactionNetwork:
    """
    A chemical reaction network compiled from a factor graph.
    
    Contains:
    - Species (message components, marginal beliefs)
    - Reactions (recycling, sum messages, product messages, marginal computation)
    - Rate constants (κ_r for recycling, κ_prod for product messages)
    """
    
    def __init__(self, name: str = "CRN"):
        self.name = name
        self.species: Dict[str, Species] = {}
        self.reactions: List[Reaction] = []
        
        # Rate constants
        self.kappa_r = 0.1  # Recycling rate
        self.kappa_prod = 50.0  # Product message rate
        
        # Conservation sets (species whose total concentration is conserved)
        self.conservation_sets: List[Set[str]] = []
        
        # Mapping from factor graph elements to species
        self.sum_msg_species: Dict[Tuple[str, str, int], str] = {}  # (factor, var, k) -> species
        self.prod_msg_species: Dict[Tuple[str, str, int], str] = {}  # (var, factor, k) -> species
        self.marginal_species: Dict[Tuple[str, int], str] = {}  # (var, k) -> species
    
    def add_species(self, species: Species):
        """Add a species to the network."""
        self.species[species.name] = species
    
    def add_reaction(self, reaction: Reaction):
        """Add a reaction to the network."""
        self.reactions.append(reaction)
    
    def get_species(self, name: str) -> Optional[Species]:
        """Get a species by name."""
        return self.species.get(name)
    
    def get_species_names(self) -> List[str]:
        """Get all species names."""
        return list(self.species.keys())
    
    def get_initial_concentrations(self) -> Dict[str, float]:
        """Get initial concentrations for all species."""
        return {name: sp.initial_concentration for name, sp in self.species.items()}
    
    def summary(self) -> str:
        """Get a summary of the CRN."""
        lines = [
            f"ChemicalReactionNetwork: {self.name}",
            f"  Species: {len(self.species)}",
            f"  Reactions: {len(self.reactions)}",
            f"  κ_r (recycling): {self.kappa_r}",
            f"  κ_prod (product): {self.kappa_prod}",
        ]
        
        # Count reaction types
        recycling = sum(1 for r in self.reactions if 'recycling' in r.description.lower())
        sum_msg = sum(1 for r in self.reactions if 'sum' in r.description.lower())
        prod_msg = sum(1 for r in self.reactions if 'product' in r.description.lower() 
                       or 'marginal' in r.description.lower())
        
        lines.append(f"  Reaction breakdown:")
        lines.append(f"    Recycling: {recycling}")
        lines.append(f"    Sum message: {sum_msg}")
        lines.append(f"    Product/Marginal: {prod_msg}")
        
        return "\n".join(lines)
    
    def __repr__(self):
        return f"CRN({self.name}, {len(self.species)} species, {len(self.reactions)} reactions)"


class CRNCompiler:
    """
    Compiles a factor graph into a chemical reaction network.
    
    Following Napp & Adams (2013), the compilation creates:
    
    1. Belief species sets:
       - For each sum message S^{j→n}: species S^{j→n}_0, S^{j→n}_1, ..., S^{j→n}_{K_n}
       - For each product message P^{n→j}: species P^{n→j}_0, ..., P^{n→j}_{K_n}
       - For each marginal P^n: species P^n_0, ..., P^n_{K_n}
       - Index 0 represents "unassigned" probability mass
    
    2. Recycling reactions (Eq. 11):
       S^{j→n}_k --κ_r--> S^{j→n}_0  for all k > 0
       
    3. Sum message reactions (Eq. 12):
       S^{j→n}_0 + Σ_{n'∈ne(j)\\n} P^{n'→j}_{k^j_{n'}} --ψ_j(k^j)--> 
       S^{j→n}_k + Σ_{n'∈ne(j)\\n} P^{n'→j}_{k^j_{n'}}
       
    4. Product message reactions (Eq. 16):
       P^{n→j}_0 + Σ_{j'∈ne(n)\\j} S^{j'→n}_k --κ_prod-->
       P^{n→j}_k + Σ_{j'∈ne(n)\\j} S^{j'→n}_k
    """
    
    def __init__(self, factor_graph: FactorGraph, kappa_r: float = 0.1, 
                 kappa_prod: float = 50.0):
        """
        Initialize the compiler.
        
        Args:
            factor_graph: The factor graph to compile
            kappa_r: Recycling rate constant (smaller = more accurate but slower)
            kappa_prod: Product message rate constant
        """
        self.fg = factor_graph
        self.kappa_r = kappa_r
        self.kappa_prod = kappa_prod
        
        self.crn = ChemicalReactionNetwork(f"CRN_{factor_graph.name}")
        self.crn.kappa_r = kappa_r
        self.crn.kappa_prod = kappa_prod
    
    def compile(self) -> ChemicalReactionNetwork:
        """
        Compile the factor graph to a CRN.
        
        Returns:
            The compiled ChemicalReactionNetwork
        """
        # Step 1: Create all species
        self._create_species()
        
        # Step 2: Create recycling reactions
        self._create_recycling_reactions()
        
        # Step 3: Create sum message reactions
        self._create_sum_message_reactions()
        
        # Step 4: Create product message reactions
        self._create_product_message_reactions()
        
        # Step 5: Create marginal belief reactions
        self._create_marginal_reactions()
        
        return self.crn
    
    def _create_species(self):
        """Create all species for the CRN."""
        
        # For each edge (factor j, variable n), create sum message species
        for factor in self.fg.factors:
            for var in factor.variables:
                # Sum message S^{j→n}_k for k = 0, 1, ..., K_n
                for k in range(var.cardinality + 1):  # +1 for unassigned (k=0)
                    name = f"S_{factor.name}_to_{var.name}_{k}"
                    species_type = "unassigned" if k == 0 else "sum_msg"
                    desc = f"Sum msg {factor.name}→{var.name}, k={k}"
                    
                    # Initialize uniformly across assigned states
                    # Following Napp & Adams: set conserved sum to 1, distribute equally
                    if k == 0:
                        init_conc = 0.0  # No unassigned mass initially
                    else:
                        init_conc = 1.0 / var.cardinality  # Uniform over states
                    
                    sp = Species(name, species_type, desc, init_conc)
                    self.crn.add_species(sp)
                    self.crn.sum_msg_species[(factor.name, var.name, k)] = name
        
        # For each edge (variable n, factor j), create product message species
        for var in self.fg.variables:
            for factor in self.fg.neighbors_of_variable(var):
                # Product message P^{n→j}_k for k = 0, 1, ..., K_n
                for k in range(var.cardinality + 1):
                    name = f"P_{var.name}_to_{factor.name}_{k}"
                    species_type = "unassigned" if k == 0 else "prod_msg"
                    desc = f"Prod msg {var.name}→{factor.name}, k={k}"
                    
                    # Initialize uniformly
                    if k == 0:
                        init_conc = 0.0
                    else:
                        init_conc = 1.0 / var.cardinality
                    
                    sp = Species(name, species_type, desc, init_conc)
                    self.crn.add_species(sp)
                    self.crn.prod_msg_species[(var.name, factor.name, k)] = name
        
        # For each variable, create marginal belief species
        for var in self.fg.variables:
            for k in range(var.cardinality + 1):
                name = f"Marginal_{var.name}_{k}"
                species_type = "unassigned" if k == 0 else "marginal"
                desc = f"P({var.name}={k})" if k > 0 else f"Unassigned P({var.name})"
                
                # Initialize uniformly
                if k == 0:
                    init_conc = 0.0
                else:
                    init_conc = 1.0 / var.cardinality
                
                sp = Species(name, species_type, desc, init_conc)
                self.crn.add_species(sp)
                self.crn.marginal_species[(var.name, k)] = name
    
    def _create_recycling_reactions(self):
        """
        Create recycling reactions (Eq. 11).
        
        For each belief species set, create reactions that return
        assigned mass to the unassigned state:
        
        S^{j→n}_k --κ_r--> S^{j→n}_0  for k > 0
        """
        # Recycle sum messages
        for factor in self.fg.factors:
            for var in factor.variables:
                unassigned = self.crn.sum_msg_species[(factor.name, var.name, 0)]
                for k in range(1, var.cardinality + 1):
                    assigned = self.crn.sum_msg_species[(factor.name, var.name, k)]
                    
                    reaction = Reaction(
                        reactants={assigned: 1},
                        products={unassigned: 1},
                        rate_constant=self.kappa_r,
                        description=f"Recycling S_{factor.name}→{var.name}[{k}]"
                    )
                    self.crn.add_reaction(reaction)
        
        # Recycle product messages
        for var in self.fg.variables:
            for factor in self.fg.neighbors_of_variable(var):
                unassigned = self.crn.prod_msg_species[(var.name, factor.name, 0)]
                for k in range(1, var.cardinality + 1):
                    assigned = self.crn.prod_msg_species[(var.name, factor.name, k)]
                    
                    reaction = Reaction(
                        reactants={assigned: 1},
                        products={unassigned: 1},
                        rate_constant=self.kappa_r,
                        description=f"Recycling P_{var.name}→{factor.name}[{k}]"
                    )
                    self.crn.add_reaction(reaction)
        
        # Recycle marginal beliefs
        for var in self.fg.variables:
            unassigned = self.crn.marginal_species[(var.name, 0)]
            for k in range(1, var.cardinality + 1):
                assigned = self.crn.marginal_species[(var.name, k)]
                
                reaction = Reaction(
                    reactants={assigned: 1},
                    products={unassigned: 1},
                    rate_constant=self.kappa_r,
                    description=f"Recycling Marginal_{var.name}[{k}]"
                )
                self.crn.add_reaction(reaction)
    
    def _create_sum_message_reactions(self):
        """
        Create sum message reactions (Eq. 12).
        
        For each factor j and variable n in scope(j):
        For each k in {1, ..., K_n}:
        For each k^j in K^j with k^j_n = k:
        
        S^{j→n}_0 + Σ_{n'∈ne(j)\\n} P^{n'→j}_{k^j_{n'}} --ψ_j(k^j)-->
        S^{j→n}_k + Σ_{n'∈ne(j)\\n} P^{n'→j}_{k^j_{n'}}
        
        The product message species are catalysts (appear on both sides).
        """
        for factor in self.fg.factors:
            for target_var in factor.variables:
                # Other variables in this factor (ne(j) \ n)
                other_vars = [v for v in factor.variables if v != target_var]
                
                # For each value k of target_var
                for k in range(1, target_var.cardinality + 1):
                    # For each configuration of other variables
                    if other_vars:
                        other_ranges = [range(1, v.cardinality + 1) for v in other_vars]
                        for other_config in product(*other_ranges):
                            self._add_sum_reaction(factor, target_var, k, 
                                                   other_vars, other_config)
                    else:
                        # Unary factor - no catalysts needed
                        self._add_sum_reaction(factor, target_var, k, [], ())
    
    def _add_sum_reaction(self, factor: Factor, target_var: Variable, 
                          target_k: int, other_vars: List[Variable],
                          other_config: Tuple[int, ...]):
        """Add a single sum message reaction."""
        
        # Build the full configuration for the factor table lookup
        # Need to map variable values to factor table indices
        full_config = {}
        full_config[target_var] = target_var.index_to_value(target_k - 1)  # k is 1-indexed
        
        for v, k in zip(other_vars, other_config):
            full_config[v] = v.index_to_value(k - 1)
        
        # Get rate constant from factor table
        rate = factor.get_value(full_config)
        
        if rate <= 0:
            return  # Skip zero-rate reactions
        
        # Build reactants and products
        reactants = {}
        products = {}
        
        # S^{j→n}_0 → S^{j→n}_k
        unassigned = self.crn.sum_msg_species[(factor.name, target_var.name, 0)]
        assigned = self.crn.sum_msg_species[(factor.name, target_var.name, target_k)]
        
        reactants[unassigned] = 1
        products[assigned] = 1
        
        # Catalysts: P^{n'→j}_{k'} for each other variable
        for v, k in zip(other_vars, other_config):
            catalyst = self.crn.prod_msg_species[(v.name, factor.name, k)]
            reactants[catalyst] = reactants.get(catalyst, 0) + 1
            products[catalyst] = products.get(catalyst, 0) + 1
        
        desc = f"Sum msg {factor.name}→{target_var.name}[{target_k}], config={other_config}"
        
        reaction = Reaction(
            reactants=reactants,
            products=products,
            rate_constant=rate,
            description=desc
        )
        self.crn.add_reaction(reaction)
    
    def _create_product_message_reactions(self):
        """
        Create product message reactions (Eq. 16).
        
        For each variable n and factor j in ne(n):
        For each k in {1, ..., K_n}:
        
        P^{n→j}_0 + Σ_{j'∈ne(n)\\j} S^{j'→n}_k --κ_prod-->
        P^{n→j}_k + Σ_{j'∈ne(n)\\j} S^{j'→n}_k
        
        All sum message species with the same k value are catalysts.
        """
        for var in self.fg.variables:
            factors_of_var = self.fg.neighbors_of_variable(var)
            
            for target_factor in factors_of_var:
                # Other factors (ne(n) \ j)
                other_factors = [f for f in factors_of_var if f != target_factor]
                
                for k in range(1, var.cardinality + 1):
                    # Build reactants and products
                    reactants = {}
                    products = {}
                    
                    # P^{n→j}_0 → P^{n→j}_k
                    unassigned = self.crn.prod_msg_species[(var.name, target_factor.name, 0)]
                    assigned = self.crn.prod_msg_species[(var.name, target_factor.name, k)]
                    
                    reactants[unassigned] = 1
                    products[assigned] = 1
                    
                    # Catalysts: S^{j'→n}_k for each other factor
                    for f in other_factors:
                        catalyst = self.crn.sum_msg_species[(f.name, var.name, k)]
                        reactants[catalyst] = reactants.get(catalyst, 0) + 1
                        products[catalyst] = products.get(catalyst, 0) + 1
                    
                    desc = f"Prod msg {var.name}→{target_factor.name}[{k}]"
                    
                    reaction = Reaction(
                        reactants=reactants,
                        products=products,
                        rate_constant=self.kappa_prod,
                        description=desc
                    )
                    self.crn.add_reaction(reaction)
    
    def _create_marginal_reactions(self):
        """
        Create marginal belief reactions (Eq. 20).
        
        For each variable n:
        For each k in {1, ..., K_n}:
        
        P^n_0 + Σ_{j∈ne(n)} S^{j→n}_k --κ_prod--> P^n_k + Σ_{j∈ne(n)} S^{j→n}_k
        
        All incoming sum message species with the same k are catalysts.
        """
        for var in self.fg.variables:
            factors_of_var = self.fg.neighbors_of_variable(var)
            
            for k in range(1, var.cardinality + 1):
                reactants = {}
                products = {}
                
                # P^n_0 → P^n_k
                unassigned = self.crn.marginal_species[(var.name, 0)]
                assigned = self.crn.marginal_species[(var.name, k)]
                
                reactants[unassigned] = 1
                products[assigned] = 1
                
                # Catalysts: all incoming sum messages S^{j→n}_k
                for f in factors_of_var:
                    catalyst = self.crn.sum_msg_species[(f.name, var.name, k)]
                    reactants[catalyst] = reactants.get(catalyst, 0) + 1
                    products[catalyst] = products.get(catalyst, 0) + 1
                
                desc = f"Marginal {var.name}[{k}]"
                
                reaction = Reaction(
                    reactants=reactants,
                    products=products,
                    rate_constant=self.kappa_prod,
                    description=desc
                )
                self.crn.add_reaction(reaction)


def compile_factor_graph_to_crn(fg: FactorGraph, kappa_r: float = 0.1,
                                 kappa_prod: float = 50.0) -> ChemicalReactionNetwork:
    """
    Convenience function to compile a factor graph to a CRN.
    
    Args:
        fg: Factor graph to compile
        kappa_r: Recycling rate (smaller = more accurate but slower)
        kappa_prod: Product message rate
        
    Returns:
        Compiled ChemicalReactionNetwork
    """
    compiler = CRNCompiler(fg, kappa_r, kappa_prod)
    return compiler.compile()
