import argparse
import random
import time
import numpy as np
import math
import operator
from deap import algorithms, base, creator, tools, gp
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import sklearn.model_selection
from collections import OrderedDict, defaultdict

# Import our modules
from island_model import run_gp_with_island_model, integrate_simplified_individuals
from llm_simplifier import BatchLLMSimplifier
from adaptive_scheduling import AdaptiveSimplificationScheduler, run_gp_with_island_model_adaptive
from datasets import get_dataset, get_all_dataset_names

def protected_div(left, right):
    """Protected division function to avoid division by zero."""
    if abs(right) < 1e-4:
        return 1.0
    return left / right

def protected_exp(x):
    """Protected exponential function to avoid overflow."""
    return math.exp(min(x, 10))

def eval_symbolic_regression(individual, points, pset, target_values):
    """
    Evaluate an individual on regression task.
    
    Args:
        individual: GP individual to evaluate
        points: Input points for evaluation
        pset: Primitive set
        target_values: Pre-computed target values
        
    Returns:
        tuple: (mean_squared_error,)
    """
    # Transform the tree expression to a callable function
    func = gp.compile(expr=individual, pset=pset)
    
    # Evaluate the MSE between the individual and the target
    try:
        # Calculate squared errors
        squared_errors = []
        for i, point in enumerate(points):
            try:
                # Handle multivariate vs univariate inputs
                multi_var = hasattr(point, '__len__') and len(point) > 1
                if multi_var:
                    pred = func(*point)
                else:
                    pred = func(point)
                    
                squared_error = (pred - target_values[i]) ** 2
                squared_errors.append(squared_error)
            except Exception:
                # Skip this point if evaluation fails
                pass
        
        # If we have no valid points, return a poor fitness
        if not squared_errors:
            return float('inf'),
                
        # Return mean squared error as fitness (lower is better)
        return np.mean(squared_errors),
    
    except Exception:
        # Penalize individuals that raise exceptions
        return float('inf'),

def setup_gp_system(num_inputs=1):
    """
    Set up the DEAP GP framework.
    
    Args:
        num_inputs: Number of input variables
        
    Returns:
        tuple: (toolbox, pset)
    """
    # Clear any existing types to avoid conflicts when re-running
    if 'FitnessMin' in creator.__dict__:
        del creator.FitnessMin
    if 'Individual' in creator.__dict__:
        del creator.Individual
        
    # Define fitness and individual types
    creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
    creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
    
    # Define primitive set
    pset = gp.PrimitiveSet("MAIN", num_inputs)
    pset.addPrimitive(operator.add, 2, name="add")
    pset.addPrimitive(operator.sub, 2, name="sub")
    pset.addPrimitive(operator.mul, 2, name="mul")
    pset.addPrimitive(protected_div, 2, name="div")
    pset.addPrimitive(operator.neg, 1, name="neg")
    pset.addPrimitive(math.sin, 1, name="sin")
    pset.addPrimitive(math.cos, 1, name="cos")
    pset.addPrimitive(protected_exp, 1, name="exp")
    
    # Add ephemeral constants
    pset.addEphemeralConstant("rand101", lambda: random.uniform(-5.0, 5.0))
    
    # Rename the arguments
    if num_inputs == 1:
        pset.renameArguments(ARG0='x')
    else:
        for i in range(num_inputs):
            pset.renameArguments(**{f'ARG{i}': f'x{i+1}'})
    
    # Create toolbox
    toolbox = base.Toolbox()
    
    # Register tree generation functions
    toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=4)
    toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)
    
    # Register genetic operators
    toolbox.register("select", tools.selTournament, tournsize=3)
    toolbox.register("mate", gp.cxOnePoint)
    toolbox.register("expr_mut", gp.genFull, min_=0, max_=2)
    toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr_mut, pset=pset)
    
    # Set limits on tree size
    toolbox.decorate("mate", gp.staticLimit(key=operator.attrgetter("height"), max_value=17))
    toolbox.decorate("mutate", gp.staticLimit(key=operator.attrgetter("height"), max_value=17))
    
    return toolbox, pset

