"""TSP-specific evaluation functions for PSRO.

This module contains TSP-specific evaluation logic that can be used
by the PSRO controller. Functions are pure functions that receive
dependencies as parameters to avoid tight coupling.
"""

from __future__ import annotations

from typing import List, Dict, Optional, Any, Tuple
import numpy as np

from ...core.config import HeuPSROConfig


def compute_gap_with_oracle(
    code: str,
    instances: List[np.ndarray],
    n_cities: int,
    gap_oracle: str,  
    gap_oracle_timeout: int,  
    tsp_solver_time_limit: int,  
    tsp_solver_max_iterations: int,
    tsp_solver_perturbation_moves: int,
    utility_cache: Optional[Dict[str, Any]] = None,
    config: Optional[HeuPSROConfig] = None
) -> float:
    """
    unified solver evaluation function: similar to evaluate_single_solver_instance, directly use solve_instance.
    
    Args:
        code: Solver code string
        instances: TSP instance list
        n_cities: number of cities
        gap_oracle: Oracle type ("lkh3", "concorde", "none")
        gap_oracle_timeout: Oracle timeout (seconds)
        tsp_solver_time_limit: solver time limit
        tsp_solver_max_iterations: maximum iterations
        tsp_solver_perturbation_moves: perturbation moves
        utility_cache: optional cache dictionary
        config: optional config object
    Returns:
        Gap percentage (if using oracle) or tour distance (if not using oracle)
    """
    import types
    from .evolution.shared.gls.gls_run import solve_instance
    from .evolution.shared.eval_utils import coords_to_matrix
    
    if config is None:
        from .config import TSPGLSConfig
        config = TSPGLSConfig(
            n_cities=n_cities,
            oracle_type=gap_oracle,  
            oracle_timeout=gap_oracle_timeout,  
            instance_solver_time_limit=tsp_solver_time_limit,  
            tsp_solver_max_iterations=tsp_solver_max_iterations,
            tsp_solver_perturbation_moves=tsp_solver_perturbation_moves,
            eoh_eval_n_instances=len(instances),
            generator_use_gap=(gap_oracle != "none"),
        )
    
    try:
        heuristic_module = types.ModuleType("heuristic_module")
        exec(code, heuristic_module.__dict__)

        if gap_oracle == "none":
            return float('inf')

        from .oracle import create_tsp_oracle
        if config is not None:
            oracle = create_tsp_oracle(
                config=config,
                oracle_type=gap_oracle,  
                oracle_timeout=gap_oracle_timeout,  
            )
        else:
            oracle_kwargs = {
                'oracle_timeout': gap_oracle_timeout,
            }
            oracle = create_tsp_oracle(oracle_type=gap_oracle, **oracle_kwargs)
        gaps = []

        for coords in instances:
            dis_matrix = coords_to_matrix(coords)

            optimal_cost = oracle.solve_exact(coords)
            opt_cost = float(optimal_cost)

            result = solve_instance(
                -1,
                opt_cost,
                dis_matrix,
                coords,
                int(tsp_solver_time_limit),
                int(tsp_solver_max_iterations),
                int(tsp_solver_perturbation_moves),
                heuristic_module
            )

            gaps.append(result)

        mean_gap = float(np.mean(gaps)) if gaps else float('inf')
        if np.isnan(mean_gap) or np.isinf(mean_gap):
            mean_gap = float('inf')

        if utility_cache is not None:
            cache_key = f"{code[:10]}_{hash(str(instances))}"
            utility_cache[f"gap_{cache_key}"] = mean_gap

        return mean_gap

    except Exception as e:
        print(f"      Evaluation failed: {e}")
        return float('inf')



import types
import numpy as np
from typing import List, Dict, Tuple, Optional, Union
from .evolution.shared.batch_eval import batch_evaluate_tasks, evaluate_single_solver_instance
from ...core.config import HeuPSROConfig

