import random
import numpy as np
from deap import tools, gp, creator
import copy

def run_gp_with_island_model(toolbox, pset, num_islands, pop_per_island, ngen, stats, 
                           simplifier, migration_interval=10, migration_rate=0.1, 
                           simplification_interval=5, island_strategies=None,
                           hall_of_fame=None, verbose=True):
    """
    Run GP evolution with an island model and LLM simplification.
    
    Args:
        toolbox: DEAP toolbox
        pset: Primitive set
        num_islands: Number of islands (subpopulations)
        pop_per_island: Size of each island population
        ngen: Number of generations
        stats: Statistics object
        simplifier: BatchLLMSimplifier instance
        migration_interval: How often to migrate individuals between islands
        migration_rate: Proportion of population to migrate
        simplification_interval: How often to apply simplification
        island_strategies: Dict of strategies for each island (can be None for default)
        hall_of_fame: Hall of fame object
        verbose: Whether to print progress
        
    Returns:
        tuple: (final_populations, logbooks, best_individual)
    """
    # Initialize logbooks for statistics
    logbooks = [tools.Logbook() for _ in range(num_islands)]
    for logbook in logbooks:
        logbook.header = ['gen', 'nevals'] + (stats.fields if stats else [])
    
    # Set up different strategies for each island if not provided
    if island_strategies is None:
        # Default varied strategies for islands
        island_strategies = {
            i: {
                "replacement_strategy": ["tournament", "worst", "random"][i % 3],
                "mutation_rate": 0.1 + (i * 0.05),  # Different mutation rates
                "crossover_rate": 0.7 - (i * 0.05),  # Different crossover rates
                "top_n": max(3, pop_per_island // 10),  # Different simplification amounts
                "simplification_focus": ["generalization", "simplicity", "balance"][i % 3]  # Different LLM focuses
            } for i in range(num_islands)
        }
    
    # Create independent island populations
    islands = [toolbox.population(n=pop_per_island) for _ in range(num_islands)]
    
    # Island-specific halls of fame
    island_hofs = [tools.HallOfFame(3) for _ in range(num_islands)]
    
    # Evaluate all initial populations
    for i, pop in enumerate(islands):
        invalid_ind = [ind for ind in pop if not ind.fitness.valid]
        fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit
        
        # Update island hall of fame
        island_hofs[i].update(pop)
        
        # Record initial statistics
        record = stats.compile(pop) if stats else {}
        logbooks[i].record(gen=0, nevals=len(invalid_ind), **record)
        if verbose:
            print(f"Island {i} - Initial stats: {logbooks[i].stream}")
    
    # Begin evolution
    for gen in range(1, ngen + 1):
        if verbose:
            print(f"\n--- Generation {gen} ---")
        
        # Process each island
        for i, pop in enumerate(islands):
            strategy = island_strategies[i]
            
            # Select and clone the next generation individuals
            offspring = toolbox.select(pop, len(pop))
            offspring = list(map(toolbox.clone, offspring))
            
            # Apply crossover with island-specific rate
            for j in range(1, len(offspring), 2):
                if random.random() < strategy["crossover_rate"]:
                    offspring[j-1], offspring[j] = toolbox.mate(offspring[j-1], offspring[j])
                    del offspring[j-1].fitness.values, offspring[j].fitness.values
            
            # Apply mutation with island-specific rate
            for j in range(len(offspring)):
                if random.random() < strategy["mutation_rate"]:
                    offspring[j], = toolbox.mutate(offspring[j])
                    del offspring[j].fitness.values
            
            # Evaluate new individuals
            invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
            fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
            for ind, fit in zip(invalid_ind, fitnesses):
                ind.fitness.values = fit
            
            # LLM-guided simplification with island-specific focus
            if gen % simplification_interval == 0:
                # Get island-specific prompt adjustment
                simplification_focus = strategy.get("simplification_focus", "balance")
                
                # Select top individuals for simplification using island-specific parameter
                top_n = strategy.get("top_n", pop_per_island // 10)
                top_individuals = tools.selRandom(offspring, top_n)
                
                # Remove duplicates
                unique_top_individuals = []
                seen = set()
                for ind in top_individuals:
                    ind_str = str(ind)
                    # if ind_str not in seen:
                    #     seen.add(ind_str)
                    #     unique_top_individuals.append(ind)
                    unique_top_individuals.append(ind)
                
                top_individuals = unique_top_individuals
                
                if verbose:
                    print(f"Island {i} - Simplifying {len(top_individuals)} individuals with focus on {simplification_focus}")
                
                # Simplify with island-specific focus
                simplified = simplifier.batch_simplify(
                    top_individuals, 
                    pset, 
                    toolbox, 
                    batch_size=min(len(top_individuals), 20),
                    simplification_focus=simplification_focus
                )
                
                # Integrate simplified individuals using island-specific strategy
                if simplified:
                    offspring = integrate_simplified_individuals(
                        offspring, simplified, 
                        toolbox, strategy["replacement_strategy"], 
                        replacement_rate=0.4
                    )
            
            # Replace island population with offspring
            islands[i][:] = offspring
            
            # Update island hall of fame
            island_hofs[i].update(islands[i])
            
            # Record statistics
            record = stats.compile(islands[i]) if stats else {}
            logbooks[i].record(gen=gen, nevals=len(invalid_ind), **record)
            if verbose:
                print(f"Island {i} - Gen {gen}: {logbooks[i].stream}")
        
        # Migration between islands
        if gen % migration_interval == 0:
            perform_migration(islands, migration_rate, toolbox)
            if verbose:
                print(f"Migration performed at generation {gen}")
        
        # Update global hall of fame
        if hall_of_fame is not None:
            for pop in islands:
                hall_of_fame.update(pop)
    
    # Find best individual across all islands
    all_individuals = [ind for pop in islands for ind in pop]
    best_individual = tools.selBest(all_individuals, 1)[0]
    
    return islands, logbooks, best_individual

def perform_migration(islands, migration_rate, toolbox):
    """
    Migrate individuals between islands.
    
    Args:
        islands: List of island populations
        migration_rate: Proportion of population to migrate
        toolbox: DEAP toolbox for cloning
    """
    num_islands = len(islands)
    if num_islands < 2:
        return
    
    # Determine number of migrants
    migrants_per_island = max(1, int(len(islands[0]) * migration_rate))
    
    # Prepare migrants from each island (best individuals)
    migrants = []
    for pop in islands:
        sorted_pop = sorted(pop, key=lambda ind: ind.fitness.values[0])
        island_migrants = [toolbox.clone(ind) for ind in sorted_pop[:migrants_per_island]]
        migrants.append(island_migrants)
    
    # Ring topology migration (island i sends to island i+1)
    for i in range(num_islands):
        # Source island and destination island
        source = i
        dest = (i + 1) % num_islands
        
        # Replace worst individuals in destination with migrants from source
        dest_pop = islands[dest]
        sorted_indices = sorted(range(len(dest_pop)), 
                               key=lambda idx: dest_pop[idx].fitness.values[0],
                               reverse=True)  # Reverse for worst individuals
        
        # Replace worst individuals with migrants
        for j in range(migrants_per_island):
            if j < len(sorted_indices):
                # Clone the migrant to avoid sharing references between islands
                dest_pop[sorted_indices[j]] = toolbox.clone(migrants[source][j])

def integrate_simplified_individuals(population, simplified_individuals, toolbox, 
                                  replacement_strategy="worst", replacement_rate=0.1):
    """
    Integrate simplified individuals into the population.
    
    Args:
        population: Current population
        simplified_individuals: List of simplified individuals to integrate
        toolbox: DEAP toolbox for genetic operations
        replacement_strategy: How to choose which individuals to replace
        replacement_rate: Percentage of population to potentially replace
        
    Returns:
        Updated population
    """
    if not simplified_individuals:
        return population
    
    # Clone population to avoid modifying the original
    pop = population[:]
    
    # Determine how many individuals to replace
    num_to_replace = min(len(simplified_individuals), max(1, int(len(pop) * replacement_rate)))
    
    # Choose which individuals to replace
    if replacement_strategy == "worst":
        # Replace the worst individuals
        sorted_indices = sorted(range(len(pop)), key=lambda i: pop[i].fitness.values[0], reverse=True)
        indices_to_replace = sorted_indices[:num_to_replace]
    
    elif replacement_strategy == "best":
        # Replace the best individuals
        sorted_indices = sorted(range(len(pop)), key=lambda i: pop[i].fitness.values[0], reverse=False)
        indices_to_replace = sorted_indices[:num_to_replace]
    
    elif replacement_strategy == "random":
        # Replace random individuals
        indices_to_replace = random.sample(range(len(pop)), num_to_replace)
    
    elif replacement_strategy == "tournament":
        # Replace tournament losers
        indices_to_replace = set()  # Use a set to avoid duplicates
        attempts = 0
        max_attempts = num_to_replace * 3  # Prevent infinite loop
        
        while len(indices_to_replace) < num_to_replace and attempts < max_attempts:
            tournament_size = min(3, len(pop) - len(indices_to_replace))
            # Only sample from indices not already selected for replacement
            available_indices = list(set(range(len(pop))) - indices_to_replace)
            if len(available_indices) < tournament_size:
                break
                
            participants = random.sample(available_indices, tournament_size)
            worst = max(participants, key=lambda i: pop[i].fitness.values[0])
            indices_to_replace.add(worst)
            attempts += 1
            
        indices_to_replace = list(indices_to_replace)
        
        # If we couldn't find enough unique individuals through tournaments,
        # fill remaining slots with worst individuals not already selected
        if len(indices_to_replace) < num_to_replace:
            remaining_indices = list(set(range(len(pop))) - set(indices_to_replace))
            remaining_sorted = sorted(remaining_indices, 
                                   key=lambda i: pop[i].fitness.values[0],
                                   reverse=True)  # Reverse for worst individuals
            indices_to_replace.extend(
                remaining_sorted[:num_to_replace - len(indices_to_replace)]
            )
    
    else:
        # Default to worst
        sorted_indices = sorted(range(len(pop)), key=lambda i: pop[i].fitness.values[0], reverse=True)
        indices_to_replace = sorted_indices[:num_to_replace]
    
    # Replace chosen individuals with simplified ones
    for i, idx in enumerate(indices_to_replace):
        if i < len(simplified_individuals):
            simplified_ind = simplified_individuals[i]
            
            # Skip empty or invalid individuals
            if not simplified_ind or len(simplified_ind) == 0:
                print("Skipping empty or invalid individual")
                continue
                
            # print simplified individual and the replaced individual
            print(f"Simplified individual: {simplified_ind}")
            print(f"Replaced individual: {pop[idx]}")

            pop[idx] = simplified_ind


            # Create variation of the simplified individual
            # try:
            #     variation = toolbox.clone(simplified_ind)
                
            #     # 50% chance to skip variation
            #     if random.random() < 0.5:
            #         continue
                
            #     # Apply either mutation or crossover
            #     if random.random() < 0.7:  # 70% chance of mutation
            #         variation, = toolbox.mutate(variation)
            #     else:  # 30% chance of crossover
            #         # Find a good partner for crossover
            #         valid_indices = [j for j in range(len(pop)) if j != idx]
            #         if valid_indices:
            #             partner_idx = random.choice(valid_indices)
            #             variation, _ = toolbox.mate(variation, toolbox.clone(pop[partner_idx]))
                
            #     # Evaluate the variation
            #     del variation.fitness.values
            #     variation.fitness.values = toolbox.evaluate(variation)
            #     pop[idx] = variation
            # except Exception as e:
            #     print(f"Error creating variation: {e}")
            #     continue
    
    return pop