"""CVRP-specific evaluation functions for PSRO.

This module contains CVRP-specific evaluation logic that can be used
by the PSRO controller. Functions are pure functions that receive
dependencies as parameters to avoid tight coupling.
"""

from __future__ import annotations

from typing import List, Dict, Optional, Any, Tuple, Union
import numpy as np
import types
import warnings

from ...core.config import HeuPSROConfig
from .evolution.shared.solver.solve_instance import solve_instance, nearest_neighbor_fallback


def prepare_exec_namespace() -> Dict[str, Any]:
    """
    Prepare a namespace for exec() that includes numpy.
    
    Returns:
        Dictionary with numpy and other standard imports
    """
    namespace = {
        '__builtins__': __builtins__,
        'numpy': np,
        'np': np,
    }
    return namespace


def evaluate_solver_on_instance(
    code: str,
    instance: Dict,
    oracle_value: Optional[float],
    time_limit: Optional[float] = None,
    max_iterations: int = 1000,
    max_stagnation: int = 10,
    debug_mode: bool = False
) -> Tuple[float, float]:
    """
    Evaluate solver code on a single CVRP instance.
    
    Args:
        code: Solver code string (must define select function)
        instance: CVRP instance dict with 'depot', 'customers', 'vehicle_capacity'
        oracle_value: Optimal solution cost (for gap calculation)
        time_limit: Time limit in seconds (optional)
        max_iterations: Maximum iterations (not used for step-by-step, kept for compatibility)
        max_stagnation: Maximum stagnation rounds (not used for step-by-step, kept for compatibility)
        debug_mode: Whether to print debug information
        
    Returns:
        Tuple of (solution_cost, gap_percentage)
    """
    try:
        # Compile solver code
        solver_module = types.ModuleType("solver_module")
        exec_namespace = prepare_exec_namespace()
        exec_namespace.update(solver_module.__dict__)
        exec(code, exec_namespace)
        solver_module.__dict__.update(exec_namespace)
        
        # Check if select function exists
        if not hasattr(solver_module, 'select'):
            warnings.warn(
                "CVRP solver evaluation: Code does not define 'select' function",
                UserWarning
            )
            return float('inf'), float('inf')
        
        select_func = solver_module.select
        
        # Run solver
        import time
        start_time = time.time()
        
        try:
            solution_cost, route = solve_instance(
                instance=instance,
                select_func=select_func,
                fallback_select=nearest_neighbor_fallback,
                time_limit=time_limit
            )
        except Exception as e:
            warnings.warn(
                f"CVRP solver evaluation: Error running solver: {e}",
                UserWarning
            )
            return float('inf'), float('inf')
        
        elapsed_time = time.time() - start_time
        
        # Check time limit
        if time_limit is not None and elapsed_time > time_limit:
            warnings.warn(
                f"CVRP solver evaluation: Solver exceeded time limit ({elapsed_time:.2f}s > {time_limit}s)",
                UserWarning
            )
            return float('inf'), float('inf')
        
        # Validate solution cost
        if not np.isfinite(solution_cost) or solution_cost < 0:
            warnings.warn(
                f"CVRP solver evaluation: Solver returned invalid cost: {solution_cost}",
                UserWarning
            )
            return float('inf'), float('inf')
        
        # Compute gap
        if oracle_value is not None and oracle_value > 0:
            # Safety check: if solution_cost < oracle_value, it means oracle is not optimal
            if solution_cost < oracle_value:
                warnings.warn(
                    f"CVRP evaluation: solution_cost ({solution_cost:.2f}) < oracle_value ({oracle_value:.2f}). "
                    f"This suggests the oracle may not be optimal. Setting gap to 0.0.",
                    UserWarning
                )
                gap = 0.0
            else:
                gap = (solution_cost - oracle_value) / abs(oracle_value) * 100.0
        else:
            gap = float('inf')
        
        return float(solution_cost), float(gap)
        
    except Exception as e:
        warnings.warn(
            f"CVRP solver evaluation: Evaluation failed: {e}",
            UserWarning
        )
        return float('inf'), float('inf')


