from __future__ import annotations

import numpy as np
import hashlib
from typing import List, Optional, Dict, Tuple

from ..shared.batch_eval import batch_evaluate_tasks, evaluate_single_solver_instance, prepare_tasks


class TSPGLSSolverProblem:
    """Evaluate heuristic code using local GLS with Strategy C (separate evaluate + weighted sum).

    Uses unified prepare_tasks function from shared module.
    """

    def __init__(
        self,
        config,  
        generator_codes: List[str] = None,
        generator_ids: List[int] = None,
        generator_weights: np.ndarray = None,
        n_cities: int = None,
        n_inst_eva: int = None, 
    ):
        """Initialize TSP GLS Solver Problem.
        
        Args:
            config: HeuPSROConfig 
            generator_codes: Generator code strings 
            generator_ids: Generator IDs
            generator_weights: Generator weights 
            n_cities: Number of cities 
            n_inst_eva: Number of instances
        """
        self.config = config 
        self.generator_codes = generator_codes or []
        self.generator_ids = generator_ids or []
        self.generator_weights = generator_weights if generator_weights is not None else np.array([])
        self.n_cities = n_cities
        
        # get parameters from config, allow optional override
        self.n_inst_eva = n_inst_eva or config.eoh_eval_n_instances
        # safely get TSP specific configuration parameters (compatible with HeuPSROConfig and TSPGLSConfig)
        self.time_limit = getattr(config, 'instance_solver_time_limit', 60)  
        self.ite_max = getattr(config, 'tsp_solver_max_iterations', 1000)
        self.perturbation_moves = getattr(config, 'tsp_solver_perturbation_moves', 1)
        self.use_gap = getattr(config, 'generator_use_gap', True)
        self.gap_oracle = getattr(config, 'oracle_type', 'concorde')   
        self.gap_oracle_timeout = getattr(config, 'oracle_timeout', 30)   
        self.lkh_runs = getattr(config, 'lkh_runs', None)
        self.debug_mode = getattr(config, 'debug_mode', False)
        self.parallel_backend = getattr(config, 'parallel_backend', 'loky')
        self.parallel_prefer = getattr(config, 'parallel_prefer', 'processes')
        self.parallel_n_jobs = getattr(config, 'eval_n_jobs', -1)
        
        # Cache for instances and oracle costs (within same PSRO round)
        self._cached_instances = None  # List[List[np.ndarray]]: instances per generator
        self._cached_oracle_costs = None  # Dict[(gen_id, inst_idx): optimal_cost]
        self._cache_key = None  # Hash of generator configuration
        
        # Add prompts attribute that EoH might expect
        from .prompts import GetPrompts
        self.prompts = GetPrompts()
    
    def set_evolution_context(self, context: str = None, enabled: bool = True) -> None:
        """
        Set evolution context for PSRO-level task description.
        
        Args:
            context: Context string describing the mixed generator distribution
            enabled: Whether to use context in prompts
        """
        if hasattr(self, 'prompts') and self.prompts is not None:
            self.prompts.set_context(context, enabled)

    def evaluate(self, code_string: str, n_instances: int = None) -> float:
        """Evaluate single solver (backward compatibility)."""
        return self.evaluate_batch([code_string])[0]
    
    def evaluate_batch(self, code_strings: List[str]) -> List[float]:
        """
        批量评估多个solver代码。
        
        Args:
            code_strings: List of solver code strings [code_0, code_1, ..., code_n-1]
            
        Returns:
            List of fitness values [fitness_0, fitness_1, ..., fitness_n-1]
        
        process:
        1. call shared prepare_tasks to generate all tasks
        2. call shared batch evaluation (shared will average by instance)
        3. weight by weight and sum
        """
        n_solvers = len(code_strings)
        solver_ids = list(range(n_solvers))
        
        # Generate cache key based on generator configuration
        cache_key_data = (
            tuple(self.generator_codes),
            tuple(self.generator_ids),
            tuple(self.generator_weights.flatten() if len(self.generator_weights.shape) > 0 else [self.generator_weights]),
            self.n_inst_eva,
            self.n_cities,
            self.use_gap,
            self.gap_oracle
        )
        cache_key = hashlib.md5(str(cache_key_data).encode()).hexdigest()
        
        # Check if cache is valid
        use_cache = (self._cache_key == cache_key and 
                    self._cached_instances is not None and
                    (not self.use_gap or self._cached_oracle_costs is not None))
        
        if use_cache:
            if self.debug_mode:
                print(f"[SolverProb] use cached instances and oracle (cache_key={cache_key[:8]}...)")
            # Use cached instances and oracle
            all_tasks, _, _ = prepare_tasks(
                solver_codes=code_strings,
                solver_ids=solver_ids,
                generator_codes=self.generator_codes,
                generator_ids=self.generator_ids,
                generator_weights=self.generator_weights,
                n_instances=self.n_inst_eva,
                n_cities=self.n_cities,
                time_limit=self.time_limit,
                ite_max=self.ite_max,
                perturbation_moves=self.perturbation_moves,
                use_gap=self.use_gap,
                gap_oracle=self.gap_oracle,  # keep parameter name unchanged (prepare_tasks function signature)
                oracle_timeout=self.gap_oracle_timeout,  
                lkh_runs=self.lkh_runs,
                optimal_parallel_n_jobs=getattr(self.config, 'optimal_parallel_n_jobs', getattr(self.config, 'oracle_parallel_n_jobs', -1)),
                debug_mode=self.debug_mode,
                cached_instances=self._cached_instances,
                cached_oracle_costs=self._cached_oracle_costs,
                lkh3_path=getattr(self.config, 'lkh3_path', None),
                concorde_path=getattr(self.config, 'concorde_path', None)
            )
        else:
            if self.debug_mode:
                print(f"[SolverProb] Generating new instances and oracle (cache_key={cache_key[:8]}...)")
            # Generate new instances and oracle
            all_tasks, all_coords_by_gen, oracle_costs_dict = prepare_tasks(
                solver_codes=code_strings,
                solver_ids=solver_ids,
                generator_codes=self.generator_codes,
                generator_ids=self.generator_ids,
                generator_weights=self.generator_weights,
                n_instances=self.n_inst_eva,
                n_cities=self.n_cities,
                time_limit=self.time_limit,
                ite_max=self.ite_max,
                perturbation_moves=self.perturbation_moves,
                use_gap=self.use_gap,
                gap_oracle=self.gap_oracle,  
                oracle_timeout=self.gap_oracle_timeout,  
                lkh_runs=self.lkh_runs,
                optimal_parallel_n_jobs=getattr(self.config, 'optimal_parallel_n_jobs', getattr(self.config, 'oracle_parallel_n_jobs', -1)),
                debug_mode=self.debug_mode,
                cached_instances=None,
                cached_oracle_costs=None,
                lkh3_path=getattr(self.config, 'lkh3_path', None),
                concorde_path=getattr(self.config, 'concorde_path', None)
            )
            
            # Cache the results
            self._cached_instances = all_coords_by_gen
            self._cached_oracle_costs = oracle_costs_dict
            self._cache_key = cache_key
        
        # Step 2: call shared batch evaluation (shared will average by (solver_id, generator_id) instance)
        # calculate timeout: task count × (solver time limit + oracle time + buffer) × safety factor
        timeout_per_task = self.time_limit + self.gap_oracle_timeout + 10  # estimated time for each task
        batch_timeout = len(all_tasks) * timeout_per_task * 1.5  # 1.5x safety factor
        results_dict = batch_evaluate_tasks(
            tasks=all_tasks,
            evaluate_fn=evaluate_single_solver_instance,
            n_jobs=self.parallel_n_jobs,
            backend=self.parallel_backend,
            prefer=self.parallel_prefer,
            timeout=batch_timeout,
            debug_mode=self.debug_mode,
            track_time=True,
            time_key="solver",
            task_batch_size=getattr(self.config, 'batch_eval_task_batch_size', None)
        )
        
        # Step 3: weight by weight and sum
        # results_dict 格式：{(solver_id, generator_id): mean_gap}
        solver_fitnesses = [1e9] * n_solvers
        for solver_id in range(n_solvers):
            weighted_sum = 0.0
            for gen_idx, gen_id in enumerate(self.generator_ids):
                # get mean gap for (solver, generator) (shared already averaged by instance)
                mean_gap = results_dict.get((solver_id, gen_id), 1e9)
                weight = self.generator_weights[gen_idx] if gen_idx < len(self.generator_weights) else 0.0
                weighted_sum += weight * mean_gap
            
            solver_fitnesses[solver_id] = float(weighted_sum)
        
        return solver_fitnesses


