"""
Plotting Code for SP-B Reduction Benchmarks

Generates publication-quality plots showing:
1. Reduction size gains (variables, factors, species, reactions)
2. Runtime speedups (compile time, simulation time)
3. Correctness preservation (marginal comparison)
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Dict, Any, Optional, Tuple
import csv
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Try to import seaborn for nicer plots
try:
    import seaborn as sns
    sns.set_style("whitegrid")
    HAS_SEABORN = True
except ImportError:
    HAS_SEABORN = False

from benchmarks.benchmark_runner import BenchmarkResult


def _to_float(x):
    """Parse CSV value to float; return nan if missing/non-numeric."""
    if x is None:
        return float('nan')
    if isinstance(x, (int, float)):
        return float(x)
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return float('nan')
    try:
        return float(s)
    except ValueError:
        return float('nan')

def _to_int(x):
    """Parse CSV value to int; return None if missing/non-numeric."""
    if x is None:
        return None
    if isinstance(x, int):
        return x
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none", "n/a", "na"}:
        return None
    try:
        return int(float(s))  # handles "1.0"
    except ValueError:
        return None

def load_results_from_csv(filename: str) -> List[Dict[str, Any]]:
    """Load benchmark results from CSV file with robust type coercion."""
    results = []
    with open(filename, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            # Always keep name as string
            name = row.get("name", "")
            out: Dict[str, Any] = {"name": name}

            # Coerce known int-like keys (counts, flags)
            INT_KEYS = {
                "orig_vars", "reduced_vars", "orig_factors", "reduced_factors",
                "orig_edges", "reduced_edges", "orig_species", "reduced_species",
                "orig_reactions", "reduced_reactions", "n_reduction_steps",
                "bp_converged_orig", "bp_converged_reduced",
            }

            # Everything else numeric -> float (times, ratios, diffs, speedups)
            for k, v in row.items():
                if k == "name":
                    continue
                if k in INT_KEYS:
                    out[k] = _to_int(v)
                else:
                    out[k] = _to_float(v)

            results.append(out)
    return results

def categorize_results(results: List[Dict]) -> Dict[str, List[Dict]]:
    """Categorize results by graph family."""
    categories = {
        'chain': [],
        'tree': [],
        'loopy': [],
        'grid': [],
        'random': []
    }
    
    for r in results:
        name = r['name']
        if name.startswith('chain'):
            categories['chain'].append(r)
        elif name.startswith('tree'):
            categories['tree'].append(r)
        elif name.startswith('loopy'):
            categories['loopy'].append(r)
        elif name.startswith('grid'):
            categories['grid'].append(r)
        elif name.startswith('random'):
            categories['random'].append(r)
    
    return categories


# =============================================================================
# Plot Set 1: Reduction Size Gains
# =============================================================================

def plot_size_reduction_bars(results: List[Dict], output_dir: str):
    """
    Plot bar charts showing before/after for variables, factors, species, reactions.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    names = [r['name'][:15] for r in results]  # Truncate long names
    x = np.arange(len(names))
    width = 0.35
    
    metrics = [
        ('orig_vars', 'reduced_vars', '# Variables', axes[0, 0]),
        ('orig_factors', 'reduced_factors', '# Factors', axes[0, 1]),
        ('orig_species', 'reduced_species', '# CRN Species', axes[1, 0]),
        ('orig_reactions', 'reduced_reactions', '# CRN Reactions', axes[1, 1]),
    ]
    
    colors = ['#3498db', '#2ecc71']  # Blue for original, green for reduced
    
    for orig_key, red_key, title, ax in metrics:
        orig_vals = [r[orig_key] for r in results]
        red_vals = [r[red_key] for r in results]
        
        bars1 = ax.bar(x - width/2, orig_vals, width, label='Original', color=colors[0], alpha=0.8)
        bars2 = ax.bar(x + width/2, red_vals, width, label='Reduced', color=colors[1], alpha=0.8)
        
        ax.set_ylabel('Count')
        ax.set_title(title)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=45, ha='right', fontsize=8)
        ax.legend()
        
        # Add reduction percentage labels
        for i, (o, r) in enumerate(zip(orig_vals, red_vals)):
            if o > 0:
                pct = (1 - r/o) * 100
                ax.annotate(f'{pct:.0f}%↓', xy=(i, max(o, r)), 
                           ha='center', va='bottom', fontsize=7, color='red')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'size_reduction_bars.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'size_reduction_bars.pdf'))
    plt.close()
    print(f"Saved: size_reduction_bars.png/pdf")