def compute_gap_with_oracle(
    code: str,
    instances: List[Dict],
    vehicle_capacity: int,
    num_vehicles: Optional[int],
    oracle_type: str = "ortools",
    oracle_timeout: int = 60,
    utility_cache: Optional[Dict[str, Any]] = None,
    config: Optional[HeuPSROConfig] = None,
    time_limit: Optional[float] = None,
    max_iterations: int = 1000,
    max_stagnation: int = 10
) -> float:
    """
    Evaluate solver on instances and compute average gap.
    
    Args:
        code: Solver code string
        instances: List of CVRP instance dicts
        vehicle_capacity: Vehicle capacity
        num_vehicles: Number of vehicles (optional)
        oracle_type: Oracle type ("ortools" or "none")
        oracle_timeout: Oracle timeout (seconds)
        utility_cache: Optional cache dictionary
        config: Optional config object
        time_limit: Solver time limit (seconds)
        max_iterations: Maximum iterations (if applicable)
        max_stagnation: Maximum stagnation rounds (if applicable)
        
    Returns:
        Average gap percentage
    """
    from .oracle import create_cvrp_oracle
    
    try:
        # Check if oracle is configured
        if oracle_type == "none":
            return float('inf')
        
        # Create oracle
        if config is not None:
            oracle = create_cvrp_oracle(
                config=config,
                oracle_type=oracle_type,
            )
        else:
            oracle = create_cvrp_oracle(oracle_type=oracle_type)
        
        # Debug: print oracle configuration
        debug_mode = getattr(config, 'debug_mode', False) if config else False
        if debug_mode:
            print(f"      [CVRP Evaluation] Oracle type: {oracle_type}")
            if oracle_type == 'lkh3':
                lkh3_path = getattr(config, 'lkh3_path', None) if config else None
                import os
                print(f"      [CVRP Evaluation] LKH3 path: {lkh3_path}")
                if lkh3_path:
                    print(f"      [CVRP Evaluation] LKH3 exists: {os.path.exists(lkh3_path)}")
        
        # Evaluate each instance
        gaps = []
        solution_costs = []  # Store solution costs for output when gap is 0
        invalid_instances = 0
        for idx, instance in enumerate(instances):
            # Ensure instance has required fields
            if 'vehicle_capacity' not in instance:
                instance['vehicle_capacity'] = vehicle_capacity
            if 'num_vehicles' not in instance and num_vehicles is not None:
                instance['num_vehicles'] = num_vehicles
            
            # Ensure instance has distance_matrix_int for oracle consistency
            if 'distance_matrix_int' not in instance:
                from .evolution.shared.solver.solve_instance import instance_to_solver_inputs
                instance_to_solver_inputs(instance)
            
            # Compute oracle value
            try:
                oracle_result = oracle.solve_oracle(instance)
                if oracle_result is None:
                    if debug_mode:
                        print(f"      [CVRP Evaluation] Instance {idx}: Oracle returned None")
                    warnings.warn(
                        f"CVRP evaluation: Oracle returned None for instance {idx}. "
                        f"Oracle type: {oracle_type}",
                        UserWarning
                    )
                    oracle_value = None
                else:
                    oracle_value = oracle_result.get("cost")
                    status = oracle_result.get("status", "UNKNOWN")
                    solver_name = oracle_result.get("solver", "unknown")
                    if debug_mode:
                        print(f"      [CVRP Evaluation] Instance {idx}: Oracle ({solver_name}) status={status}, cost={oracle_value}")
                    # Check status - warn if not feasible
                    if status != "FEASIBLE":
                        warnings.warn(
                            f"CVRP evaluation: Oracle ({solver_name}) returned status '{status}' for instance {idx}. "
                            f"Cost: {oracle_value}",
                            UserWarning
                        )
            except Exception as e:
                import traceback
                if debug_mode:
                    print(f"      [CVRP Evaluation] Instance {idx}: Oracle exception: {type(e).__name__}: {e}")
                    traceback.print_exc()
                warnings.warn(
                    f"CVRP evaluation: Failed to compute oracle for instance {idx}: {type(e).__name__}: {e}. "
                    f"Oracle type: {oracle_type}. Skipping.",
                    UserWarning
                )
                invalid_instances += 1
                continue
            
            # Check if instance is valid
            if oracle_value is None or not np.isfinite(oracle_value) or oracle_value <= 0:
                invalid_instances += 1
                num_customers = len(instance.get('customers', []))
                if debug_mode:
                    print(f"      [CVRP Evaluation] Instance {idx}: Invalid oracle_value={oracle_value}, num_customers={num_customers}")
                warnings.warn(
                    f"CVRP evaluation: Instance {idx} has invalid oracle_value={oracle_value} "
                    f"(num_customers={num_customers}). Skipping.",
                    UserWarning
                )
                continue
            
            # Get time limit from config if not provided
            solver_time_limit = time_limit
            if solver_time_limit is None and config is not None:
                solver_time_limit = getattr(config, 'instance_solver_time_limit', None)
            
            # Evaluate solver
            try:
                solution_cost, gap = evaluate_solver_on_instance(
                    code=code,
                    instance=instance,
                    oracle_value=oracle_value,
                    time_limit=solver_time_limit,
                    max_iterations=max_iterations,
                    max_stagnation=max_stagnation,
                    debug_mode=False
                )
            except Exception as e:
                warnings.warn(
                    f"CVRP evaluation: Failed to evaluate solver on instance {idx}: {e}. Skipping.",
                    UserWarning
                )
                continue
            
            if np.isfinite(gap):
                gaps.append(gap)
            else:
                warnings.warn(
                    f"CVRP evaluation: Instance {idx} has non-finite gap={gap}, "
                    f"solution_cost={solution_cost:.2f}, oracle_value={oracle_value:.2f}. Skipping.",
                    UserWarning
                )
        
        # Warn if all instances were invalid
        if invalid_instances == len(instances):
            if debug_mode:
                print(f"      [CVRP Evaluation] All {len(instances)} instances were invalid!")
            warnings.warn(
                f"CVRP evaluation: All {len(instances)} instances were invalid. "
                f"Returning gap=inf.",
                UserWarning
            )
        
        # Compute mean gap
        mean_gap = float(np.mean(gaps)) if gaps else float('inf')
        if np.isnan(mean_gap) or np.isinf(mean_gap):
            mean_gap = float('inf')
        
        if debug_mode:
            print(f"      [CVRP Evaluation] Summary: {len(gaps)} valid gaps, {invalid_instances} invalid instances")
            if gaps:
                print(f"      [CVRP Evaluation] Mean gap: {mean_gap:.4f}%")
            else:
                print(f"      [CVRP Evaluation] No valid gaps, returning inf")
        
        # Optional: cache gap
        if utility_cache is not None:
            cache_key = f"{code[:10]}_{hash(str(instances))}"
            utility_cache[f"gap_{cache_key}"] = mean_gap
        
        return mean_gap
        
    except Exception as e:
        warnings.warn(
            f"CVRP evaluation: Evaluation failed: {e}",
            UserWarning
        )
        return float('inf')


