from __future__ import annotations

import numpy as np
from typing import List, Optional

from ..shared.batch_eval import batch_evaluate_tasks, evaluate_single_solver_instance, prepare_tasks

class TSPGLSGeneratorProblem:
    def __init__(
        self,
        config,  
        heuristics: List[str] = None,
        sigma_h: np.ndarray = None,
        n_cities: int = None,
        n_inst_eva: int = None,  
    ) -> None:
        """Initialize TSP GLS Generator Problem.
        
        Args:
            config: HeuPSROConfig
            heuristics: List of solver heuristic codes 
            sigma_h: Mixture weights 
            n_cities: Number of cities 
            n_inst_eva: Number of instances 
        """
        self.config = config  
        
        self.heuristics = heuristics or []
        self.sigma_h = np.asarray(sigma_h, dtype=float) if sigma_h is not None else np.array([1.0])
        self.sigma_h = self.sigma_h / (self.sigma_h.sum() if self.sigma_h.sum() > 0 else 1)
        self.n_cities = n_cities
        
        self.n_inst_eva = n_inst_eva or config.eoh_eval_n_instances
        self.time_limit = getattr(config, 'instance_solver_time_limit', 60)  
        self.ite_max = getattr(config, 'tsp_solver_max_iterations', 1000)
        self.perturbation_moves = getattr(config, 'tsp_solver_perturbation_moves', 1)
        self.use_gap = config.generator_use_gap
        self.gap_oracle = getattr(config, 'oracle_type', 'concorde') 
        self.gap_oracle_timeout = getattr(config, 'oracle_timeout', 30)  
        self.debug_mode = getattr(config, 'debug_mode', False)
        self.parallel_backend = getattr(config, 'parallel_backend', 'loky')
        self.parallel_prefer = getattr(config, 'parallel_prefer', 'processes')
        self.parallel_n_jobs = getattr(config, 'eval_n_jobs', -1)
        
        
        # Add prompts attribute that EoH might expect
        from .prompts import GetPrompts
        self.prompts = GetPrompts()
    
    def set_evolution_context(self, context: str = None, enabled: bool = True) -> None:
        """
        Set evolution context for PSRO-level task description.
        
        Args:
            context: Context string describing the mixed solver strategy
            enabled: Whether to use context in prompts
        """
        if hasattr(self, 'prompts') and self.prompts is not None:
            self.prompts.set_context(context, enabled)

    def evaluate(self, code_string: str, n_instances: int = None) -> float:
        """Evaluate single generator (backward compatibility)."""
        if n_instances is None:
            n_instances = self.n_inst_eva
        return self.evaluate_batch([code_string])[0]
    
    def evaluate_batch(self, code_strings: List[str]) -> List[float]:
        """
        batch evaluate multiple generator codes.
        
        Args:
            code_strings: List of generator code strings [code_0, code_1, ..., code_n-1]
            
        Returns:
            List of fitness values [fitness_0, fitness_1, ..., fitness_n-1]
        
        1. call prepare_tasks to generate all tasks (use self.heuristics as solver_codes)
        2. call shared batch evaluation (shared will average by instance)
        3. group by generator, weight by sigma_h (solver's weight)
        """
        n_generators = len(code_strings)
        n_solvers = len(self.heuristics)
        
        solver_codes = self.heuristics
        solver_ids = list(range(n_solvers))
        generator_ids = list(range(n_generators))
        # for generator evaluation, each generator weight is 1.0, finally weighted by sigma_h
        generator_weights = np.ones(n_generators)
        
        # Step 1: call prepare_tasks to generate all tasks
        # Note: For generator evaluation, we cannot cache instances because each generator code
        # generates different instances. Each call must generate new instances from the new generators.
        all_tasks, _, _ = prepare_tasks(
            solver_codes=solver_codes,
            solver_ids=solver_ids,
            generator_codes=code_strings,
            generator_ids=generator_ids,
            generator_weights=generator_weights,
            n_instances=self.n_inst_eva,
            n_cities=self.n_cities,
            time_limit=self.time_limit,
            ite_max=self.ite_max,
            perturbation_moves=self.perturbation_moves,
            use_gap=self.use_gap,
            gap_oracle="lkh3",  
            oracle_timeout=self.gap_oracle_timeout,  
            lkh_runs=getattr(self.config, 'lkh_runs', None),
            optimal_parallel_n_jobs=getattr(self.config, 'optimal_parallel_n_jobs', getattr(self.config, 'oracle_parallel_n_jobs', -1)),
            debug_mode=self.debug_mode,
            cached_instances=None,
            cached_oracle_costs=None,
            lkh3_path=getattr(self.config, 'lkh3_path', None),
            concorde_path=getattr(self.config, 'concorde_path', None)
        )
        # print("Start generator evaluation batch")
        # Step 2: call shared batch evaluation (shared will average by instance)
        timeout_per_task = self.time_limit + self.gap_oracle_timeout + 10  # estimated time for each task
        batch_timeout = len(all_tasks) * timeout_per_task * 1.5  # 1.5x safety factor
        results_dict = batch_evaluate_tasks(
            tasks=all_tasks,
            evaluate_fn=evaluate_single_solver_instance,
            n_jobs=self.parallel_n_jobs,
            backend=self.parallel_backend,
            prefer=self.parallel_prefer,
            timeout=batch_timeout,
            debug_mode=self.debug_mode,
            track_time=True,
            time_key="generator",
            task_batch_size=getattr(self.config, 'batch_eval_task_batch_size', None)
        )
        
        # Step 3: group by generator, weight by sigma_h (solver's weight)
        # results_dict format: {(solver_id, generator_id): mean_gap}
        generator_fitnesses = [1e9] * n_generators
        for gen_id in range(n_generators):
            weighted_sum = 0.0
            for solver_id in range(n_solvers):
                # get mean gap for (solver, generator) (shared already averaged by instance)
                mean_gap = results_dict.get((solver_id, gen_id), 1e9)
                weight = self.sigma_h[solver_id] if solver_id < len(self.sigma_h) else 0.0
                weighted_sum += weight * mean_gap
                
                # For generator: we want to maximize tour length/gap (harder instances)
                # But EOH minimizes objective, so we return negative values
            generator_fitnesses[gen_id] = -float(weighted_sum)
        
        return generator_fitnesses


