import matplotlib.pyplot as plt
import numpy as np

def plot_adaptive_scheduling_metrics(scheduler, logbooks):
    """
    Plot metrics and decisions from adaptive scheduling.
    
    Args:
        scheduler: AdaptiveSimplificationScheduler instance with history
        logbooks: List of DEAP logbooks from islands
    """
    if not scheduler or not scheduler.metrics_history:
        print("No adaptive scheduling data to plot")
        return
    
    # Extract generations where simplification occurred
    simplification_gens = [event["generation"] for event in scheduler.simplification_history]
    
    # Extract metrics over time
    generations = list(range(len(scheduler.metrics_history)))
    diversity = [m["structural_diversity"] for m in scheduler.metrics_history]
    complexity = [m["avg_expression_size"] for m in scheduler.metrics_history]
    improvement = [m.get("improvement_rate", 0) for m in scheduler.metrics_history]
    
    # Plot fitness evolution with simplification events marked
    plt.figure(figsize=(12, 8))
    
    # Plot 1: Minimum fitness across islands
    plt.subplot(2, 2, 1)
    for i, logbook in enumerate(logbooks):
        gen = logbook.select("gen")
        fit_mins = logbook.select("min")
        plt.plot(gen, fit_mins, label=f"Island {i}")
    
    # Mark simplification events
    for gen in simplification_gens:
        plt.axvline(x=gen, color='r', linestyle='--', alpha=0.3)
    
    plt.title('Fitness Evolution with Simplification Events')
    plt.xlabel('Generation')
    plt.ylabel('Fitness (MSE)')
    plt.legend()
    
    # Plot 2: Diversity over time
    plt.subplot(2, 2, 2)
    plt.plot(generations, diversity)
    threshold_line = [scheduler.thresholds["diversity_drop_threshold"]] * len(generations)
    plt.plot(generations, threshold_line, 'r--', label='Threshold')
    
    # Mark simplification events
    for gen in simplification_gens:
        plt.axvline(x=gen, color='r', linestyle='--', alpha=0.3)
    
    plt.title('Population Diversity Over Time')
    plt.xlabel('Generation')
    plt.ylabel('Structural Diversity')
    
    # Plot 3: Complexity over time
    plt.subplot(2, 2, 3)
    plt.plot(generations, complexity)
    
    # Mark simplification events
    for gen in simplification_gens:
        plt.axvline(x=gen, color='r', linestyle='--', alpha=0.3)
    
    plt.title('Expression Complexity Over Time')
    plt.xlabel('Generation')
    plt.ylabel('Avg Expression Size')
    
    # Plot 4: Improvement rate with plateau threshold
    plt.subplot(2, 2, 4)
    plt.plot(generations, improvement)
    threshold_line = [scheduler.thresholds["plateau_threshold"]] * len(generations)
    plt.plot(generations, threshold_line, 'r--', label='Plateau Threshold')
    
    # Mark simplification events
    for gen in simplification_gens:
        plt.axvline(x=gen, color='r', linestyle='--', alpha=0.3)
    
    plt.title('Improvement Rate Over Time')
    plt.xlabel('Generation')
    plt.ylabel('Fitness Improvement Rate')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('adaptive_scheduling_metrics.png')
    plt.show()

def plot_simplification_triggers(scheduler):
    """
    Plot the distribution of triggers that caused simplification.
    
    Args:
        scheduler: AdaptiveSimplificationScheduler instance with history
    """
    if not scheduler or not scheduler.simplification_history:
        print("No simplification history to plot")
        return
    
    # Count trigger occurrences
    trigger_counts = {}
    for event in scheduler.simplification_history:
        for trigger in event["triggers"]:
            trigger_counts[trigger] = trigger_counts.get(trigger, 0) + 1
    
    # Plot trigger distribution
    plt.figure(figsize=(10, 6))
    triggers = list(trigger_counts.keys())
    counts = list(trigger_counts.values())
    
    plt.bar(triggers, counts)
    plt.title('Simplification Trigger Distribution')
    plt.xlabel('Trigger Type')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('simplification_triggers.png')
    plt.show()

def plot_threshold_evolution(scheduler):
    """
    Plot how thresholds evolved over time due to learning.
    
    Args:
        scheduler: AdaptiveSimplificationScheduler instance with history
    """
    if not scheduler or not scheduler.simplification_history:
        print("No simplification history to plot")
        return
    
    # Extract threshold changes if available
    # This requires modifying the scheduler to track threshold history
    # For now, we'll just show the final thresholds
    
    thresholds = scheduler.thresholds
    
    plt.figure(figsize=(10, 6))
    names = list(thresholds.keys())
    values = list(thresholds.values())
    
    plt.bar(names, values)
    plt.title('Final Adaptive Thresholds')
    plt.xlabel('Threshold')
    plt.ylabel('Value')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('adaptive_thresholds.png')
    plt.show()