import numpy as np
import copy
from collections import defaultdict
import random
from deap import tools
# Import the necessary function from your existing code
from island_model import integrate_simplified_individuals, perform_migration

class AdaptiveSimplificationScheduler:
    """
    Manages adaptive scheduling of LLM simplification operations.
    """
    
    def __init__(self, initial_thresholds=None):
        """
        Initialize the scheduler with default or provided thresholds.
        
        Args:
            initial_thresholds: Dictionary of initial threshold values
        """
        # Default thresholds
        self.thresholds = {
            "min_interval": 3,          # Minimum generations between simplifications
            "max_interval": 10,         # Maximum generations without simplification
            "complexity_growth_rate": 0.05,  # Expression growth rate that triggers simplification
            "min_improvement_rate": 0.01,    # Minimum acceptable fitness improvement rate
            "diversity_drop_threshold": 0.1,  # Diversity drop that triggers action
            "plateau_threshold": 0.005       # Threshold for detecting plateau
        }
        
        # Update with provided thresholds if any
        if initial_thresholds:
            self.thresholds.update(initial_thresholds)
        
        # Tracking variables
        self.metrics_history = []
        self.simplification_history = []
        self.last_simplification_gen = 0
        
        # Store previous state
        self.previous_islands = None
        
    def calculate_metrics(self, islands, logbooks):
        """
        Calculate current evolutionary metrics for decision making.
        
        Args:
            islands: List of island populations
            logbooks: List of island logbooks
            
        Returns:
            Dictionary of metrics
        """
        metrics = {}
        
        # Combine all individuals across islands
        all_individuals = [ind for pop in islands for ind in pop]
        
        # Structural diversity: proportion of unique expressions
        unique_structures = set(str(ind) for ind in all_individuals)
        metrics["structural_diversity"] = len(unique_structures) / len(all_individuals)
        
        # Average expression size
        metrics["avg_expression_size"] = np.mean([len(ind) for ind in all_individuals])
        
        # Calculate fitness metrics
        avg_fitness = np.mean([ind.fitness.values[0] for ind in all_individuals])
        metrics["avg_fitness"] = avg_fitness
        
        # Best fitness for this generation
        metrics["best_fitness"] = min([min([ind.fitness.values[0] for ind in pop]) for pop in islands])
        
        # Calculate fitness improvement rate
        if len(self.metrics_history) > 0:
            prev_best = self.metrics_history[-1]["best_fitness"]
            improvement = (prev_best - metrics["best_fitness"]) / prev_best if prev_best > 0 else 0
            metrics["improvement_rate"] = max(0, improvement)  # Only positive improvements
        else:
            metrics["improvement_rate"] = 0
        
        # Calculate complexity growth if we have previous data
        if self.previous_islands is not None:
            prev_all = [ind for pop in self.previous_islands for ind in pop]
            prev_avg_size = np.mean([len(ind) for ind in prev_all])
            growth = (metrics["avg_expression_size"] - prev_avg_size) / prev_avg_size if prev_avg_size > 0 else 0
            metrics["complexity_growth"] = growth
        else:
            metrics["complexity_growth"] = 0
        
        # Count unique operators to measure functional diversity
        operators = []
        for ind in all_individuals:
            # Extract operators from the expression string (assumes DEAP's format)
            ops = str(ind).split('(')[0::2]  # Get every other split starting from 0
            operators.extend(ops)
        
        unique_ops = set(operators)
        metrics["operator_diversity"] = len(operators) > 0 and len(unique_ops) / len(operators) or 0
        
        return metrics
    
    def should_trigger_simplification(self, current_gen):
        """
        Decide whether to trigger LLM simplification in this generation.
        
        Args:
            current_gen: Current generation number
            
        Returns:
            tuple: (should_simplify, trigger_reasons)
        """
        # Enforce minimum interval between simplifications
        if current_gen - self.last_simplification_gen < self.thresholds["min_interval"]:
            return False, []
        
        triggers = []
        
        # Only make decisions when we have enough history
        if len(self.metrics_history) >= 3:
            recent = self.metrics_history[-3:]
            
            # Detect bloat: complexity increasing without fitness improvement
            fitness_improving = recent[-1]["improvement_rate"] > self.thresholds["min_improvement_rate"]
            complexity_growing = recent[-1]["complexity_growth"] > self.thresholds["complexity_growth_rate"]
            
            if complexity_growing and not fitness_improving:
                triggers.append("bloat_detected")
            
            # Detect diversity loss
            if len(recent) >= 3:
                diversity_dropping = (recent[-3]["structural_diversity"] - 
                                     recent[-1]["structural_diversity"]) > self.thresholds["diversity_drop_threshold"]
                
                if diversity_dropping:
                    triggers.append("diversity_loss")
            
            # Detect fitness plateau
            recent_improvements = [m["improvement_rate"] for m in recent]
            plateau = all(imp < self.thresholds["plateau_threshold"] for imp in recent_improvements)
            
            if plateau and current_gen > 5:  # Don't trigger plateau too early
                triggers.append("fitness_plateau")
        
        # Always trigger if max interval reached
        if current_gen - self.last_simplification_gen >= self.thresholds["max_interval"]:
            triggers.append("max_interval_reached")
        
        return len(triggers) > 0, triggers
    
    def evaluate_simplification_success(self, original_islands, new_islands):
        """
        Evaluate how successful the simplification was.
        
        Args:
            original_islands: Islands before simplification
            new_islands: Islands after simplification
            
        Returns:
            Dictionary of success metrics
        """
        success_metrics = {}
        
        # Calculate fitness improvement
        orig_best = min([min([ind.fitness.values[0] for ind in pop]) for pop in original_islands])
        new_best = min([min([ind.fitness.values[0] for ind in pop]) for pop in new_islands])
        
        fitness_improvement = (orig_best - new_best) / orig_best if orig_best != 0 else 0
        success_metrics["fitness_improvement"] = fitness_improvement
        
        # Calculate complexity reduction
        orig_complexity = np.mean([len(ind) for pop in original_islands for ind in pop])
        new_complexity = np.mean([len(ind) for pop in new_islands for ind in pop])
        
        complexity_reduction = (orig_complexity - new_complexity) / orig_complexity if orig_complexity != 0 else 0
        success_metrics["complexity_reduction"] = complexity_reduction
        
        # Weighted success score
        success_metrics["success_score"] = 0.7 * fitness_improvement + 0.3 * complexity_reduction
        
        return success_metrics
    
    def update_thresholds(self, learning_rate=0.1):
        """
        Update thresholds based on simplification outcomes.
        
        Args:
            learning_rate: How quickly to adjust thresholds
            
        Returns:
            None (updates internal state)
        """
        # Skip if not enough history
        if len(self.simplification_history) < 2:
            return
        
        # Look at recent simplifications
        recent = self.simplification_history[-min(5, len(self.simplification_history)):]
        
        # Group by trigger
        by_trigger = defaultdict(list)
        for event in recent:
            for trigger in event["triggers"]:
                by_trigger[trigger].append(event["success_metrics"]["success_score"])
        
        # Update thresholds for each trigger
        for trigger, scores in by_trigger.items():
            avg_score = np.mean(scores)
            success_threshold = 0.2  # Minimum score to consider successful
            
            if trigger == "bloat_detected" and len(scores) >= 2:
                if avg_score > success_threshold:  # Successful interventions
                    # Make more sensitive to bloat (lower threshold)
                    self.thresholds["complexity_growth_rate"] *= (1 - learning_rate)
                else:
                    # Make less sensitive (higher threshold)
                    self.thresholds["complexity_growth_rate"] *= (1 + learning_rate)
            
            elif trigger == "diversity_loss" and len(scores) >= 2:
                if avg_score > success_threshold:
                    # Make more sensitive to diversity loss
                    self.thresholds["diversity_drop_threshold"] *= (1 - learning_rate)
                else:
                    # Make less sensitive
                    self.thresholds["diversity_drop_threshold"] *= (1 + learning_rate)
            
            elif trigger == "fitness_plateau" and len(scores) >= 2:
                if avg_score > success_threshold:
                    # Make more sensitive to plateaus
                    self.thresholds["plateau_threshold"] *= (1 + learning_rate)
                else:
                    # Make less sensitive
                    self.thresholds["plateau_threshold"] *= (1 - learning_rate)
        
        # Ensure thresholds stay within reasonable bounds
        self.thresholds["complexity_growth_rate"] = max(0.01, min(0.2, self.thresholds["complexity_growth_rate"]))
        self.thresholds["diversity_drop_threshold"] = max(0.05, min(0.3, self.thresholds["diversity_drop_threshold"]))
        self.thresholds["plateau_threshold"] = max(0.001, min(0.02, self.thresholds["plateau_threshold"]))
    
    def update_state(self, islands, logbooks, current_gen):
        """
        Update the scheduler's internal state with current evolution data.
        
        Args:
            islands: Current island populations
            logbooks: Current logbooks
            current_gen: Current generation number
            
        Returns:
            dict: Current metrics
        """
        # Store current metrics
        current_metrics = self.calculate_metrics(islands, logbooks)
        self.metrics_history.append(current_metrics)
        
        # Store current islands state (for next comparison)
        # We use a shallow copy of each island's list but don't copy individuals
        # This is to avoid excessive memory usage
        self.previous_islands = [pop[:] for pop in islands]
        
        return current_metrics


