"""
Napp-Style Trajectory Plots for CRN Simulations

Creates publication-quality plots showing:
- Belief species concentration trajectories over time
- Colored areas representing probability mass for each state
- Dotted lines showing exact marginal distributions
- White area at top representing unassigned probability mass

Following the style of Figure 3 in Napp & Adams (2013).
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from typing import List, Dict, Optional, Tuple, Any
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import FactorGraph
from crn import (
    ChemicalReactionNetwork,
    compile_factor_graph_to_crn,
    simulate_crn,
)
from crn.crn_simulator import SimulationResult


def get_belief_species_for_variable(crn: ChemicalReactionNetwork, 
                                     var_name: str) -> Dict[int, str]:
    """
    Get marginal belief species names for a variable.
    
    Returns dict mapping state k -> species name.
    """
    species_map = {}
    for name in crn.species:
        if name.startswith(f'Marginal_{var_name}_'):
            k = int(name.split('_')[-1])
            if k > 0:  # Skip k=0 (unassigned)
                species_map[k] = name
    return species_map


def extract_belief_trajectories(result: SimulationResult,
                                crn: ChemicalReactionNetwork,
                                var_name: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str]]:
    """
    Extract belief trajectories for a variable from simulation results.
    
    Returns:
        times: Array of time points
        beliefs: Array of shape (n_times, n_states) with normalized beliefs
        unassigned_frac: Array of unassigned fraction over time
        state_names: List of state names
    """
    species_map = get_belief_species_for_variable(crn, var_name)
    
    if not species_map:
        raise ValueError(f"No marginal species found for variable {var_name}")
    
    times = result.times
    n_times = len(times)
    n_states = len(species_map)
    
    # Get concentrations for each state (concentrations is a dict)
    raw_beliefs = np.zeros((n_times, n_states))
    state_names = []
    
    for k in sorted(species_map.keys()):
        sp_name = species_map[k]
        state_idx = k - 1  # Convert 1-indexed to 0-indexed
        
        # Get concentration trajectory from dict
        if sp_name in result.concentrations:
            raw_beliefs[:, state_idx] = result.concentrations[sp_name]
        state_names.append(f"k={k-1}")  # Display as 0-indexed
    
    # Also get unassigned species
    unassigned_name = f'Marginal_{var_name}_0'
    if unassigned_name in result.concentrations:
        unassigned = result.concentrations[unassigned_name]
    else:
        unassigned = np.zeros(n_times)
    
    # Normalize to get probabilities (sum including unassigned should be conserved)
    total = raw_beliefs.sum(axis=1) + unassigned
    # Avoid division by zero
    total = np.maximum(total, 1e-10)
    beliefs = raw_beliefs / total[:, np.newaxis]
    unassigned_frac = unassigned / total
    
    return times, beliefs, unassigned_frac, state_names


def plot_napp_style_trajectory(result: SimulationResult,
                               crn: ChemicalReactionNetwork,
                               var_name: str,
                               exact_marginal: Optional[np.ndarray] = None,
                               ax: Optional[plt.Axes] = None,
                               title: Optional[str] = None,
                               colors: Optional[List[str]] = None,
                               show_unassigned: bool = True) -> plt.Axes:
    """
    Create a Napp-style stacked area plot for belief trajectories.
    
    Args:
        result: Simulation result from simulate_crn
        crn: The chemical reaction network
        var_name: Name of variable to plot
        exact_marginal: Optional exact marginal distribution for comparison
        ax: Optional matplotlib axes to plot on
        title: Optional title for the plot
        colors: Optional list of colors for each state
        show_unassigned: Whether to show unassigned mass as white area
        
    Returns:
        The matplotlib axes object
    """
    times, beliefs, unassigned_frac, state_names = extract_belief_trajectories(
        result, crn, var_name)
    
    n_states = beliefs.shape[1]
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 4))
    
    # Default colors - use a colormap
    if colors is None:
        cmap = plt.cm.Set2 if n_states <= 8 else plt.cm.tab20
        colors = [cmap(i / n_states) for i in range(n_states)]
    
    # Create stacked area plot
    # Stack from bottom: state 0, state 1, ..., then unassigned at top
    y_stack = np.zeros((len(times), n_states + 1))
    y_stack[:, :n_states] = beliefs
    if show_unassigned:
        y_stack[:, n_states] = unassigned_frac
    
    # Cumulative sum for stacking
    y_cumsum = np.cumsum(y_stack, axis=1)
    
    # Plot stacked areas
    ax.fill_between(times, 0, y_cumsum[:, 0], 
                    color=colors[0], alpha=0.8, label=state_names[0])
    for i in range(1, n_states):
        ax.fill_between(times, y_cumsum[:, i-1], y_cumsum[:, i],
                        color=colors[i], alpha=0.8, label=state_names[i])
    
    # Unassigned area (white/light gray)
    if show_unassigned:
        ax.fill_between(times, y_cumsum[:, n_states-1], y_cumsum[:, n_states],
                        color='white', alpha=1.0, edgecolor='lightgray')
    
    # Plot exact marginal as horizontal dashed lines
    if exact_marginal is not None:
        cumsum_exact = np.cumsum(exact_marginal)
        for i, val in enumerate(cumsum_exact):
            ax.axhline(y=val, color='black', linestyle='--', linewidth=1.5, alpha=0.7)
    
    ax.set_xlim(times[0], times[-1])
    ax.set_ylim(0, 1.0)
    ax.set_xlabel('Time (sec)')
    ax.set_ylabel('Probability')
    
    if title:
        ax.set_title(title)
    else:
        ax.set_title(f'P({var_name})')
    
    # Legend
    handles = [Patch(facecolor=colors[i], alpha=0.8, label=state_names[i]) 
               for i in range(n_states)]
    if exact_marginal is not None:
        handles.append(plt.Line2D([0], [0], color='black', linestyle='--', 
                                  label='Exact marginal'))
    ax.legend(handles=handles, loc='upper right', fontsize=8)
    
    return ax


def plot_multiple_variables(result: SimulationResult,
                           crn: ChemicalReactionNetwork,
                           var_names: List[str],
                           exact_marginals: Optional[Dict[str, np.ndarray]] = None,
                           figsize: Optional[Tuple[int, int]] = None,
                           suptitle: Optional[str] = None) -> plt.Figure:
    """
    Create Napp-style plots for multiple variables in a single figure.
    
    Args:
        result: Simulation result
        crn: Chemical reaction network
        var_names: List of variable names to plot
        exact_marginals: Optional dict mapping var name -> exact marginal
        figsize: Optional figure size
        suptitle: Optional super title
        
    Returns:
        The matplotlib figure
    """
    n_vars = len(var_names)
    
    if figsize is None:
        figsize = (5 * n_vars, 4)
    
    fig, axes = plt.subplots(1, n_vars, figsize=figsize)
    
    if n_vars == 1:
        axes = [axes]
    
    for i, var_name in enumerate(var_names):
        exact = exact_marginals.get(var_name) if exact_marginals else None
        plot_napp_style_trajectory(result, crn, var_name, 
                                   exact_marginal=exact, ax=axes[i])
    
    if suptitle:
        fig.suptitle(suptitle, fontsize=14)
    
    plt.tight_layout()
    return fig


def plot_reduction_comparison(fg: FactorGraph,
                              bp_marginals: Dict[str, np.ndarray],
                              t_end: float = 5000,
                              kappa_r: float = 0.02,
                              figsize: Tuple[int, int] = (14, 8),
                              output_path: Optional[str] = None) -> plt.Figure:
    """
    Create a comparison figure showing:
    - Original CRN trajectories
    - Reduced CRN trajectories
    - Exact BP marginals
    
    Args:
        fg: Factor graph
        bp_marginals: Dict of exact BP marginals for each variable
        t_end: Simulation end time
        kappa_r: Recycling rate
        figsize: Figure size
        output_path: Optional path to save figure
        
    Returns:
        The matplotlib figure
    """
    from crn import reduce_crn_to_core
    
    # Compile and simulate original CRN
    crn_orig = compile_factor_graph_to_crn(fg, kappa_r=kappa_r)
    result_orig = simulate_crn(crn_orig, t_end=t_end, n_points=200)
    
    # Get variable names
    var_names = [v.name for v in fg.variables]
    n_vars = len(var_names)
    
    # Reduce CRN
    crn_reduced, steps = reduce_crn_to_core(crn_orig, copy=True, mode='structural')
    
    # Find surviving variables
    surviving_vars = set()
    for name in crn_reduced.species:
        if name.startswith('Marginal_'):
            var = name.split('_')[1]
            surviving_vars.add(var)
    surviving_vars = sorted(surviving_vars)
    
    # Create figure
    n_cols = max(n_vars, len(surviving_vars))
    fig, axes = plt.subplots(2, n_cols, figsize=figsize)
    
    if n_cols == 1:
        axes = axes.reshape(2, 1)
    
    # Top row: Original CRN
    for i, var_name in enumerate(var_names):
        exact = bp_marginals.get(var_name)
        plot_napp_style_trajectory(result_orig, crn_orig, var_name,
                                   exact_marginal=exact, ax=axes[0, i],
                                   title=f'Original: P({var_name})')
    
    # Hide unused axes in top row
    for i in range(n_vars, n_cols):
        axes[0, i].axis('off')
    
    # Bottom row: Reduced CRN
    if len(surviving_vars) > 0 and len(crn_reduced.reactions) > 0:
        try:
            result_reduced = simulate_crn(crn_reduced, t_end=t_end, n_points=200)
            
            for i, var_name in enumerate(surviving_vars):
                exact = bp_marginals.get(var_name)
                plot_napp_style_trajectory(result_reduced, crn_reduced, var_name,
                                           exact_marginal=exact, ax=axes[1, i],
                                           title=f'Reduced: P({var_name})')
        except Exception as e:
            axes[1, 0].text(0.5, 0.5, f'Simulation error:\n{e}', 
                          ha='center', va='center', transform=axes[1, 0].transAxes)
    else:
        axes[1, 0].text(0.5, 0.5, 'Reduced to trivial\n(no message passing needed)', 
                       ha='center', va='center', transform=axes[1, 0].transAxes,
                       fontsize=12)
    
    # Hide unused axes in bottom row
    for i in range(max(1, len(surviving_vars)), n_cols):
        axes[1, i].axis('off')
    
    fig.suptitle(f'CRN Trajectories: Original ({len(crn_orig.species)} species) vs '
                 f'Reduced ({len(crn_reduced.species)} species)', fontsize=14)
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"Saved to {output_path}")
    
    return fig


def plot_kappa_r_comparison(fg: FactorGraph,
                            bp_marginals: Dict[str, np.ndarray],
                            kappa_r_values: List[float] = [0.1, 0.01],
                            t_end: float = 3000,
                            var_to_plot: Optional[str] = None,
                            output_path: Optional[str] = None) -> plt.Figure:
    """
    Recreate Napp Figure 3 style: compare slow vs fast κ_r.
    
    Shows the tradeoff between speed and accuracy:
    - Higher κ_r = faster dynamics but more unassigned mass
    - Lower κ_r = slower dynamics but more accurate (less unassigned)
    
    Args:
        fg: Factor graph
        bp_marginals: Exact BP marginals
        kappa_r_values: List of κ_r values to compare
        t_end: Simulation end time
        var_to_plot: Variable to plot (defaults to first)
        output_path: Optional path to save
        
    Returns:
        The matplotlib figure
    """
    if var_to_plot is None:
        var_to_plot = fg.variables[0].name
    
    n_kappa = len(kappa_r_values)
    fig, axes = plt.subplots(n_kappa, 1, figsize=(10, 3 * n_kappa))
    
    if n_kappa == 1:
        axes = [axes]
    
    exact = bp_marginals.get(var_to_plot)
    
    for i, kappa_r in enumerate(kappa_r_values):
        crn = compile_factor_graph_to_crn(fg, kappa_r=kappa_r)
        result = simulate_crn(crn, t_end=t_end, n_points=200)
        
        plot_napp_style_trajectory(result, crn, var_to_plot,
                                   exact_marginal=exact, ax=axes[i],
                                   title=f'κ_r = {kappa_r}')
        
        # Add annotation about speed vs accuracy
        if kappa_r >= 0.1:
            axes[i].annotate('Fast (more unassigned)', xy=(0.02, 0.98), 
                           xycoords='axes fraction', va='top', fontsize=9)
        else:
            axes[i].annotate('Slow (more accurate)', xy=(0.02, 0.98),
                           xycoords='axes fraction', va='top', fontsize=9)
    
    fig.suptitle(f'Effect of κ_r on P({var_to_plot})', fontsize=14)
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"Saved to {output_path}")
    
    return fig


if __name__ == "__main__":
    # Demo: Create Napp-style plots for a simple example
    from core import Variable, Factor, FactorGraph
    from inference import run_bp
    
    print("Creating Napp-style trajectory plots...")
    
    # Create a simple chain
    fg = FactorGraph("chain")
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    fg.add_factor(Factor("f12", [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    fg.add_factor(Factor("f23", [x2, x3], np.array([[0.7, 0.3], [0.4, 0.6]])))
    fg.add_factor(Factor("u1", [x1], np.array([0.8, 0.2])))
    fg.add_factor(Factor("u3", [x3], np.array([0.3, 0.7])))
    
    # Get exact marginals
    bp_result = run_bp(fg)
    bp_marginals = {v.name: bp_result.get_marginal(v.name) for v in fg.variables}
    
    print(f"BP Marginals:")
    for var, marg in bp_marginals.items():
        print(f"  P({var}) = {marg}")
    
    # Create plots
    os.makedirs("/home/mauwork/factor_graph_project/results/plots", exist_ok=True)
    
    # Single variable trajectory
    crn = compile_factor_graph_to_crn(fg, kappa_r=0.02)
    result = simulate_crn(crn, t_end=5000, n_points=200)
    
    fig, ax = plt.subplots(figsize=(10, 4))
    plot_napp_style_trajectory(result, crn, "x2", exact_marginal=bp_marginals["x2"], ax=ax)
    plt.savefig("/home/mauwork/factor_graph_project/results/plots/napp_single_var.png", dpi=150)
    plt.close()
    print("Saved: napp_single_var.png")
    
    # Multiple variables
    fig = plot_multiple_variables(result, crn, ["x1", "x2", "x3"], 
                                  exact_marginals=bp_marginals,
                                  suptitle="Chain Factor Graph: Belief Trajectories")
    plt.savefig("/home/mauwork/factor_graph_project/results/plots/napp_multi_var.png", dpi=150)
    plt.close()
    print("Saved: napp_multi_var.png")
    
    # κ_r comparison
    fig = plot_kappa_r_comparison(fg, bp_marginals, kappa_r_values=[0.1, 0.01],
                                   var_to_plot="x2",
                                   output_path="/home/mauwork/factor_graph_project/results/plots/napp_kappa_comparison.png")
    plt.close()
    
    # Reduction comparison
    fig = plot_reduction_comparison(fg, bp_marginals, t_end=5000,
                                    output_path="/home/mauwork/factor_graph_project/results/plots/napp_reduction_comparison.png")
    plt.close()
    
    print("\nAll plots saved to results/plots/")