def evaluate_solvers_on_instances_with_optcosts(
    solver_codes: Union[str, List[str]],
    instances: Union[Dict, List[Dict]],
    optimal_costs: List[Optional[float]],
    config: HeuPSROConfig,
    return_format: str = "per_solver"  # "per_solver" | "raw"
) -> Union[Dict[int, Optional[float]],
           Dict[Tuple[int, int], Optional[float]]]:
    """
    evaluate multiple solvers on multiple instances (requires providing optimal_cost for each instance).
    - do not use oracle, directly use provided optimal_costs
    - use instance index as generator_id, avoid being averaged in batch evaluation.
    - output mean_gap for each solver, or raw {(solver_id, inst_idx): gap}.

    Returns:
        see return_format
    """
    from .evolution.shared.batch_eval import batch_evaluate_tasks, evaluate_single_solver_instance
    
    # standardize input
    solver_list: List[str] = [solver_codes] if isinstance(solver_codes, str) else list(solver_codes)
    inst_list: List[Dict] = [instances] if isinstance(instances, dict) else list(instances)
    if len(optimal_costs) != len(inst_list):
        raise ValueError(f"optimal_costs length ({len(optimal_costs)}) must match instances length ({len(inst_list)})")

    # precompile solvers
    solver_modules: List[Optional[types.ModuleType]] = []
    for idx, code in enumerate(solver_list):
        try:
            m = types.ModuleType(f"heuristic_module_{idx}")
            exec_namespace = prepare_exec_namespace()
            exec_namespace.update(m.__dict__)
            exec(code, exec_namespace)
            m.__dict__.update(exec_namespace)
            solver_modules.append(m)
        except Exception as e:
            warnings.warn(f"Failed to compile solver {idx}: {e}", UserWarning)
            solver_modules.append(None)

    # safely get CVRP specific configuration parameters
    time_limit = getattr(config, 'instance_solver_time_limit', 30)
    max_iterations = getattr(config, 'ls_max_iterations', 1000)
    max_stagnation = getattr(config, 'ls_max_stagnation', 10)
    
    # first priority: check if optimal values are provided directly
    provided_optimal_count = sum(1 for opt in optimal_costs if opt is not None and np.isfinite(opt))
    total_instances = len(inst_list)
    
    if provided_optimal_count > 0:
        print(f"[Evaluation] Using provided optimal_costs: {provided_optimal_count}/{total_instances} instances have optimal values")
        print(f"[Evaluation] No oracle computation needed - directly using provided optimal_costs")
    else:
        print(f"[Evaluation] Warning: No valid optimal_costs provided for any instance")
    
    # ensure all instances have distance_matrix_int
    for inst in inst_list:
        if 'distance_matrix_int' not in inst:
            from .evolution.shared.solver.solve_instance import instance_to_solver_inputs
            instance_to_solver_inputs(inst)
    
    # construct tasks: use inst_idx as generator_id, avoid being averaged in batch evaluation
    tasks = []
    for s_idx, mod in enumerate(solver_modules):
        if mod is None:
            continue
        for inst_idx, instance in enumerate(inst_list):
            opt = optimal_costs[inst_idx]  # directly use provided optimal values
            tasks.append((
                s_idx,
                inst_idx,  # key: use different gid for each instance
                instance,
                opt,  # directly use provided optimal values, do not compute oracle
                1.0,  # weight
                time_limit,
                mod,
                max_iterations,
                max_stagnation,
                None  # random_seed
            ))

    if not tasks:
        if return_format == "raw":
            return {}
        return {i: None for i in range(len(solver_list))}

    # batch evaluation, result is (solver_id, generator_id) -> gap
    timeout_per_task = time_limit + 10
    batch_timeout = len(tasks) * timeout_per_task * 1.5 if tasks else None
    n_jobs = getattr(config, 'eval_n_jobs', -1)
    backend = getattr(config, 'parallel_backend', 'loky')
    prefer = getattr(config, 'parallel_prefer', 'processes')
    debug_mode = getattr(config, 'debug_mode', False)
    
    print(f"[Evaluation] Preparing batch evaluation:")
    print(f"   - Total tasks: {len(tasks)} (={len(solver_list)} solvers × {len(inst_list)} instances)")
    print(f"   - Time limit per task: {time_limit}s")
    print(f"   - Parallel config: n_jobs={n_jobs}, backend={backend}, prefer={prefer}")
    
    from joblib.parallel import effective_n_jobs
    actual_workers = effective_n_jobs(n_jobs) if n_jobs > 0 else 1
    estimated_time = len(tasks) * time_limit / max(1, actual_workers)
    print(f"   - Actual workers: {actual_workers} (requested n_jobs={n_jobs})")
    print(f"   - Estimated total time: ~{estimated_time:.1f}s")
    
    results = batch_evaluate_tasks(
        tasks=tasks,
        evaluate_fn=evaluate_single_solver_instance,
        n_jobs=n_jobs,
        backend=backend,
        prefer=prefer,
        timeout=batch_timeout,
        debug_mode=debug_mode,
    )
    
    # results is {(solver_id, generator_id): gap}, where generator_id is inst_idx
    if return_format == "raw":
        # directly return raw results, but need to convert format
        raw_results = {}
        for (s_id, g_id), gap in results.items():
            raw_results[(s_id, g_id)] = gap
        return raw_results
    
    # aggregate by solver
    solver_gaps: Dict[int, List[float]] = {}
    for (s_id, g_id), gap in results.items():
        if s_id not in solver_gaps:
            solver_gaps[s_id] = []
        if gap is not None and np.isfinite(gap):
            solver_gaps[s_id].append(gap)
    
    # compute mean gap for each solver
    per_solver_results = {}
    for s_idx in range(len(solver_list)):
        if s_idx in solver_gaps and solver_gaps[s_idx]:
            mean_gap = float(np.mean(solver_gaps[s_idx]))
            per_solver_results[s_idx] = mean_gap if np.isfinite(mean_gap) else None
        else:
            per_solver_results[s_idx] = None
    
    return per_solver_results