# Modified run_gp_with_island_model function with adaptive scheduling
def run_gp_with_island_model_adaptive(toolbox, pset, num_islands, pop_per_island, ngen, stats, 
                                   simplifier, migration_interval=10, migration_rate=0.1,
                                   island_strategies=None, hall_of_fame=None, verbose=True,
                                   adaptive_scheduler=None, same_prompt=False):
    """
    Run GP evolution with an island model and adaptive LLM simplification.
    
    Args:
        toolbox: DEAP toolbox
        pset: Primitive set
        num_islands: Number of islands (subpopulations)
        pop_per_island: Size of each island population
        ngen: Number of generations
        stats: Statistics object
        simplifier: BatchLLMSimplifier instance
        migration_interval: How often to migrate individuals between islands
        migration_rate: Proportion of population to migrate
        island_strategies: Dict of strategies for each island (can be None for default)
        hall_of_fame: Hall of fame object
        verbose: Whether to print progress
        adaptive_scheduler: AdaptiveSimplificationScheduler instance (if None, creates new one)
        
    Returns:
        tuple: (final_populations, logbooks, best_individual, scheduler)
    """
    # Initialize logbooks for statistics
    logbooks = [tools.Logbook() for _ in range(num_islands)]
    for logbook in logbooks:
        logbook.header = ['gen', 'nevals'] + (stats.fields if stats else [])
    
    # Set up different strategies for each island if not provided
    if island_strategies is None:
        # Default varied strategies for islands
        island_strategies = {
            i: {
                "replacement_strategy": ["tournament", "worst", "random"][i % 3],
                "mutation_rate": 0.1 + (i * 0.05),  # Different mutation rates
                "crossover_rate": 0.7 - (i * 0.05),  # Different crossover rates
                "top_n": max(3, pop_per_island // 10),  # Different simplification amounts
                "simplification_focus": ["generalization", "simplicity", "balance"][i % 3]  # Different LLM focuses
            } for i in range(num_islands)
        }
    
    # Create independent island populations
    islands = [toolbox.population(n=pop_per_island) for _ in range(num_islands)]
    
    # Island-specific halls of fame
    island_hofs = [tools.HallOfFame(3) for _ in range(num_islands)]
    
    # Initialize adaptive scheduler if not provided
    if adaptive_scheduler is None:
        adaptive_scheduler = AdaptiveSimplificationScheduler()
    
    # Evaluate all initial populations
    for i, pop in enumerate(islands):
        invalid_ind = [ind for ind in pop if not ind.fitness.valid]
        fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit
        
        # Update island hall of fame
        island_hofs[i].update(pop)
        
        # Record initial statistics
        record = stats.compile(pop) if stats else {}
        logbooks[i].record(gen=0, nevals=len(invalid_ind), **record)
        if verbose:
            print(f"Island {i} - Initial stats: {logbooks[i].stream}")
    
    # Update scheduler with initial state
    adaptive_scheduler.update_state(islands, logbooks, 0)
    
    # Begin evolution
    for gen in range(1, ngen + 1):
        if verbose:
            print(f"\n--- Generation {gen} ---")
        
        # Process each island
        for i, pop in enumerate(islands):
            strategy = island_strategies[i]
            
            # Select and clone the next generation individuals
            offspring = toolbox.select(pop, len(pop))
            offspring = list(map(toolbox.clone, offspring))
            
            # Apply crossover with island-specific rate
            for j in range(1, len(offspring), 2):
                if random.random() < strategy["crossover_rate"]:
                    offspring[j-1], offspring[j] = toolbox.mate(offspring[j-1], offspring[j])
                    del offspring[j-1].fitness.values, offspring[j].fitness.values
            
            # Apply mutation with island-specific rate
            for j in range(len(offspring)):
                if random.random() < strategy["mutation_rate"]:
                    offspring[j], = toolbox.mutate(offspring[j])
                    del offspring[j].fitness.values
            
            # Evaluate new individuals
            invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
            fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
            for ind, fit in zip(invalid_ind, fitnesses):
                ind.fitness.values = fit
            
            # Replace island population with offspring
            islands[i][:] = offspring
            
            # Update island hall of fame
            island_hofs[i].update(islands[i])
            
            # Record statistics
            record = stats.compile(islands[i]) if stats else {}
            logbooks[i].record(gen=gen, nevals=len(invalid_ind), **record)
            if verbose:
                print(f"Island {i} - Gen {gen}: {logbooks[i].stream}")
        
        # Update the adaptive scheduler with current state
        adaptive_scheduler.update_state(islands, logbooks, gen)
        
        # Get decision on whether to perform simplification
        should_simplify, triggers = adaptive_scheduler.should_trigger_simplification(gen)
        
        # LLM-guided simplification based on adaptive decision
        if should_simplify:
            if verbose:
                print(f"Adaptive simplification triggered at gen {gen}. Reasons: {triggers}")
            
            # Store pre-simplification islands for evaluation
            pre_simplification_islands = [pop[:] for pop in islands]
            
            # Perform simplification on each island
            for i, pop in enumerate(islands):
                # Get island-specific parameters
                strategy = island_strategies[i]
                simplification_focus = strategy.get("simplification_focus", "balance")
                top_n = strategy.get("top_n", pop_per_island // 10)
                
                # Select top individuals for simplification
                top_individuals = tools.selBest(pop, top_n)
                
                # Remove duplicates
                unique_top_individuals = []
                seen = set()
                for ind in top_individuals:
                    ind_str = str(ind)
                    # if ind_str not in seen:
                    #     seen.add(ind_str)
                    #     unique_top_individuals.append(ind)
                    unique_top_individuals.append(ind)
                
                top_individuals = unique_top_individuals
                
                if verbose:
                    print(f"Island {i} - Simplifying {len(top_individuals)} individuals with focus on {simplification_focus}")
                
                # Simplify with island-specific focus
                try:
                    simplified = simplifier.batch_simplify(
                        top_individuals, 
                        pset, 
                        toolbox, 
                        batch_size=min(len(top_individuals), 10),
                        simplification_focus=simplification_focus,
                        same_prompt=same_prompt
                    )
                    
                    # Integrate simplified individuals using island-specific strategy
                    if simplified:
                        # Use the existing integrate_simplified_individuals function
                        islands[i][:] = integrate_simplified_individuals(
                            islands[i], simplified, 
                            toolbox, strategy["replacement_strategy"], 
                            replacement_rate=0.1
                        )
                except Exception as e:
                    print(f"Error during simplification: {e}")
            
            # Evaluate success of this simplification event
            success_metrics = adaptive_scheduler.evaluate_simplification_success(
                pre_simplification_islands, islands)
            
            # Record the simplification event
            adaptive_scheduler.simplification_history.append({
                "generation": gen,
                "triggers": triggers,
                "success_metrics": success_metrics
            })
            
            # Update adaptive thresholds based on outcome
            adaptive_scheduler.update_thresholds()
            
            # Update the "last simplification" counter
            adaptive_scheduler.last_simplification_gen = gen
            
            if verbose:
                print(f"Simplification success score: {success_metrics['success_score']:.4f}")
                print(f"Updated thresholds: {adaptive_scheduler.thresholds}")
        
        # Migration between islands
        if gen % migration_interval == 0:
            perform_migration(islands, migration_rate, toolbox)
            if verbose:
                print(f"Migration performed at generation {gen}")
        
        # Update global hall of fame
        if hall_of_fame is not None:
            for pop in islands:
                hall_of_fame.update(pop)
    
    # Find best individual across all islands
    all_individuals = [ind for pop in islands for ind in pop]
    best_individual = tools.selBest(all_individuals, 1)[0]
    
    return islands, logbooks, best_individual, adaptive_scheduler