"""
CRN Simulator Module

Simulates chemical reaction networks using mass-action kinetics,
following the ODE formulation in Napp & Adams (2013).

Mass Action Kinetics (Eq. 9):
d[Z_m]/dt = Σ_q κ_q · Π_{m'} [Z_{m'}]^{r^q_{m'}} · (p^q_m - r^q_m)

At steady state, the concentrations of marginal belief species
correspond to the BP marginals (up to normalization).
"""

from typing import Dict, List, Tuple, Optional, Callable
from dataclasses import dataclass, field
import numpy as np
from scipy.integrate import solve_ivp
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from .crn_compiler import ChemicalReactionNetwork, Species, Reaction


@dataclass
class SimulationResult:
    """
    Result of CRN simulation.
    
    Attributes:
        times: Array of time points
        concentrations: Dict mapping species name to concentration trajectory
        final_concentrations: Dict of final concentrations
        marginals: Dict mapping variable name to marginal distribution
        converged: Whether simulation reached steady state
    """
    times: np.ndarray
    concentrations: Dict[str, np.ndarray]
    final_concentrations: Dict[str, float]
    marginals: Dict[str, np.ndarray]
    converged: bool
    
    def get_concentration_at(self, species_name: str, time_idx: int = -1) -> float:
        """Get concentration of a species at a given time index."""
        return self.concentrations[species_name][time_idx]
    
    def get_marginal(self, var_name: str) -> np.ndarray:
        """Get the marginal distribution for a variable."""
        return self.marginals.get(var_name, np.array([]))
    
    def summary(self) -> str:
        """Get a summary of the simulation results."""
        lines = [
            f"SimulationResult:",
            f"  Time points: {len(self.times)}",
            f"  Final time: {self.times[-1]:.2f}",
            f"  Converged: {self.converged}",
            f"  Marginals:",
        ]
        for var, marg in self.marginals.items():
            lines.append(f"    P({var}) = {np.round(marg, 4)}")
        return "\n".join(lines)


