"""
ICML-Ready Plotting for SP-B Reduction Benchmarks

Publication-quality figures designed for quick comprehension:
1. Headline Gains - Big numbers showing key wins
2. Simulation Speedup - Clean scatter with log-log scaling
3. Tendril Scaling - Median + IQR showing convergence
4. CRN Trajectory Comparison - Large, readable, blue/orange

Design principles:
- No internal titles (LaTeX captions carry narrative)
- Consistent font sizes (labels: 12-13, ticks: 10-11, legend: 10-11)
- Minimal gridlines (alpha=0.2 or disabled)
- Color-blind friendly palette (blue/orange/gray)
- tight_layout or constrained_layout
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Dict, Any, Optional, Tuple
import csv
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# =============================================================================
# Color Palette (color-blind friendly)
# =============================================================================
BLUE = "#1f77b4"      # Original
ORANGE = "#ff7f0e"    # Reduced
RED = "#d62728"       # Thresholds/reference
GRAY = "#7f7f7f"      # Diagonals/annotations
BLACK = "#000000"     # Target lines

# Font sizes
LABEL_SIZE = 12
TICK_SIZE = 10
LEGEND_SIZE = 10
BIG_NUMBER_SIZE = 48
METRIC_LABEL_SIZE = 13


def load_results_from_csv(filename: str) -> List[Dict[str, Any]]:
    """Load benchmark results from CSV file."""
    results = []
    with open(filename, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            # Convert numeric fields
            for key in row:
                if key != 'name':
                    try:
                        val = row[key]
                        if val == '' or val is None:
                            row[key] = float('nan')
                        elif '.' in str(val) or 'e' in str(val).lower():
                            row[key] = float(val)
                        else:
                            row[key] = int(val)
                    except (ValueError, TypeError):
                        row[key] = float('nan')
            results.append(row)
    return results


def categorize_results(results: List[Dict]) -> Dict[str, List[Dict]]:
    """Categorize results by graph family."""
    categories = {
        'chain': [],
        'tree': [],
        'loopy': [],
        'grid': [],
        'random': []
    }
    
    for r in results:
        name = r['name']
        if name.startswith('chain'):
            categories['chain'].append(r)
        elif name.startswith('tree'):
            categories['tree'].append(r)
        elif name.startswith('loopy'):
            categories['loopy'].append(r)
        elif name.startswith('grid'):
            categories['grid'].append(r)
        elif name.startswith('random'):
            categories['random'].append(r)
    
    return categories


# =============================================================================
# Figure 1: Headline Gains (Big Numbers Only)
# =============================================================================

def plot_headline_gains(results: List[Dict], output_dir: str):
    """
    Create clean headline gains figure with just 3-4 large numbers.
    
    Shows:
    - Average variable reduction (%)
    - Median compile speedup (x)
    - Median simulation speedup (x)
    - Max marginal difference
    """
    # Compute statistics with NaN filtering
    valid = [r for r in results if r.get('orig_vars', 0) > 0]
    
    # Variable reduction
    var_ratios = []
    for r in valid:
        if r['orig_vars'] > 0:
            var_ratios.append(r['reduced_vars'] / r['orig_vars'])
    avg_var_reduction = 1 - np.mean(var_ratios) if var_ratios else 0
    
    # Compile speedup
    compile_speedups = []
    for r in valid:
        if r.get('reduced_compile_time', 0) > 0 and r.get('orig_compile_time', 0) > 0:
            speedup = r['orig_compile_time'] / r['reduced_compile_time']
            if np.isfinite(speedup) and speedup > 0:
                compile_speedups.append(speedup)
    median_compile_speedup = np.median(compile_speedups) if compile_speedups else 1.0
    
    # Simulation speedup
    sim_speedups = []
    for r in valid:
        if r.get('reduced_sim_time', 0) > 0 and r.get('orig_sim_time', 0) > 0:
            speedup = r['orig_sim_time'] / r['reduced_sim_time']
            if np.isfinite(speedup) and speedup > 0:
                sim_speedups.append(speedup)
    median_sim_speedup = np.median(sim_speedups) if sim_speedups else 1.0
    
    # Correctness (marginal differences)
    diffs = []
    for r in valid:
        d = r.get('marginal_max_diff', float('nan'))
        if np.isfinite(d):
            diffs.append(d)
    max_diff = max(diffs) if diffs else 0
    median_diff = np.median(diffs) if diffs else 0
    
    # Create figure: 1x4 for headline numbers
    fig, axes = plt.subplots(1, 4, figsize=(14, 3.5), constrained_layout=True)
    
    # Metric 1: Variable Reduction
    ax = axes[0]
    ax.text(0.5, 0.55, f"{avg_var_reduction*100:.0f}%", 
            fontsize=BIG_NUMBER_SIZE, ha='center', va='center', 
            color=ORANGE, fontweight='bold')
    ax.text(0.5, 0.15, "Avg Variable\nReduction", 
            fontsize=METRIC_LABEL_SIZE, ha='center', va='center', color=GRAY)
    ax.axis('off')
    
    # Metric 2: Compile Speedup
    ax = axes[1]
    ax.text(0.5, 0.55, f"{median_compile_speedup:.1f}×", 
            fontsize=BIG_NUMBER_SIZE, ha='center', va='center', 
            color=BLUE, fontweight='bold')
    ax.text(0.5, 0.15, "Median Compile\nSpeedup", 
            fontsize=METRIC_LABEL_SIZE, ha='center', va='center', color=GRAY)
    ax.axis('off')
    
    # Metric 3: Simulation Speedup
    ax = axes[2]
    if median_sim_speedup > 1:
        ax.text(0.5, 0.55, f"{median_sim_speedup:.1f}×", 
                fontsize=BIG_NUMBER_SIZE, ha='center', va='center', 
                color=BLUE, fontweight='bold')
    else:
        ax.text(0.5, 0.55, "N/A", 
                fontsize=BIG_NUMBER_SIZE, ha='center', va='center', 
                color=GRAY, fontweight='bold')
    ax.text(0.5, 0.15, "Median Sim\nSpeedup", 
            fontsize=METRIC_LABEL_SIZE, ha='center', va='center', color=GRAY)
    ax.axis('off')
    
    # Metric 4: Correctness
    ax = axes[3]
    if max_diff < 1e-10:
        diff_str = "<10⁻¹⁰"
        color = ORANGE
    elif max_diff < 1e-6:
        diff_str = f"<10⁻⁶"
        color = ORANGE
    else:
        diff_str = f"{max_diff:.0e}"
        color = ORANGE if max_diff < 0.01 else RED
    
    ax.text(0.5, 0.55, diff_str, 
            fontsize=BIG_NUMBER_SIZE - 8, ha='center', va='center', 
            color=color, fontweight='bold')
    ax.text(0.5, 0.15, "Max Marginal\nDifference", 
            fontsize=METRIC_LABEL_SIZE, ha='center', va='center', color=GRAY)
    ax.axis('off')
    
    # Save
    plt.savefig(os.path.join(output_dir, 'headline_gains.png'), dpi=200, 
                facecolor='white', edgecolor='none')
    plt.savefig(os.path.join(output_dir, 'headline_gains.pdf'), 
                facecolor='white', edgecolor='none')
    plt.close()
    print("Saved: headline_gains.png/pdf")


# =============================================================================
# Figure 2: Simulation Speedup (Single Clean Scatter)
# =============================================================================

def plot_simulation_speedup(results: List[Dict], output_dir: str):
    """
    Clean single-panel scatter: original vs reduced simulation time.
    Uses log-log scale for clarity across orders of magnitude.
    """
    # Filter for valid simulation times
    valid = [(r['orig_sim_time'], r['reduced_sim_time']) 
             for r in results 
             if r.get('orig_sim_time', 0) > 0 and r.get('reduced_sim_time', 0) > 0]
    
    if not valid:
        print("No valid simulation data for speedup plot")
        return
    
    orig, red = zip(*valid)
    orig = np.array(orig)
    red = np.array(red)
    
    fig, ax = plt.subplots(1, 1, figsize=(5.8, 4.5), constrained_layout=True)
    
    # Scatter plot
    ax.scatter(orig, red, alpha=0.7, s=40, color=BLUE, edgecolors='white', linewidth=0.5)
    
    # Diagonal reference line (no speedup)
    min_val = min(orig.min(), red.min()) * 0.8
    max_val = max(orig.max(), red.max()) * 1.2
    ax.plot([min_val, max_val], [min_val, max_val], 
            linestyle='--', color=GRAY, linewidth=1.5, label='No speedup')
    
    # Log-log scale if range spans more than 1 order of magnitude
    if max_val / min_val > 10:
        ax.set_xscale('log')
        ax.set_yscale('log')
    
    # Labels
    ax.set_xlabel('Original simulation time (s)', fontsize=LABEL_SIZE)
    ax.set_ylabel('Reduced simulation time (s)', fontsize=LABEL_SIZE)
    ax.tick_params(axis='both', labelsize=TICK_SIZE)
    
    # Light grid
    ax.grid(True, alpha=0.2, linestyle='-', linewidth=0.5)
    
    # Legend
    ax.legend(fontsize=LEGEND_SIZE, loc='upper left')
    
    # Save
    plt.savefig(os.path.join(output_dir, 'simulation_speedup.png'), dpi=200,
                facecolor='white', edgecolor='none')
    plt.savefig(os.path.join(output_dir, 'simulation_speedup.pdf'),
                facecolor='white', edgecolor='none')
    plt.close()
    print("Saved: simulation_speedup.png/pdf")


# =============================================================================
# Figure 3: Tendril Scaling (Median + IQR)
# =============================================================================

def plot_tendril_speedup_scaling(results: List[Dict], output_dir: str):
    """
    Show how speedup scales with tendril length.
    Uses median + IQR to show convergence/plateau behavior.
    """
    # Extract loopy results grouped by tendril length
    loopy = [r for r in results if r['name'].startswith('loopy')]
    
    if not loopy:
        print("No loopy results for tendril scaling plot")
        return
    
    # Group by tendril length
    tendril_data = {}  # tendril_len -> list of speedups
    
    for r in loopy:
        # Parse tendril length from name like "loopy_c3_t5"
        parts = r['name'].split('_')
        for p in parts:
            if p.startswith('t') and p[1:].isdigit():
                t_len = int(p[1:])
                
                # Compute speedup (prefer simulation, fall back to compile)
                speedup = None
                if r.get('orig_sim_time', 0) > 0 and r.get('reduced_sim_time', 0) > 0:
                    speedup = r['orig_sim_time'] / r['reduced_sim_time']
                
                if speedup and np.isfinite(speedup) and speedup > 0:
                    if t_len not in tendril_data:
                        tendril_data[t_len] = []
                    tendril_data[t_len].append(speedup)
    
    if not tendril_data:
        # Fall back to compile speedup
        for r in loopy:
            parts = r['name'].split('_')
            for p in parts:
                if p.startswith('t') and p[1:].isdigit():
                    t_len = int(p[1:])
                    if r.get('orig_compile_time', 0) > 0 and r.get('reduced_compile_time', 0) > 0:
                        speedup = r['orig_compile_time'] / r['reduced_compile_time']
                        if np.isfinite(speedup) and speedup > 0:
                            if t_len not in tendril_data:
                                tendril_data[t_len] = []
                            tendril_data[t_len].append(speedup)
    
    if not tendril_data:
        print("No valid tendril speedup data")
        return
    
    # Compute median and IQR for each tendril length
    t_lens = sorted(tendril_data.keys())
    medians = []
    q25 = []
    q75 = []
    
    for t in t_lens:
        data = tendril_data[t]
        medians.append(np.median(data))
        q25.append(np.percentile(data, 25))
        q75.append(np.percentile(data, 75))
    
    fig, ax = plt.subplots(1, 1, figsize=(5.8, 4.2), constrained_layout=True)
    
    # Plot median line with markers
    ax.plot(t_lens, medians, 'o-', color=BLUE, linewidth=2, markersize=8, label='Median speedup')
    
    # IQR band
    ax.fill_between(t_lens, q25, q75, alpha=0.2, color=BLUE, label='IQR')
    
    # Reference line at 1x
    ax.axhline(y=1.0, color=GRAY, linestyle='--', linewidth=1, alpha=0.7)
    
    # Asymptote annotation (last point)
    if len(medians) > 1:
        asymptote = medians[-1]
        ax.axhline(y=asymptote, color=ORANGE, linestyle=':', linewidth=1.5, alpha=0.8)
        ax.annotate(f'≈{asymptote:.1f}×', xy=(t_lens[-1], asymptote), 
                   xytext=(t_lens[-1] + 0.5, asymptote + 0.1),
                   fontsize=LEGEND_SIZE, color=ORANGE)
    
    # Labels
    ax.set_xlabel('Tendril length', fontsize=LABEL_SIZE)
    ax.set_ylabel('Speedup factor', fontsize=LABEL_SIZE)
    ax.tick_params(axis='both', labelsize=TICK_SIZE)
    
    # Light grid
    ax.grid(True, alpha=0.2, linestyle='-', linewidth=0.5)
    
    # Legend
    ax.legend(fontsize=LEGEND_SIZE, loc='lower right')
    
    # Save
    plt.savefig(os.path.join(output_dir, 'tendril_speedup_scaling.png'), dpi=200,
                facecolor='white', edgecolor='none')
    plt.savefig(os.path.join(output_dir, 'tendril_speedup_scaling.pdf'),
                facecolor='white', edgecolor='none')
    plt.close()
    print("Saved: tendril_speedup_scaling.png/pdf")


# =============================================================================
# Figure 4: CRN Trajectory Comparison (Large, Readable)
# =============================================================================

def plot_crn_trajectory_comparison(output_dir: str, sim_time: float = 5000):
    """
    Create large, readable CRN trajectory comparison.
    Shows original vs reduced convergence with blue/orange palette.
    """
    from core import Variable, Factor, FactorGraph
    from crn import compile_factor_graph_to_crn, simulate_crn
    from reduction import from_factor_graph, reduce_to_core_spb, to_factor_graph_if_possible
    from inference import run_bp
    
    # Create test graph (chain with endpoints)
    fg = FactorGraph('trajectory_test')
    x1 = fg.add_variable(Variable('x1', [0, 1]))
    x2 = fg.add_variable(Variable('x2', [0, 1]))
    x3 = fg.add_variable(Variable('x3', [0, 1]))
    
    fg.add_factor(Factor('f12', [x1, x2], np.array([[0.9, 0.1], [0.2, 0.8]])))
    fg.add_factor(Factor('f23', [x2, x3], np.array([[0.7, 0.3], [0.4, 0.6]])))
    fg.add_factor(Factor('u1', [x1], np.array([0.8, 0.2])))
    fg.add_factor(Factor('u3', [x3], np.array([0.3, 0.7])))
    
    # Get BP target marginal
    bp_result = run_bp(fg)
    target_p = bp_result.get_marginal('x3')
    
    # Compile original CRN
    orig_crn = compile_factor_graph_to_crn(fg, kappa_r=0.02, kappa_prod=50.0)
    
    # Reduce and compile
    poset = from_factor_graph(fg)
    reduce_to_core_spb(poset)
    reduced_fg = to_factor_graph_if_possible(poset)
    reduced_crn = compile_factor_graph_to_crn(reduced_fg, kappa_r=0.02, kappa_prod=50.0)
    
    # Simulate both
    orig_sim = simulate_crn(orig_crn, t_end=sim_time, n_points=300)
    reduced_sim = simulate_crn(reduced_crn, t_end=sim_time, n_points=300)
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(7, 5), constrained_layout=True)
    
    # Extract trajectories for x3
    t_orig = orig_sim.times
    t_red = reduced_sim.times
    
    # Get marginal trajectories
    if hasattr(orig_sim, 'get_marginal_trajectory'):
        p_orig = orig_sim.get_marginal_trajectory('x3')[:, 0]  # P(x3=0)
        p_red = reduced_sim.get_marginal_trajectory('x3')[:, 0]
    else:
        # Compute from concentrations
        p_orig = []
        p_red = []
        for i in range(len(t_orig)):
            conc = {k: orig_sim.concentrations[k][i] for k in orig_sim.concentrations}
            m0 = conc.get('Marginal_x3_1', 0)
            m1 = conc.get('Marginal_x3_2', 0)
            total = m0 + m1
            p_orig.append(m0 / total if total > 0 else 0.5)
        
        for i in range(len(t_red)):
            conc = {k: reduced_sim.concentrations[k][i] for k in reduced_sim.concentrations}
            m0 = conc.get('Marginal_x3_1', 0)
            m1 = conc.get('Marginal_x3_2', 0)
            total = m0 + m1
            p_red.append(m0 / total if total > 0 else 0.5)
        
        p_orig = np.array(p_orig)
        p_red = np.array(p_red)
    
    # Plot trajectories
    ax.plot(t_orig, p_orig, color=BLUE, linewidth=2.5, 
            label=f'Original ({len(orig_crn.species)} species)', alpha=0.9)
    ax.plot(t_red, p_red, color=ORANGE, linewidth=2.5, 
            label=f'Reduced ({len(reduced_crn.species)} species)', alpha=0.9)
    
    # Target marginal
    ax.axhline(y=target_p[0], color=BLACK, linestyle='--', linewidth=1.5, 
               label=f'BP target: {target_p[0]:.3f}')
    
    # Find convergence times (within 1% of target)
    eps = 0.01
    t_conv_orig = None
    t_conv_red = None
    
    for i, p in enumerate(p_orig):
        if abs(p - target_p[0]) < eps:
            t_conv_orig = t_orig[i]
            break
    
    for i, p in enumerate(p_red):
        if abs(p - target_p[0]) < eps:
            t_conv_red = t_red[i]
            break
    
    # Annotate convergence times
    if t_conv_orig:
        ax.axvline(x=t_conv_orig, color=BLUE, linestyle=':', linewidth=1, alpha=0.7)
        ax.annotate(f't={t_conv_orig:.0f}', xy=(t_conv_orig, 0.52), 
                   fontsize=TICK_SIZE, color=BLUE, ha='left')
    
    if t_conv_red:
        ax.axvline(x=t_conv_red, color=ORANGE, linestyle=':', linewidth=1, alpha=0.7)
        ax.annotate(f't={t_conv_red:.0f}', xy=(t_conv_red, 0.48),
                   fontsize=TICK_SIZE, color=ORANGE, ha='left')
    
    # Labels
    ax.set_xlabel('Time (simulation units)', fontsize=LABEL_SIZE)
    ax.set_ylabel('P(x₃ = 0)', fontsize=LABEL_SIZE)
    ax.tick_params(axis='both', labelsize=TICK_SIZE)
    
    # Light grid
    ax.grid(True, alpha=0.2, linestyle='-', linewidth=0.5)
    
    # Legend
    ax.legend(fontsize=LEGEND_SIZE, loc='upper right')
    
    # Set axis limits
    ax.set_xlim(0, sim_time)
    ax.set_ylim(0.35, 0.55)
    
    # Save
    plt.savefig(os.path.join(output_dir, 'napp_reduction_comparison_clean.png'), dpi=200,
                facecolor='white', edgecolor='none')
    plt.savefig(os.path.join(output_dir, 'napp_reduction_comparison_clean.pdf'),
                facecolor='white', edgecolor='none')
    plt.close()
    print("Saved: napp_reduction_comparison_clean.png/pdf")


# =============================================================================
# Appendix Figure: Size Reduction Scatter (for supplementary)
# =============================================================================

def plot_size_reduction_scatter(results: List[Dict], output_dir: str):
    """
    Clean scatter plot showing original vs reduced variables/species.
    For appendix/supplementary material.
    """
    fig, axes = plt.subplots(1, 2, figsize=(11, 4.5), constrained_layout=True)
    
    valid = [r for r in results if r.get('orig_vars', 0) > 0]
    
    # Variables scatter
    ax = axes[0]
    orig = [r['orig_vars'] for r in valid]
    red = [r['reduced_vars'] for r in valid]
    
    ax.scatter(orig, red, alpha=0.7, s=40, color=BLUE, edgecolors='white', linewidth=0.5)
    
    max_val = max(max(orig), max(red)) * 1.1
    ax.plot([0, max_val], [0, max_val], '--', color=GRAY, linewidth=1.5)
    
    ax.set_xlabel('Original variables', fontsize=LABEL_SIZE)
    ax.set_ylabel('Reduced variables', fontsize=LABEL_SIZE)
    ax.tick_params(axis='both', labelsize=TICK_SIZE)
    ax.grid(True, alpha=0.2)
    ax.set_xlim(0, max_val)
    ax.set_ylim(0, max_val)
    
    # Species scatter
    ax = axes[1]
    orig = [r['orig_species'] for r in valid if r.get('orig_species', 0) > 0]
    red = [r['reduced_species'] for r in valid if r.get('orig_species', 0) > 0]
    
    ax.scatter(orig, red, alpha=0.7, s=40, color=ORANGE, edgecolors='white', linewidth=0.5)
    
    max_val = max(max(orig), max(red)) * 1.1
    ax.plot([0, max_val], [0, max_val], '--', color=GRAY, linewidth=1.5)
    
    ax.set_xlabel('Original CRN species', fontsize=LABEL_SIZE)
    ax.set_ylabel('Reduced CRN species', fontsize=LABEL_SIZE)
    ax.tick_params(axis='both', labelsize=TICK_SIZE)
    ax.grid(True, alpha=0.2)
    ax.set_xlim(0, max_val)
    ax.set_ylim(0, max_val)
    
    # Save
    plt.savefig(os.path.join(output_dir, 'size_reduction_scatter.png'), dpi=200,
                facecolor='white', edgecolor='none')
    plt.savefig(os.path.join(output_dir, 'size_reduction_scatter.pdf'),
                facecolor='white', edgecolor='none')
    plt.close()
    print("Saved: size_reduction_scatter.png/pdf")


# =============================================================================
# Main Generation Function
# =============================================================================

def generate_icml_plots(csv_file: str, output_dir: str):
    """Generate all ICML-ready publication figures."""
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\nGenerating ICML-ready plots from {csv_file}")
    print(f"Output directory: {output_dir}\n")
    
    results = load_results_from_csv(csv_file)
    print(f"Loaded {len(results)} benchmark results\n")
    
    print("Generating main paper figures:")
    print("-" * 40)
    
    # Figure 1: Headline Gains
    plot_headline_gains(results, output_dir)
    
    # Figure 2: Simulation Speedup
    plot_simulation_speedup(results, output_dir)
    
    # Figure 3: Tendril Scaling
    plot_tendril_speedup_scaling(results, output_dir)
    
    # Figure 4: CRN Trajectory Comparison
    try:
        plot_crn_trajectory_comparison(output_dir)
    except Exception as e:
        print(f"Warning: Could not generate trajectory plot: {e}")
    
    print("\nGenerating appendix figures:")
    print("-" * 40)
    
    # Appendix: Size reduction scatter
    plot_size_reduction_scatter(results, output_dir)
    
    print(f"\nAll ICML figures saved to {output_dir}")


if __name__ == "__main__":
    csv_file = "/home/mauwork//factor_graph_project/results/benchmark_results.csv"
    output_dir = "/home/mauwork/factor_graph_project/results/plots_icml"
    
    if os.path.exists(csv_file):
        generate_icml_plots(csv_file, output_dir)
    else:
        print(f"CSV file not found: {csv_file}")
        print("Run benchmark_runner.py first to generate results.")