def run_island_experiment(target_func_name, 
                         api_key,
                         num_islands=3,
                         pop_per_island=100,
                         num_generations=50,
                         num_points=1000,
                         migration_interval=5,
                         migration_rate=0.1,
                         simplification_interval=3,
                         random_seed=42,
                         display_plots=True,
                         use_adaptive=False,
                         learning_rate=0.1,
                         llm_provider="openai",
                         llm_model="gpt-4",
                         same_prompt=False):
    """
    Run symbolic regression with island model and LLM simplification.
    """
    # Set random seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    # Get dataset
    X, y = get_dataset(target_func_name, num_points=num_points, random_seed=random_seed)
    
    # Get number of variables from input shape
    num_variables = X.shape[1] if len(X.shape) > 1 else 1
    
    # Split into train/test sets
    x_train, x_test, y_train, y_actual = sklearn.model_selection.train_test_split(
        X, y, test_size=0.2, random_state=random_seed
    )
    
    # Sort for better visualization if univariate
    if num_variables == 1:
        x_train_sorted, y_train_sorted = zip(*sorted(zip(x_train, y_train)))
        x_test_sorted, y_actual_sorted = zip(*sorted(zip(x_test, y_actual)))
    else:
        x_train_sorted, y_train_sorted = x_train, y_train
        x_test_sorted, y_actual_sorted = x_test, y_actual
    
    # Set up DEAP toolbox
    toolbox, pset = setup_gp_system(num_inputs=num_variables)
    
    # Register the fitness evaluation function
    toolbox.register("evaluate", eval_symbolic_regression, 
                   points=x_train, pset=pset, target_values=y_train)
    
    # Set up statistics
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)
    
    # Initialize hall of fame
    hof = tools.HallOfFame(5)
    
    # Create LLM simplifier with provider configuration
    simplifier = BatchLLMSimplifier(
        api_key,
        model=llm_model,
        provider=llm_provider
    )
    
    # Set problem description for context
    if num_variables == 1:
        context_desc = (f"Symbolic regression for finding a mathematical expression "
                       f"that fits data from the {target_func_name} dataset.")
    else:
        context_desc = (f"Multivariate symbolic regression for finding a mathematical "
                       f"expression that fits data from the {target_func_name} dataset "
                       f"with {num_variables} input variables.")
    
    simplifier.set_problem_context(context_desc)
    
    # Define different strategies for each island
    island_strategies = {}
    for i in range(num_islands):
        # Assign different focuses and strategies to islands
        if i % 3 == 0:
            focus = "generalization"
            strategy = "worst"
            mutation_rate = 0.2
            crossover_rate = 0.6
        elif i % 3 == 1:
            focus = "simplicity"
            strategy = "random"
            mutation_rate = 0.1
            crossover_rate = 0.7
        else:
            focus = "balance"
            strategy = "tournament"
            mutation_rate = 0.15
            crossover_rate = 0.65
        
        island_strategies[i] = {
            "simplification_focus": focus,
            "replacement_strategy": strategy,
            "mutation_rate": mutation_rate,
            "crossover_rate": crossover_rate,
            "top_n": max(3, pop_per_island // 5)
        }
    
    print(f"\n=== Running Island Model Experiment on {target_func_name} ===")
    print(f"Number of islands: {num_islands}")
    print(f"Population per island: {pop_per_island}")
    print(f"Total population: {num_islands * pop_per_island}")
    print(f"Generations: {num_generations}")
    print(f"Migration interval: {migration_interval}")
    
    # Run island model
    if use_adaptive:
        print(f"Using adaptive simplification scheduling (learning rate: {learning_rate})")
        # Initialize adaptive scheduler
        adaptive_scheduler = AdaptiveSimplificationScheduler()
        
        # Run island model with adaptive scheduling
        islands, logbooks, best, scheduler = run_gp_with_island_model_adaptive(
            toolbox=toolbox,
            pset=pset,
            num_islands=num_islands,
            pop_per_island=pop_per_island,
            ngen=num_generations,
            stats=stats,
            simplifier=simplifier,
            migration_interval=migration_interval,
            migration_rate=migration_rate,
            island_strategies=island_strategies,
            hall_of_fame=hof,
            verbose=True,
            adaptive_scheduler=adaptive_scheduler,
            same_prompt=same_prompt
        )
        
        # Log adaptive scheduling info
        print("\n=== Adaptive Scheduling Summary ===")
        print(f"Final thresholds: {scheduler.thresholds}")
        simplification_count = len(scheduler.simplification_history)
        print(f"Total simplifications: {simplification_count}")
        if simplification_count > 0:
            triggers = {}
            for event in scheduler.simplification_history:
                for trigger in event["triggers"]:
                    triggers[trigger] = triggers.get(trigger, 0) + 1
            print(f"Trigger breakdown: {triggers}")
            
            avg_success = np.mean([event["success_metrics"]["success_score"] 
                                 for event in scheduler.simplification_history])
            print(f"Average success score: {avg_success:.4f}")
    else:
        print(f"Using fixed simplification interval: {simplification_interval}")
        # Run the original island model
        islands, logbooks, best = run_gp_with_island_model(
            toolbox=toolbox,
            pset=pset,
            num_islands=num_islands,
            pop_per_island=pop_per_island,
            ngen=num_generations,
            stats=stats,
            simplifier=simplifier,
            migration_interval=migration_interval,
            migration_rate=migration_rate,
            simplification_interval=simplification_interval,
            island_strategies=island_strategies,
            hall_of_fame=hof,
            verbose=True
        )
        scheduler = None
    
    # Extract and compile the best individual
    best_func = gp.compile(best, pset)
    
    # Generate test predictions
    if num_variables == 1:
        # Univariate case
        y_pred = [best_func(x) for x in x_test_sorted]
    else:
        # Multivariate case
        y_pred = [best_func(*x) for x in x_test_sorted]
    
    # Calculate MSE and correlation
    mse = np.mean([(a - p) ** 2 for a, p in zip(y_actual_sorted, y_pred)])
    corr, _ = pearsonr(y_actual_sorted, y_pred)
    
    print("\n=== Results ===")
    print(f"Best individual: {best}")
    print(f"Length: {len(best)}")
    print(f"Test MSE: {mse:.4e}")
    print(f"Correlation: {corr:.4f}")
    
    if display_plots:
        from adaptive_visualization import (
            plot_adaptive_scheduling_metrics, 
            plot_simplification_triggers,
            plot_threshold_evolution
        )
        
        # Plot adaptive scheduling metrics
        plot_adaptive_scheduling_metrics(scheduler, logbooks)
        plot_simplification_triggers(scheduler)
        plot_threshold_evolution(scheduler)
        
        # Plot results
        if num_variables == 1:
            # For univariate functions
            plt.figure(figsize=(10, 6))
            plt.plot(x_test_sorted, y_actual_sorted, 'b-', label='Actual')
            plt.plot(x_test_sorted, y_pred, 'r--', label='Predicted')
            plt.title(f'Symbolic Regression - {target_func_name}\n{best}')
            plt.xlabel('x')
            plt.ylabel('f(x)')
            plt.legend()
            plt.grid(True)
            plt.show()
        else:
            # For multivariate functions
            plt.figure(figsize=(10, 6))
            plt.scatter(y_actual_sorted, y_pred, alpha=0.5)
            min_val = min(min(y_actual_sorted), min(y_pred))
            max_val = max(max(y_actual_sorted), max(y_pred))
            plt.plot([min_val, max_val], [min_val, max_val], 'r--')
            plt.title(f'Actual vs Predicted - {target_func_name}\n{best}')
            plt.xlabel('Actual Values')
            plt.ylabel('Predicted Values')
            plt.grid(True)
            plt.show()
            
        # Plot evolution of best fitness across islands
        plt.figure(figsize=(12, 6))
        for i, logbook in enumerate(logbooks):
            gen = logbook.select("gen")
            fit_mins = logbook.select("min")
            plt.plot(gen, fit_mins, label=f"Island {i}")
        
        plt.title('Evolution of Minimum Fitness Across Islands')
        plt.xlabel('Generation')
        plt.ylabel('Fitness (MSE)')
        plt.legend()
        plt.grid(True)
        plt.show()
    
    return best, best_func, mse, corr, scheduler

def batch_experiment(target_func_names, api_key, num_runs=3, use_adaptive=False, experiment_name="", **kwargs):
    """
    Run multiple experiments and collect statistics.
    """
    results = {}
    filename = f'results/results_{experiment_name}_{time.strftime("%Y%m%d_%H%M%S")}_{kwargs["llm_model"].replace("/", "_")}.txt'
    # Initialize results file with header
    with open(filename, 'w') as f:
        f.write("ISLAND MODEL EXPERIMENT RESULTS\n")
        f.write("==============================\n\n")
        f.write(f"Using adaptive scheduling: {use_adaptive}\n\n")
    
    for func_name in target_func_names:
        print(f"\n\n{'='*50}")
        print(f"RUNNING EXPERIMENTS FOR {func_name.upper()}")
        print(f"{'='*50}")
        
        func_results = defaultdict(list)
        adaptive_stats = defaultdict(list) if use_adaptive else None
        
        for run in range(num_runs):
            print(f"\nRun {run+1}/{num_runs} for {func_name}")
            
            try:
                # Run with a different random seed each time
                seed = 42 + run
                best, best_func, mse, corr, scheduler = run_island_experiment(
                    target_func_name=func_name,
                    api_key=api_key,
                    random_seed=seed,
                    use_adaptive=use_adaptive,
                    **kwargs
                )
                
                # Record results
                func_results["best_expressions"].append(str(best))
                func_results["expression_lengths"].append(len(best))
                func_results["mse"].append(mse)
                func_results["correlation"].append(corr)
                
                # Record adaptive stats if applicable
                if use_adaptive and scheduler:
                    adaptive_stats["simplification_count"].append(len(scheduler.simplification_history))
                    if scheduler.simplification_history:
                        adaptive_stats["success_scores"].append(
                            np.mean([e["success_metrics"]["success_score"] 
                                   for e in scheduler.simplification_history])
                        )
                        # Count trigger types
                        trigger_counts = {}
                        for event in scheduler.simplification_history:
                            for trigger in event["triggers"]:
                                trigger_counts[trigger] = trigger_counts.get(trigger, 0) + 1
                        adaptive_stats["triggers"].append(trigger_counts)
                
                print(f"Run {run+1} completed - MSE: {mse:.4e}, Corr: {corr:.4f}")
                
            except Exception as e:
                print(f"Error in run {run+1}: {e}")
        
        # Calculate statistics
        if func_results["mse"]:
            results[func_name] = {
                "avg_mse": np.mean(func_results["mse"]),
                "std_mse": np.std(func_results["mse"]),
                "avg_corr": np.mean(func_results["correlation"]),
                "avg_length": np.mean(func_results["expression_lengths"]),
                "best_mse": min(func_results["mse"]),
                "best_expr": func_results["best_expressions"][np.argmin(func_results["mse"])]
            }
            
            # Add adaptive statistics if applicable
            if use_adaptive and adaptive_stats:
                results[func_name]["avg_simplifications"] = np.mean(adaptive_stats["simplification_count"])
                if adaptive_stats["success_scores"]:
                    results[func_name]["avg_success_score"] = np.mean(adaptive_stats["success_scores"])
                
                # Aggregate trigger counts
                if adaptive_stats["triggers"]:
                    all_triggers = {}
                    for t_dict in adaptive_stats["triggers"]:
                        for trigger, count in t_dict.items():
                            all_triggers[trigger] = all_triggers.get(trigger, 0) + count
                    results[func_name]["trigger_counts"] = all_triggers
            
            print(f"\nResults for {func_name}:")
            print(f"Average MSE: {results[func_name]['avg_mse']:.4e} ± {results[func_name]['std_mse']:.4e}")
            print(f"Average correlation: {results[func_name]['avg_corr']:.4f}")
            print(f"Average expression length: {results[func_name]['avg_length']:.1f}")
            print(f"Best expression: {results[func_name]['best_expr']}")
            print(f"Best MSE: {results[func_name]['best_mse']:.4e}")
            
            # Append results for this function to the file
            with open(filename, 'a') as f:
                f.write(f"LLM Provider: {kwargs['llm_provider']}\n")
                f.write(f"LLM Model: {kwargs['llm_model']}\n")
                f.write(f"Function: {func_name}\n")
                f.write(f"Average MSE: {results[func_name]['avg_mse']:.4e} ± {results[func_name]['std_mse']:.4e}\n")
                f.write(f"Average correlation: {results[func_name]['avg_corr']:.4f}\n")
                f.write(f"Average expression length: {results[func_name]['avg_length']:.1f}\n")
                f.write(f"Best expression: {results[func_name]['best_expr']}\n")
                f.write(f"Best MSE: {results[func_name]['best_mse']:.4e}\n")
                
                if use_adaptive and adaptive_stats:
                    f.write(f"Average simplifications: {results[func_name].get('avg_simplifications', 0):.1f}\n")
                    if "avg_success_score" in results[func_name]:
                        f.write(f"Average success score: {results[func_name]['avg_success_score']:.4f}\n")
                    if "trigger_counts" in results[func_name]:
                        f.write(f"Trigger counts: {results[func_name]['trigger_counts']}\n")
                
                f.write("\n------------------------------\n\n")
                f.flush()
        else:
            print(f"No valid results for {func_name}")
            # Log failed function in results file
            with open(filename, 'a') as f:
                f.write(f"Function: {func_name}\n")
                f.write("No valid results obtained\n\n")
                f.write("------------------------------\n\n")
                f.flush()
    
    return results

def main():
    """Main function to parse arguments and run experiments."""
    parser = argparse.ArgumentParser(description='Symbolic Regression with Island Model and LLM Simplification')
    parser.add_argument('--api_key', type=str, required=True, help='OpenAI API key')
    parser.add_argument('--function', type=str, default='korns_11', help='Target function name')
    parser.add_argument('--islands', type=int, default=3, help='Number of islands')
    parser.add_argument('--pop_per_island', type=int, default=50, help='Population size per island')
    parser.add_argument('--generations', type=int, default=50, help='Number of generations')
    parser.add_argument('--migration_interval', type=int, default=5, help='Migration interval')
    parser.add_argument('--simplification_interval', type=int, default=5, help='Simplification interval')
    parser.add_argument('--batch', action='store_true', help='Run batch experiments on multiple functions')
    parser.add_argument('--num_runs', type=int, default=3, help='Number of runs for batch experiments')
    parser.add_argument('--no_plot', action='store_true', help='Disable plotting')
    parser.add_argument('--adaptive', action='store_true', help='Use adaptive simplification scheduling')
    parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate for adaptive scheduling')
    parser.add_argument('--llm_provider', type=str, default='openai', choices=['openai', 'openrouter'],
                      help='LLM provider to use')
    parser.add_argument('--llm_model', type=str, default='gpt-4o-mini',
                      help='Model identifier for the chosen provider')
    parser.add_argument('--experiment_name', type=str, default='',
                      help='Name of the experiment')
    parser.add_argument('--same_prompt', action='store_true', help='Use the same prompt for all functions')
    parser.add_argument('--function_type', type=str, default='synthetic', choices=['synthetic', 'real', 'all'],
                      help='Type of functions to run')
    args = parser.parse_args()
    
    if args.batch:
        # Get all available dataset names
        available_datasets = get_all_dataset_names(args.function_type)
        results = batch_experiment(
            target_func_names=available_datasets,
            api_key=args.api_key,
            num_runs=args.num_runs,
            use_adaptive=args.adaptive,
            learning_rate=args.learning_rate,
            num_islands=args.islands,
            pop_per_island=args.pop_per_island,
            num_generations=args.generations,
            migration_interval=args.migration_interval,
            simplification_interval=args.simplification_interval,
            display_plots=not args.no_plot,
            llm_provider=args.llm_provider,
            llm_model=args.llm_model,
            experiment_name=args.experiment_name,
            same_prompt=args.same_prompt
        )
        
        print("\nExperiment results saved to island_model_results.txt")
        
    else:
        best, best_func, mse, corr, scheduler = run_island_experiment(
            target_func_name=args.function,
            api_key=args.api_key,
            num_islands=args.islands,
            pop_per_island=args.pop_per_island,
            num_generations=args.generations,
            migration_interval=args.migration_interval,
            simplification_interval=args.simplification_interval,
            display_plots=not args.no_plot,
            use_adaptive=args.adaptive,
            learning_rate=args.learning_rate,
            llm_provider=args.llm_provider,
            llm_model=args.llm_model,
            experiment_name=args.experiment_name
        )

if __name__ == "__main__":
    main()