#!/usr/bin/env python3
"""
EoH evolution for TSP GLS solver heuristics
"""

from __future__ import annotations

import os
import sys
from typing import Callable, List, Dict, Any, Tuple
import numpy as np

# Add the project root to the path
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..'))
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)
# Add eoh/src to path for eoh module
_eoh_src = os.path.join(_project_root, 'eoh', 'src')
if _eoh_src not in sys.path:
    sys.path.insert(0, _eoh_src)

from eoh import EVOL
from eoh.utils.getParas import Paras
from .prob import TSPGLSSolverProblem

def run_solver_evolution_with_population(
    config,  
    generator_codes: List[str] | None = None,
    generator_ids: List[int] | None = None,
    generator_weights: np.ndarray | None = None,
    n_cities: int = None,
    output_dir: str = None,
    continue_from_generation: int = 0,
    evolution_context: str | None = None,
    eoh_eval_n_instances: int = None,  
) -> Tuple[str, float, str] | List[Tuple[str, float, str]]:
    """Run EoH evolution to evolve a solver heuristic with initial population.
    
    Args:
        config: HeuPSROConfig 
        generator_codes: Generator code strings 
        generator_ids: Generator IDs
        generator_weights: Generator weights 
        n_cities: Number of cities 
        output_dir: Output directory 
        continue_from_generation: Continue from this generation 
        evolution_context: Evolution context string 
        eoh_eval_n_instances: Number of instances 
        
    Returns:
        Tuple of (best_heuristic_code, best_score, best_algorithm) or list of tuples
    """
    # calculate return_top_k from config
    use_top_k = getattr(config, 'psro_use_top_k', False)
    top_k = getattr(config, 'psro_top_k', 1)
    return_top_k = top_k if use_top_k else 1
    
    # use values from config, allow optional override
    eval_n_instances = eoh_eval_n_instances or config.eoh_eval_n_instances
    
    # Create the problem instance
    if generator_codes is not None and generator_ids is not None and generator_weights is not None:
        problem = TSPGLSSolverProblem(
            config=config,
            generator_codes=generator_codes,
            generator_ids=generator_ids,
            generator_weights=generator_weights,
            n_cities=n_cities or config.n_cities,
            n_inst_eva=eval_n_instances,
        )
    else:
        raise ValueError("generator_codes, generator_ids, and generator_weights must be provided")
    
    # Set evolution context if provided
    if evolution_context is not None:
        problem.set_evolution_context(evolution_context, enabled=True)
    else:
        problem.set_evolution_context(None, enabled=False)
    
    # Set parameters - all from config
    paras = Paras()
    
    # Get solver_n_pop with backward compatibility
    solver_n_pop = getattr(config, 'solver_n_pop', getattr(config, 'n_pop', 2))
    
    paras.set_paras(
        method="eoh",
        problem=problem,
        llm_api_endpoint=config.llm_api_endpoint,
        llm_api_key=config.llm_api_key,
        llm_model=config.llm_model,
        llm_use_local=config.llm_use_local,  
        ec_operators=config.ec_operators,
        ec_pop_size=config.pop_size,
        ec_n_pop=solver_n_pop,
        ec_m=config.ec_m,
        exp_n_proc=config.exp_n_proc,
        exp_debug_mode=config.exp_debug_mode,  
        eva_numba_decorator=config.eva_numba_decorator,  
        eva_timeout=config.eoh_framework_timeout,  
        exp_output_path=os.path.abspath(output_dir) if output_dir else os.path.abspath("./results/solver_eoh"),
        management=config.eoh_management_strategy,
        llm_use_async=getattr(config, 'llm_use_async', True),
        llm_max_concurrent_requests=getattr(config, 'llm_max_concurrent_requests', 10),
        llm_rate_limit_per_minute=getattr(config, 'llm_rate_limit_per_minute', 60),
        llm_temperature=getattr(config, 'llm_temperature', 0.7),
        llm_top_p=getattr(config, 'llm_top_p', 0.95),
        diversity_threshold=getattr(config, 'eoh_diversity_threshold', 0.8),
        objective_precision=getattr(config, 'eoh_objective_precision', 1),
        max_per_objective=getattr(config, 'eoh_max_per_objective', 1)
    )
    
    # Set up continue from previous generation
    if continue_from_generation > 0:
        # Continue from previous generation
        continue_path = os.path.join(output_dir, "results", "pops", f"population_generation_{continue_from_generation}.json")
        if os.path.exists(continue_path):
            paras.exp_use_continue = True
            paras.exp_continue_path = continue_path
            paras.exp_continue_id = continue_from_generation
            # update n_pop to continue for n_pop more generations
            paras.ec_n_pop = continue_from_generation + solver_n_pop
        else:
            # if previous generation doesn't exist, this is an error
            raise FileNotFoundError(f"Previous generation {continue_from_generation} not found at {continue_path}")
    else:
        # first run: EOH will create initial population automatically
        paras.exp_use_continue = False
        paras.ec_n_pop = solver_n_pop
    
    # initialize and run evolution
    evolution = EVOL(paras)
    evolution.run()
    
    # Get the best heuristic code from the evolution results
    import json
    
    # Find the last generation file
    output_dir = paras.exp_output_path
    pops_dir = os.path.join(output_dir, "results", "pops")
    
    if os.path.exists(pops_dir):
        # Get all population files and find the last one
        pop_files = [f for f in os.listdir(pops_dir) if f.startswith("population_generation_") and f.endswith(".json")]
        if pop_files:
            # Sort by generation number
            pop_files.sort(key=lambda x: int(x.split("_")[2].split(".")[0]))
            last_pop_file = pop_files[-1]
            
            # Load the final population
            with open(os.path.join(pops_dir, last_pop_file), 'r') as f:
                final_population = json.load(f)
            
            # Get the top-k individuals (sort by objective to find the best)
            if final_population:
                # Sort by objective (lower is better for minimization)
                final_population.sort(key=lambda x: x.get('objective', float('inf')))
                
                # Return top-k individuals
                candidates = []
                for i in range(min(return_top_k, len(final_population))):
                    ind = final_population[i]
                    code = ind.get('code', '')
                    score = ind.get('objective', float('inf'))
                    algorithm = ind.get('algorithm', 'Unknown algorithm')
                    
                    if code:
                        # uniform format: (code, algorithm, params, score)
                        candidates.append((code, algorithm, {}, score))
                
                # Return single tuple for backward compatibility, or list if k>1
                if return_top_k == 1:
                    return candidates[0] if candidates else None
                else:
                    return candidates
    
    # Fallback: return a simple heuristic if no results found
    print("    Warning: No EoH results found, using fallback solver")
    best_code = (
        "def update_edge_distance(edge_distance, local_opt_tour, edge_n_used):\n"
        "    # Fallback heuristic: penalize frequently used edges\n"
        "    updated = edge_distance.copy()\n"
        "    penalty = 0.1 * edge_n_used\n"
        "    updated = updated + penalty\n"
        "    return updated\n"
    )
    
    # for fallback, we need to compute score
    fallback_score = problem.evaluate(best_code)
    print(f"    Fallback solver score: {fallback_score:.4f}")
    # uniform format: (code, algorithm, params, score)
    result = (best_code, "Fallback solver", {}, fallback_score)
    
    # save evolution object to module variable, for bridge to access time information (minimal intrusion)
    # note: only access when needed, do not modify return value format
    import sys
    sys.modules[__name__].__dict__['_last_evolution'] = evolution
    
    # Return based on return_top_k
    if return_top_k == 1:
        return result
    else:
        return [result]  # Return as list for consistency