class CRNSimulator:
    """
    Simulates a CRN using mass-action kinetics ODEs.
    
    The simulator:
    1. Builds the ODE system from the reaction network
    2. Integrates using scipy's solve_ivp
    3. Extracts marginal beliefs from steady-state concentrations
    """
    
    def __init__(self, crn: ChemicalReactionNetwork):
        """
        Initialize the simulator.
        
        Args:
            crn: The chemical reaction network to simulate
        """
        self.crn = crn
        
        # Build species index mapping
        self.species_names = list(crn.species.keys())
        self.species_idx = {name: i for i, name in enumerate(self.species_names)}
        self.n_species = len(self.species_names)
        
        # Precompute reaction data for efficient ODE evaluation
        self._precompute_reaction_data()
    
    def _precompute_reaction_data(self):
        """Precompute reaction data for fast ODE evaluation."""
        self.reaction_data = []
        
        for reaction in self.crn.reactions:
            # Get reactant indices and coefficients
            reactant_indices = []
            reactant_coeffs = []
            for name, coeff in reaction.reactants.items():
                if coeff > 0:
                    reactant_indices.append(self.species_idx[name])
                    reactant_coeffs.append(coeff)
            
            # Compute net change (p - r) for each species
            net_change = np.zeros(self.n_species)
            for name, coeff in reaction.products.items():
                net_change[self.species_idx[name]] += coeff
            for name, coeff in reaction.reactants.items():
                net_change[self.species_idx[name]] -= coeff
            
            self.reaction_data.append({
                'rate': reaction.rate_constant,
                'reactant_indices': reactant_indices,
                'reactant_coeffs': reactant_coeffs,
                'net_change': net_change
            })
    
    def _ode_rhs(self, t: float, y: np.ndarray) -> np.ndarray:
        """
        Compute the right-hand side of the ODE system.
        
        d[Z_m]/dt = Σ_q κ_q · Π_{m'} [Z_{m'}]^{r^q_{m'}} · (p^q_m - r^q_m)
        
        Args:
            t: Time (unused but required by solve_ivp)
            y: Current concentrations
            
        Returns:
            Time derivatives of concentrations
        """
        dydt = np.zeros(self.n_species)
        
        for rxn_data in self.reaction_data:
            # Compute reaction rate: κ · Π [Z]^r
            rate = rxn_data['rate']
            for idx, coeff in zip(rxn_data['reactant_indices'], 
                                   rxn_data['reactant_coeffs']):
                # Avoid negative concentrations causing issues
                conc = max(y[idx], 0.0)
                rate *= conc ** coeff
            
            # Add contribution to each species
            dydt += rate * rxn_data['net_change']
        
        return dydt
    
    def simulate(self, t_end: float = 1000.0, n_points: int = 1000,
                 method: str = 'LSODA', rtol: float = 1e-6, 
                 atol: float = 1e-9) -> SimulationResult:
        """
        Simulate the CRN to steady state.
        
        Args:
            t_end: End time for simulation
            n_points: Number of output time points
            method: ODE solver method ('LSODA', 'RK45', 'BDF', etc.)
            rtol: Relative tolerance
            atol: Absolute tolerance
            
        Returns:
            SimulationResult with trajectories and marginals
        """
        # Get initial concentrations
        y0 = np.array([self.crn.species[name].initial_concentration 
                       for name in self.species_names])
        
        # Time points for output
        t_eval = np.linspace(0, t_end, n_points)
        
        # Solve ODE system
        sol = solve_ivp(
            self._ode_rhs,
            [0, t_end],
            y0,
            method=method,
            t_eval=t_eval,
            rtol=rtol,
            atol=atol
        )
        
        # Extract results
        times = sol.t
        concentrations = {name: sol.y[i] for i, name in enumerate(self.species_names)}
        final_concentrations = {name: sol.y[i, -1] for i, name in enumerate(self.species_names)}
        
        # Compute marginals from final concentrations
        marginals = self._compute_marginals(final_concentrations)
        
        # Check convergence (compare last two time windows)
        converged = self._check_convergence(sol.y)
        
        return SimulationResult(
            times=times,
            concentrations=concentrations,
            final_concentrations=final_concentrations,
            marginals=marginals,
            converged=converged
        )
    
    def _compute_marginals(self, concentrations: Dict[str, float]) -> Dict[str, np.ndarray]:
        """
        Compute normalized marginal distributions from concentrations.
        
        For each variable, the marginal P(x=k) is proportional to
        the concentration of the marginal species P^n_k.
        """
        marginals = {}
        
        # Group marginal species by variable
        var_marginals = {}
        for (var_name, k), species_name in self.crn.marginal_species.items():
            if k == 0:
                continue  # Skip unassigned
            if var_name not in var_marginals:
                var_marginals[var_name] = []
            var_marginals[var_name].append((k, concentrations[species_name]))
        
        # Normalize each variable's marginal
        for var_name, values in var_marginals.items():
            values.sort(key=lambda x: x[0])  # Sort by k
            concs = np.array([v for k, v in values])
            
            # Normalize (handle zero sum)
            total = concs.sum()
            if total > 1e-10:
                marginals[var_name] = concs / total
            else:
                # Uniform if all zeros
                marginals[var_name] = np.ones(len(concs)) / len(concs)
        
        return marginals
    
    def _check_convergence(self, y: np.ndarray, window_frac: float = 0.1,
                           tol: float = 1e-4) -> bool:
        """
        Check if the simulation has converged to steady state.
        
        Compares the final window to the previous window.
        """
        n_points = y.shape[1]
        window_size = max(int(n_points * window_frac), 10)
        
        if n_points < 2 * window_size:
            return False
        
        # Average concentrations in last two windows
        final_window = y[:, -window_size:].mean(axis=1)
        prev_window = y[:, -2*window_size:-window_size].mean(axis=1)
        
        # Relative change
        denom = np.maximum(np.abs(final_window), 1e-10)
        rel_change = np.abs(final_window - prev_window) / denom
        
        return rel_change.max() < tol
    
    def get_belief_trajectories(self, result: SimulationResult,
                                 var_name: str) -> Dict[int, np.ndarray]:
        """
        Get the belief trajectories for a specific variable.
        
        Returns dict mapping k to concentration trajectory of P^n_k.
        """
        trajectories = {}
        
        for (vn, k), species_name in self.crn.marginal_species.items():
            if vn == var_name and k > 0:
                trajectories[k] = result.concentrations[species_name]
        
        return trajectories


def simulate_crn(crn: ChemicalReactionNetwork, t_end: float = 1000.0,
                 n_points: int = 1000, **kwargs) -> SimulationResult:
    """
    Convenience function to simulate a CRN.
    
    Args:
        crn: The chemical reaction network
        t_end: Simulation end time
        n_points: Number of output time points
        **kwargs: Additional arguments passed to CRNSimulator.simulate
        
    Returns:
        SimulationResult
    """
    simulator = CRNSimulator(crn)
    return simulator.simulate(t_end=t_end, n_points=n_points, **kwargs)