def evaluate_solvers_on_instances_with_optcosts(
    solver_codes: Union[str, List[str]],
    instances: Union[np.ndarray, List[np.ndarray]],
    optimal_costs: List[Optional[float]],
    config: HeuPSROConfig,
    return_format: str = "per_solver"  # "per_solver" | "raw"
) -> Union[Dict[int, Tuple[Optional[float], Optional[float]]],
           Dict[Tuple[int, int], Tuple[Optional[float], Optional[float]]]]:
    """
    evaluate multiple solvers on multiple instances (requires providing optimal_cost for each instance).
    - use instance index as generator_id, avoid being averaged in batch evaluation.
    - output (mean_cost, mean_gap) for each solver, or raw {(solver_id, inst_idx): (cost, gap)}.
    """
    solver_list: List[str] = [solver_codes] if isinstance(solver_codes, str) else list(solver_codes)
    inst_list: List[np.ndarray] = [instances] if isinstance(instances, np.ndarray) else list(instances)
    if len(optimal_costs) != len(inst_list):
        raise ValueError(f"optimal_costs length ({len(optimal_costs)}) must match instances length ({len(inst_list)})")

    solver_modules: List[Optional[types.ModuleType]] = []
    for idx, code in enumerate(solver_list):
        m = types.ModuleType(f"heuristic_module_{idx}")
        exec(code, m.__dict__)
        solver_modules.append(m)

    tsp_solver_time_limit = getattr(config, 'instance_solver_time_limit', 60)  
    tsp_solver_max_iterations = getattr(config, 'tsp_solver_max_iterations', 1000)
    tsp_solver_perturbation_moves = getattr(config, 'tsp_solver_perturbation_moves', 1)
    gap_oracle_timeout = getattr(config, 'oracle_timeout', 30)  
    
    tasks = []
    for s_idx, mod in enumerate(solver_modules):
        if mod is None:
            continue
        for inst_idx, coords in enumerate(inst_list):
            opt = optimal_costs[inst_idx]
            tasks.append((
                s_idx,
                inst_idx,  
                coords,
                opt,
                1.0,
                tsp_solver_time_limit,
                tsp_solver_max_iterations,
                tsp_solver_perturbation_moves,
                mod
            ))

    if not tasks:
        if return_format == "raw":
            return {}
        return {i: (None, None) for i in range(len(solver_list))}

    timeout_per_task = tsp_solver_time_limit + gap_oracle_timeout + 10
    batch_timeout = len(tasks) * timeout_per_task * 1.5 if tasks else None
    
    if len(tasks) > 0:
        n_solvers = len(solver_list)
        n_instances = len(inst_list)
        n_jobs = getattr(config, 'eval_n_jobs', -1)
        
        if n_jobs == -1:
            try:
                from joblib import effective_n_jobs
                effective_workers = effective_n_jobs(-1)
            except:
                import os
                effective_workers = os.cpu_count() or 1
        else:
            effective_workers = n_jobs
        
        if effective_workers > 0:
            estimated_time_parallel = (len(tasks) / effective_workers) * timeout_per_task / 60.0  # minutes
            estimated_time_serial = len(tasks) * timeout_per_task / 60.0  # minutes (for reference)
        else:
            estimated_time_parallel = len(tasks) * timeout_per_task / 60.0
            estimated_time_serial = estimated_time_parallel
        
        print(f"   Batch evaluating {n_solvers} solvers × {n_instances} instances = {len(tasks)} tasks")
        print(f"     Estimated time: ~{estimated_time_parallel:.1f} minutes (parallel, {effective_workers} workers)")
        print(f"     Serial estimate: ~{estimated_time_serial:.1f} minutes (if sequential)")
        print(f"     Timeout: {batch_timeout/60:.1f} min | Using {n_jobs} workers ({config.parallel_backend}/{config.parallel_prefer})")
    
    results = batch_evaluate_tasks(
        tasks=tasks,
        evaluate_fn=evaluate_single_solver_instance,
        n_jobs=getattr(config, 'eval_n_jobs', -1),
        backend=config.parallel_backend,
        prefer=config.parallel_prefer,
        timeout=batch_timeout,
        debug_mode=config.debug_mode,
        track_time=False,
        time_key="solver",
        task_batch_size=getattr(config, 'batch_eval_task_batch_size', None)
    )
    
    if len(tasks) > 0:
        print(f"   Batch evaluation complete: {len(results)} results")

    if return_format == "raw":
        raw: Dict[Tuple[int, int], Tuple[Optional[float], Optional[float]]] = {}
        for (s_id, inst_idx), gap_val in results.items():
            opt = optimal_costs[inst_idx]
            if opt is None or np.isnan(opt) or not np.isfinite(gap_val) or gap_val >= 1e8:
                raw[(s_id, inst_idx)] = (None, None)
            else:
                cost = float(opt) * (1.0 + float(gap_val) / 100.0)
                raw[(s_id, inst_idx)] = (float(cost), float(gap_val))
        return raw

    per_solver: Dict[int, Tuple[Optional[float], Optional[float]]] = {}
    n_solvers = len(solver_list)
    for s_id in range(n_solvers):
        costs, gaps = [], []
        for inst_idx in range(len(inst_list)):
            key = (s_id, inst_idx)
            if key not in results:
                continue
            gap_val = results[key]
            opt = optimal_costs[inst_idx]
            if opt is not None and not np.isnan(opt) and np.isfinite(gap_val) and gap_val < 1e8:
                costs.append(float(opt) * (1.0 + float(gap_val) / 100.0))
                gaps.append(float(gap_val))
        mean_cost = float(np.mean(costs)) if costs else None
        mean_gap = float(np.mean(gaps)) if gaps else None
        per_solver[s_id] = (mean_cost, mean_gap)
    return per_solver