def plot_reduction_ratio_by_family(results: List[Dict], output_dir: str):
    """
    Plot reduction ratios grouped by graph family.
    """
    categories = categorize_results(results)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Variables reduction
    ax = axes[0]
    family_data = []
    family_names = []
    
    for family, family_results in categories.items():
        if family_results:
            ratios = [r['reduced_vars'] / r['orig_vars'] if r['orig_vars'] > 0 else 1 
                     for r in family_results]
            family_data.append(ratios)
            family_names.append(family.capitalize())
    
    bp = ax.boxplot(family_data, labels=family_names, patch_artist=True)
    colors = plt.cm.Set3(np.linspace(0, 1, len(family_names)))
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    ax.axhline(y=1.0, color='red', linestyle='--', alpha=0.5, label='No reduction')
    ax.set_ylabel('Reduction Ratio (lower is better)')
    ax.set_title('Variable Reduction by Graph Family')
    ax.set_ylim(0, 1.1)
    
    # Species reduction
    ax = axes[1]
    family_data = []
    
    for family, family_results in categories.items():
        if family_results:
            ratios = [r['reduced_species'] / r['orig_species'] if r['orig_species'] > 0 else 1 
                     for r in family_results]
            family_data.append(ratios)
    
    bp = ax.boxplot(family_data, labels=family_names, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    ax.axhline(y=1.0, color='red', linestyle='--', alpha=0.5)
    ax.set_ylabel('Reduction Ratio (lower is better)')
    ax.set_title('CRN Species Reduction by Graph Family')
    ax.set_ylim(0, 1.1)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'reduction_ratio_by_family.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'reduction_ratio_by_family.pdf'))
    plt.close()
    print(f"Saved: reduction_ratio_by_family.png/pdf")


# =============================================================================
# Plot Set 2: Runtime Speedups
# =============================================================================

def plot_compile_time_comparison(results: List[Dict], output_dir: str):
    """
    Plot compilation time before vs after reduction.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Filter results with valid timing
    valid = [r for r in results if r['orig_compile_time'] > 0 and r['reduced_compile_time'] > 0]
    
    # Scatter plot: original vs reduced
    ax = axes[0]
    orig = [r['orig_compile_time'] * 1000 for r in valid]  # Convert to ms
    red = [r['reduced_compile_time'] * 1000 for r in valid]
    
    ax.scatter(orig, red, alpha=0.7, s=50)
    
    # Add diagonal line (no speedup)
    max_val = max(max(orig), max(red)) * 1.1
    ax.plot([0, max_val], [0, max_val], 'r--', alpha=0.5, label='No speedup')
    
    ax.set_xlabel('Original Compile Time (ms)')
    ax.set_ylabel('Reduced Compile Time (ms)')
    ax.set_title('CRN Compilation Time')
    ax.legend()
    ax.set_xlim(0, max_val)
    ax.set_ylim(0, max_val)
    
    # Speedup histogram
    ax = axes[1]
    speedups = [r['orig_compile_time'] / r['reduced_compile_time'] for r in valid 
                if r['reduced_compile_time'] > 0]
    
    ax.hist(speedups, bins=20, alpha=0.7, color='#3498db', edgecolor='black')
    ax.axvline(x=1.0, color='red', linestyle='--', label='No speedup')
    ax.axvline(x=np.median(speedups), color='green', linestyle='-', 
               label=f'Median: {np.median(speedups):.1f}x')
    
    ax.set_xlabel('Speedup Factor')
    ax.set_ylabel('Count')
    ax.set_title('Compilation Speedup Distribution')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'compile_time_comparison.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'compile_time_comparison.pdf'))
    plt.close()
    print(f"Saved: compile_time_comparison.png/pdf")


def plot_simulation_speedup(results: List[Dict], output_dir: str):
    """
    Plot simulation time speedups.
    """
    # Filter results with valid simulation timing
    valid = [r for r in results if r['orig_sim_time'] > 0 and r['reduced_sim_time'] > 0]
    
    if not valid:
        print("No valid simulation timing data")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Scatter plot
    ax = axes[0]
    orig = [r['orig_sim_time'] for r in valid]
    red = [r['reduced_sim_time'] for r in valid]
    
    ax.scatter(orig, red, alpha=0.7, s=50)
    max_val = max(max(orig), max(red)) * 1.1
    ax.plot([0, max_val], [0, max_val], 'r--', alpha=0.5, label='No speedup')
    
    ax.set_xlabel('Original Simulation Time (s)')
    ax.set_ylabel('Reduced Simulation Time (s)')
    ax.set_title('CRN Simulation Time')
    ax.legend()
    
    # Speedup vs original species count
    ax = axes[1]
    species = [r['orig_species'] for r in valid]
    speedups = [r['orig_sim_time'] / r['reduced_sim_time'] for r in valid]
    
    ax.scatter(species, speedups, alpha=0.7, s=50)
    ax.axhline(y=1.0, color='red', linestyle='--', alpha=0.5, label='No speedup')
    
    ax.set_xlabel('Original # Species')
    ax.set_ylabel('Simulation Speedup')
    ax.set_title('Simulation Speedup vs Graph Size')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'simulation_speedup.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'simulation_speedup.pdf'))
    plt.close()
    print(f"Saved: simulation_speedup.png/pdf")


def plot_tendril_length_scaling(results: List[Dict], output_dir: str):
    """
    Plot speedup vs tendril length for loopy-core graphs.
    """
    categories = categorize_results(results)
    loopy = categories['loopy']
    
    if not loopy:
        print("No loopy graph results")
        return
    
    # Extract tendril length from name
    def get_tendril_length(name):
        # Format: loopy_cX_tY
        parts = name.split('_')
        for p in parts:
            if p.startswith('t') and p[1:].isdigit():
                return int(p[1:])
        return 0
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Group by core size
    core_sizes = {}
    for r in loopy:
        name = r['name']
        # Extract core size
        for part in name.split('_'):
            if part.startswith('c') and part[1:].isdigit():
                core_size = int(part[1:])
                if core_size not in core_sizes:
                    core_sizes[core_size] = []
                core_sizes[core_size].append(r)
                break
    
    # Variable reduction vs tendril length
    ax = axes[0]
    for core_size, group in sorted(core_sizes.items()):
        tendril_lens = [get_tendril_length(r['name']) for r in group]
        var_ratios = [r['reduced_vars'] / r['orig_vars'] if r['orig_vars'] > 0 else 1 
                     for r in group]
        
        # Sort by tendril length
        sorted_data = sorted(zip(tendril_lens, var_ratios))
        tendril_lens, var_ratios = zip(*sorted_data)
        
        ax.plot(tendril_lens, var_ratios, 'o-', label=f'Core size {core_size}', markersize=8)
    
    ax.set_xlabel('Tendril Length')
    ax.set_ylabel('Variable Reduction Ratio')
    ax.set_title('Reduction vs Tendril Length')
    ax.legend()
    ax.set_ylim(0, 1.1)
    
    # Speedup vs tendril length
    ax = axes[1]
    for core_size, group in sorted(core_sizes.items()):
        tendril_lens = [get_tendril_length(r['name']) for r in group]
        speedups = []
        for r in group:
            if r['orig_sim_time'] > 0 and r['reduced_sim_time'] > 0:
                speedups.append(r['orig_sim_time'] / r['reduced_sim_time'])
            elif r['orig_compile_time'] > 0 and r['reduced_compile_time'] > 0:
                speedups.append(r['orig_compile_time'] / r['reduced_compile_time'])
            else:
                speedups.append(1.0)
        
        sorted_data = sorted(zip(tendril_lens, speedups))
        tendril_lens, speedups = zip(*sorted_data)
        
        ax.plot(tendril_lens, speedups, 'o-', label=f'Core size {core_size}', markersize=8)
    
    ax.set_xlabel('Tendril Length')
    ax.set_ylabel('Speedup Factor')
    ax.set_title('Speedup vs Tendril Length')
    ax.legend()
    ax.axhline(y=1.0, color='red', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'tendril_scaling.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'tendril_scaling.pdf'))
    plt.close()
    print(f"Saved: tendril_scaling.png/pdf")


# =============================================================================
# Plot Set 3: Correctness Preservation
# =============================================================================

def plot_marginal_differences(results: List[Dict], output_dir: str):
    """
    Plot marginal differences to show correctness preservation.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Histogram of marginal differences
    ax = axes[0]
    diffs = [r['marginal_max_diff'] for r in results
         if np.isfinite(r.get('marginal_max_diff', np.nan))
         and r.get('bp_converged_orig', 0) == 1
         and r.get('bp_converged_reduced', 0) == 1]
    if diffs:
        # Use log scale for small differences
        log_diffs = [np.log10(d) if d > 0 else -16 for d in diffs]
        
        ax.hist(log_diffs, bins=20, alpha=0.7, color='#2ecc71', edgecolor='black')
        ax.set_xlabel('log₁₀(Max Marginal Difference)')
        ax.set_ylabel('Count')
        ax.set_title('Marginal Preservation Quality')
        ax.axvline(x=-2, color='red', linestyle='--', label='1% threshold')
        ax.legend()
    
    # By category
    ax = axes[1]
    categories = categorize_results(results)
    
    cat_data = []
    cat_names = []
    for family, family_results in categories.items():
        if family_results:
            diffs = [r['marginal_max_diff'] for r in family_results
                    if np.isfinite(r.get('marginal_max_diff', np.nan))]
            if diffs:
                cat_data.append(diffs)
                cat_names.append(family.capitalize())
    
    if cat_data:
        bp = ax.boxplot(cat_data, labels=cat_names, patch_artist=True)
        colors = plt.cm.Set3(np.linspace(0, 1, len(cat_names)))
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
        
        ax.set_ylabel('Max Marginal Difference')
        ax.set_title('Correctness by Graph Family')
        ax.set_yscale('log')
        ax.axhline(y=0.01, color='red', linestyle='--', alpha=0.5, label='1% threshold')
        ax.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'marginal_differences.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'marginal_differences.pdf'))
    plt.close()
    print(f"Saved: marginal_differences.png/pdf")


def plot_summary_dashboard(results: List[Dict], output_dir: str):
    """
    Create a summary dashboard with key metrics.
    """
    fig = plt.figure(figsize=(16, 10))
    
    # Calculate summary statistics
    valid = [r for r in results if r['orig_vars'] > 0]
    
    avg_var_reduction = np.mean([r['reduced_vars'] / r['orig_vars'] for r in valid])
    avg_species_reduction = np.mean([r['reduced_species'] / r['orig_species'] 
                                     for r in valid if r['orig_species'] > 0])
    
    compile_speedups = [r['orig_compile_time'] / r['reduced_compile_time'] 
                       for r in valid if r['reduced_compile_time'] > 0]
    avg_compile_speedup = np.mean(compile_speedups) if compile_speedups else 1
    
    sim_speedups = [r['orig_sim_time'] / r['reduced_sim_time'] 
                   for r in valid if r['reduced_sim_time'] > 0]
    avg_sim_speedup = np.mean(sim_speedups) if sim_speedups else 1
    
    finite_diffs = [r['marginal_max_diff'] for r in valid
                if np.isfinite(r.get('marginal_max_diff', np.nan))]
    max_diff = max(finite_diffs) if finite_diffs else np.nan
    
    # Create subplots
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    # Top row: Key metrics as big numbers
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.text(0.5, 0.5, f"{(1-avg_var_reduction)*100:.0f}%", 
             fontsize=48, ha='center', va='center', color='#2ecc71')
    ax1.text(0.5, 0.15, "Avg Variable\nReduction", fontsize=14, ha='center', va='center')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.text(0.5, 0.5, f"{avg_compile_speedup:.1f}x", 
             fontsize=48, ha='center', va='center', color='#3498db')
    ax2.text(0.5, 0.15, "Avg Compile\nSpeedup", fontsize=14, ha='center', va='center')
    ax2.axis('off')
    
    ax3 = fig.add_subplot(gs[0, 2])
    if np.isfinite(max_diff):
        label = f"{max_diff:.0e}"
    else:
        label = "N/A"
    ax3.text(0.5, 0.5, label, fontsize=36, ha='center', va='center',
         color='#2ecc71' if np.isfinite(max_diff) and max_diff < 0.01 else '#e74c3c')
    ax3.text(0.5, 0.15, "Max Marginal\nDifference", fontsize=14, ha='center', va='center')
    ax3.axis('off')
    
    # Bottom row: Charts
    # Reduction scatter
    ax4 = fig.add_subplot(gs[1, 0])
    orig_vars = [r['orig_vars'] for r in valid]
    red_vars = [r['reduced_vars'] for r in valid]
    ax4.scatter(orig_vars, red_vars, alpha=0.7, s=50, c='#3498db')
    max_v = max(max(orig_vars), max(red_vars)) * 1.1
    ax4.plot([0, max_v], [0, max_v], 'r--', alpha=0.5)
    ax4.set_xlabel('Original Variables')
    ax4.set_ylabel('Reduced Variables')
    ax4.set_title('Variable Reduction')
    
    # Speedup distribution
    ax5 = fig.add_subplot(gs[1, 1])
    if compile_speedups:
        ax5.hist(compile_speedups, bins=15, alpha=0.7, color='#3498db', edgecolor='black')
        ax5.axvline(x=np.median(compile_speedups), color='green', linestyle='-',
                   label=f'Median: {np.median(compile_speedups):.1f}x')
        ax5.axvline(x=1.0, color='red', linestyle='--', alpha=0.5)
        ax5.set_xlabel('Speedup Factor')
        ax5.set_ylabel('Count')
        ax5.set_title('Compile Speedup Distribution')
        ax5.legend()
    
    # Correctness by family
    ax6 = fig.add_subplot(gs[1, 2])
    categories = categorize_results(results)
    cat_names = []
    cat_diffs = []
    for family, family_results in categories.items():
        if family_results:
            diffs = [r['marginal_max_diff'] for r in family_results 
                    if r['marginal_max_diff'] >= 0]
            if diffs:
                cat_names.append(family[:5].capitalize())
                cat_diffs.append(np.mean(diffs))
    
    if cat_names:
        bars = ax6.bar(cat_names, cat_diffs, color='#2ecc71', alpha=0.8)
        ax6.axhline(y=0.01, color='red', linestyle='--', alpha=0.5)
        ax6.set_ylabel('Avg Max Marginal Diff')
        ax6.set_title('Correctness by Family')
        ax6.set_yscale('log')
    
    plt.suptitle('SP-B Reduction + CRN Compilation: Summary Dashboard', fontsize=16, y=0.98)
    
    plt.savefig(os.path.join(output_dir, 'summary_dashboard.png'), dpi=150)
    plt.savefig(os.path.join(output_dir, 'summary_dashboard.pdf'))
    plt.close()
    print(f"Saved: summary_dashboard.png/pdf")


def generate_all_plots(csv_file: str, output_dir: str):
    """Generate all benchmark plots."""
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\nGenerating plots from {csv_file}")
    print(f"Output directory: {output_dir}\n")
    
    results = load_results_from_csv(csv_file)
    print(f"Loaded {len(results)} benchmark results\n")
    
    # Generate all plots
    print("Generating plots...")
    plot_size_reduction_bars(results, output_dir)
    plot_reduction_ratio_by_family(results, output_dir)
    plot_compile_time_comparison(results, output_dir)
    plot_simulation_speedup(results, output_dir)
    plot_tendril_length_scaling(results, output_dir)
    plot_marginal_differences(results, output_dir)
    plot_summary_dashboard(results, output_dir)
    
    print(f"\nAll plots saved to {output_dir}")


if __name__ == "__main__":
    csv_file = "/home/mauwork/factor_graph_project/results/benchmark_results.csv"
    output_dir = "/home/mauwork/factor_graph_project/results/plots"
    
    if os.path.exists(csv_file):
        generate_all_plots(csv_file, output_dir)
    else:
        print(f"CSV file not found: {csv_file}")
        print("Run benchmark_runner.py first to generate results.")
