#!/usr/bin/env python3
"""Official testing module for PSRO experiments."""

import os
import json
import time
import pickle
import types
import sys
import warnings
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
import numpy as np
from joblib import Parallel, delayed

# Add the project root to Python path for imports
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

# Note: We use solve_instance directly instead of TSPGLSSolverProblem for single-instance evaluation


def load_test_data(test_data_dir: str, sizes: List[int], max_instances_per_size: int = 10) -> Dict[int, Dict[str, Any]]:
    """Load official TSP test data from pickle files.
    
    Args:
        test_data_dir: Directory containing TSP{20,50,100}.pkl files
        sizes: List of TSP sizes to load (e.g., [20, 50, 100])
        max_instances_per_size: Maximum number of instances to randomly sample per size
    
    Returns:
        Dict mapping size to test data: {size: {'coords': [...], 'distances': [...], 'costs': [...]}}
    """
    test_data = {}
    
    for size in sizes:
        pkl_path = os.path.join(test_data_dir, f"TSP{size}.pkl")
        if not os.path.exists(pkl_path):
            print(f"Warning: Test data file not found: {pkl_path}")
            continue
            
        try:
            with open(pkl_path, 'rb') as f:
                data = pickle.load(f)
            
            # Extract data from the pickle file
            coords = data['coordinate']
            distances = data['distance_matrix'] 
            costs = data['cost']
            
            # Randomly sample instances if there are too many
            n_total = len(coords)
            if n_total > max_instances_per_size:
                import random
                random.seed(42)  # Fixed seed for reproducibility
                indices = random.sample(range(n_total), max_instances_per_size)
                coords = [coords[i] for i in indices]
                distances = [distances[i] for i in indices]
                costs = [costs[i] for i in indices]
                print(f"Loaded TSP{size}: randomly sampled {max_instances_per_size} from {n_total} instances")
            else:
                print(f"Loaded TSP{size}: {n_total} instances")
            
            test_data[size] = {
                'coords': coords,
                'distances': distances,
                'costs': costs,
                'n_instances': len(coords)
            }
            
        except Exception as e:
            print(f"Error loading {pkl_path}: {e}")
            continue
    
    return test_data


def load_solver_code(solver_code: str) -> types.ModuleType:
    """Load solver code string as a module object.
    
    Args:
        solver_code: Python code string containing the solver heuristic
    
    Returns:
        Module object containing the solver
    """
    # Create a new module object
    heuristic_module = types.ModuleType("heuristic_module")
    
    # Suppress warnings during code execution
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        # Execute the code string in the new module's namespace
        exec(solver_code, heuristic_module.__dict__)
        
        # Add the module to sys.modules so it can be imported
        sys.modules[heuristic_module.__name__] = heuristic_module
    
    return heuristic_module


def _evaluate_solver_parallel(args):
    """Helper function for parallel solver evaluation."""
    solver_code, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves = args
    return evaluate_solver_on_instance(solver_code, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves)


def _evaluate_nash_parallel(args):
    """Helper function for parallel Nash mixture evaluation."""
    solver_pool, sigma_h, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves = args
    return evaluate_nash_mixture_on_instance(solver_pool, sigma_h, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves)


def evaluate_solver_on_instance(
    solver_code: str,
    coords: np.ndarray,
    distance_matrix: np.ndarray,
    optimal_cost: Optional[float],
    time_limit: int = 60,
    ite_max: int = 1000,
    perturbation_moves: int = 1
) -> Tuple[float, float]:
    """Evaluate a single solver on a single TSP instance.
    
    This is a convenience wrapper for backward compatibility.
    For batch evaluation, use evaluate_solvers_on_instances_batch instead.
    
    Args:
        solver_code: Solver code string
        coords: Instance coordinates
        distance_matrix: Distance matrix
        optimal_cost: Optimal cost (if available)
        time_limit: Time limit for solving
        ite_max: Maximum iterations
        perturbation_moves: Perturbation moves
    
    Returns:
        Tuple of (cost, gap_percentage) where gap_percentage is None if optimal_cost is None
    """
    # Use batch evaluation for single instance (more efficient)
    results = evaluate_solvers_on_instances_batch(
        solver_codes=[solver_code],
        instances_list=[[coords]],  # Single dataset with single instance
        optimal_costs_list=[[optimal_cost]] if optimal_cost is not None else [[None]],
        time_limit=time_limit,
        ite_max=ite_max,
        perturbation_moves=perturbation_moves,
        n_jobs=1  # Single instance, no need for parallel
    )
    return results[0]  # Return single result


def evaluate_solvers_on_instances_batch(
    solver_codes: List[str],
    instances_list: List[List[np.ndarray]],
    optimal_costs_list: List[List[Optional[float]]],
    time_limit: int = 60,
    ite_max: int = 1000,
    perturbation_moves: int = 1,
    n_jobs: int = -1,
    backend: str = "threading",
    prefer: str = "threads",
    use_gap: bool = True,
    debug_mode: bool = False,
    timeout: Optional[float] = None
) -> Dict[Tuple[int, int], Tuple[float, float]]:
    """Evaluate multiple solvers on multiple datasets using batch evaluation.
    
    Directly builds tasks from existing solvers and instances, without calling prepare_tasks.
    Each dataset is treated as a generator, and each dataset can have different number of instances.
    
    Args:
        solver_codes: List of solver code strings
        instances_list: List of datasets, each dataset is a list of instances (coords arrays)
        optimal_costs_list: List of datasets, each dataset is a list of optimal costs (or None)
        time_limit: Time limit for solving
        ite_max: Maximum iterations
        perturbation_moves: Perturbation moves
        n_jobs: Number of parallel jobs
        backend: Parallel backend
        prefer: Parallel preference
        use_gap: Whether to compute gap (requires optimal costs)
    
    Returns:
        Dictionary mapping (solver_idx, dataset_idx) -> (avg_cost, avg_gap) for that dataset
        Results are averaged across instances within each dataset
    """
    from heupsro.problems.tsp_gls.evolution.shared.batch_eval import batch_evaluate_tasks, evaluate_single_solver_instance
    
    try:
        n_solvers = len(solver_codes)
        solver_ids = list(range(n_solvers))
        n_datasets = len(instances_list)
        
        # Step 1
        solver_modules = []
        for solver_id, solver_code in zip(solver_ids, solver_codes):
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    heuristic_module = types.ModuleType(f"heuristic_module_{solver_id}")
                    exec(solver_code, heuristic_module.__dict__)
                    solver_modules.append(heuristic_module)
            except Exception as e:
                print(f"Warning: Error compiling solver {solver_id}: {e}")
                solver_modules.append(None)
        
        # Step 2: tasks (solver_id, generator_id, coords, optimal_cost, weight, time_limit, ite_max, perturbation_moves, heuristic_module)
        all_tasks = []
        for solver_idx, solver_id in enumerate(solver_ids):
            if solver_modules[solver_idx] is None:
                continue
            
            heuristic_module = solver_modules[solver_idx]
            
            for dataset_idx in range(n_datasets):
                generator_id = dataset_idx
                instances = instances_list[dataset_idx]
                opt_costs = optimal_costs_list[dataset_idx]
                
                for inst_idx, coords in enumerate(instances):
                    optimal_cost = opt_costs[inst_idx] if inst_idx < len(opt_costs) else None
                    weight = 1.0
                    
                    all_tasks.append((
                        solver_id, generator_id, coords, optimal_cost, weight,
                        time_limit, ite_max, perturbation_moves,
                        heuristic_module
                    ))
        
        # Step 3: batch evaluate all tasks
        # results_dict maps (solver_id, generator_id) -> mean_gap (averaged over instances in that dataset)
        results_dict = batch_evaluate_tasks(
            tasks=all_tasks,
            evaluate_fn=evaluate_single_solver_instance,
            n_jobs=n_jobs,
            backend=backend,
            prefer=prefer,
            debug_mode=debug_mode,
            track_time=False,
            time_key="solver",
            timeout=timeout
        )
        
        # Step 4: aggregate results
        solver_dataset_results = {}
        for solver_id in range(n_solvers):
            for dataset_idx in range(n_datasets):
                gap = results_dict.get((solver_id, dataset_idx), 1e9)
                
                if gap < 1e8:  # Valid result
                    costs = []
                    for inst_idx in range(len(instances_list[dataset_idx])):
                        opt_cost = optimal_costs_list[dataset_idx][inst_idx]
                        if opt_cost is not None and not np.isnan(opt_cost):
                            cost = opt_cost * (1 + gap / 100.0)
                            costs.append(cost)
                    
                    avg_cost = float(np.mean(costs)) if costs else None
                    solver_dataset_results[(solver_id, dataset_idx)] = (avg_cost, float(gap))
                else:
                    solver_dataset_results[(solver_id, dataset_idx)] = (1e10, 100.0)
        
        return solver_dataset_results
        
    except Exception as e:
        print(f"Error in batch evaluation: {e}")
        import traceback
        traceback.print_exc()
        # Return error results
        return {(s, d): (1e10, 100.0) for s in range(len(solver_codes)) for d in range(len(instances_list))}


def evaluate_nash_mixture_on_instance(
    solver_pool: List[Any],
    sigma_h: np.ndarray,
    coords: np.ndarray,
    distance_matrix: np.ndarray,
    optimal_cost: Optional[float],
    time_limit: int = 60,
    ite_max: int = 1000,
    perturbation_moves: int = 1,
    weight_threshold: float = 0.01
) -> Tuple[float, float]:
    """Evaluate Nash mixture on a single TSP instance.
    
    Args:
        solver_pool: List of solver objects from controller.pools.solver_pool
        sigma_h: Nash equilibrium weights for solvers
        coords: Instance coordinates
        distance_matrix: Distance matrix
        optimal_cost: Optimal cost (if available)
        time_limit: Time limit for solving
        ite_max: Maximum iterations
        perturbation_moves: Perturbation moves
        weight_threshold: Minimum weight threshold for evaluation
    
    Returns:
        Tuple of (expected_cost, gap_percentage)
    """
    try:
        # Filter solvers with significant weights
        significant_solvers = []
        significant_weights = []
        
        for i, solver in enumerate(solver_pool):
            if i < len(sigma_h) and sigma_h[i] > weight_threshold:
                significant_solvers.append(solver)
                significant_weights.append(sigma_h[i])
        
        if not significant_solvers:
            print("Warning: No solvers with significant weights found")
            return 1e10, 100.0 if optimal_cost is not None else None
        
        # Normalize weights
        total_weight = sum(significant_weights)
        normalized_weights = [w / total_weight for w in significant_weights]
        
        # Evaluate each significant solver
        costs = []
        for solver in significant_solvers:
            cost, _ = evaluate_solver_on_instance(
                solver.code, coords, distance_matrix, optimal_cost,
                time_limit, ite_max, perturbation_moves
            )
            costs.append(cost)
        
        # Calculate weighted expected cost
        expected_cost = sum(w * c for w, c in zip(normalized_weights, costs))
        
        # Calculate gap if optimal cost is available
        gap = None
        if optimal_cost is not None and not np.isnan(optimal_cost) and optimal_cost > 0:
            gap = (expected_cost / optimal_cost - 1) * 100
        
        return expected_cost, gap
        
    except Exception as e:
        print(f"Error evaluating Nash mixture on instance: {e}")
        return 1e10, 100.0 if optimal_cost is not None else None


def run_official_tests(
    controller,
    test_data_dir: str,
    sizes: List[int] = [20, 50, 100],
    targets: List[str] = ["latest", "nash"],
    time_limit: int = 60,
    ite_max: int = 1000,
    perturbation_moves: int = 1,
    max_instances_per_size: int = 10,
    n_jobs: int = -1,
) -> Dict[str, Any]:
    """Run official tests on TSP instances after PSRO iteration.
    
    Args:
        controller: HeuPSROController instance
        test_data_dir: Directory containing TSP test data
        sizes: List of TSP sizes to test
        targets: List of targets to test ("latest", "nash")
        time_limit: Time limit per instance
        ite_max: Maximum iterations
        perturbation_moves: Perturbation moves
        max_instances_per_size: Maximum instances to randomly sample per size
        n_jobs: Number of parallel jobs (-1 for all cores)
    
    Returns:
        Dictionary containing test results
    """
    print(f"🧪 Running official tests for iteration {controller.iteration}")
    
    # Load test data
    test_data = load_test_data(test_data_dir, sizes, max_instances_per_size)
    if not test_data:
        print("No test data loaded, skipping tests")
        return {}
    
    # Get current meta-game solution using configured solver
    try:
        # Use the same meta-game solver as iterate_one_round
        meta_solver = getattr(controller.cfg, 'meta_game_solver', 'ne')
        
        if meta_solver == "alpha_rank":
            alpha = getattr(controller.cfg, 'alpha_rank_alpha', 15.0)
            num_iters = getattr(controller.cfg, 'alpha_rank_num_iters', 10_000)
            tol = getattr(controller.cfg, 'alpha_rank_tol', 1e-10)
            sigma_h, sigma_g = controller.meta.solve_alpha_rank(
                alpha=alpha,
                num_iters=num_iters,
                tol=tol
            )
            solver_info = "Alpha-Rank"
        else:
            sigma_h, sigma_g = controller.meta.solve_ne()
            solver_info = "NE"
        
        print(f" {solver_info} weights: σ_H={len(sigma_h)} solvers, σ_G={len(sigma_g)} generators")
    except Exception as e:
        solver_info = meta_solver if 'meta_solver' in locals() else "NE"
        print(f"Error computing {solver_info}: {e}")
        sigma_h, sigma_g = None, None
    
    # Get latest solver code from file_manager
    latest_solver_code = controller.file_manager.get_latest_solver_code() if hasattr(controller, 'file_manager') else None
    if latest_solver_code is None:
        print("Warning: No latest solver code available")
        latest_solver_code = ""
    
    # Initialize results structure
    results = {
        "iteration": controller.iteration,
        "timestamp": datetime.now().isoformat(),
        "results": {}
    }
    
    # Helper for evaluating a given sigma_h vector on all instances of a given size
    def _eval_sigma_on_size(sigma_h_local, size_key):
      
        start_time_local = time.time()
        costs_local = []
        gaps_local = []
        coords_list_local = test_data[size_key]['coords']
        distances_list_local = test_data[size_key]['distances']
        costs_list_local = test_data[size_key]['costs']
        
        parallel_args_local = [
            (controller.pools.solver_pool, sigma_h_local, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves)
            for coords, distances, optimal_cost in zip(coords_list_local, distances_list_local, costs_list_local)
        ]
        parallel_results_local = Parallel(n_jobs=n_jobs, backend="threading", prefer="threads")(delayed(_evaluate_nash_parallel)(args) for args in parallel_args_local)
        
        for cost_val, gap_val in parallel_results_local:
            costs_local.append(cost_val)
            if gap_val is not None:
                gaps_local.append(gap_val)
        
        return {
            "mean_cost": np.mean(costs_local),
            "mean_gap": np.mean(gaps_local) if gaps_local else None,
            "total_time": time.time() - start_time_local,
            "n_instances": len(coords_list_local)
        }

    # Test each size
    for size in sizes:
        if size not in test_data:
            continue
            
        print(f"  Testing TSP{size} ({test_data[size]['n_instances']} instances)")
        
        size_results = {}
        coords_list = test_data[size]['coords']
        distances_list = test_data[size]['distances']
        costs_list = test_data[size]['costs']
        
        # Test latest solver
        if "latest" in targets and latest_solver_code:
            print(f"    Testing latest solver...")
            start_time = time.time()
            
            latest_costs = []
            latest_gaps = []
            
            try:
                parallel_args = [
                    (latest_solver_code, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves)
                    for coords, distances, optimal_cost in zip(coords_list, distances_list, costs_list)
                ]
                
                parallel_results = Parallel(n_jobs=n_jobs, backend="threading", prefer="threads")(
                    delayed(_evaluate_solver_parallel)(args) for args in parallel_args
                )
                
                for cost, gap in parallel_results:
                    latest_costs.append(cost)
                    if gap is not None:
                        latest_gaps.append(gap)
                
                total_time = time.time() - start_time
                
                size_results["latest_solver"] = {
                    "mean_cost": np.mean(latest_costs),
                    "mean_gap": np.mean(latest_gaps) if latest_gaps else None,
                    "total_time": total_time,
                    "n_instances": len(coords_list)
                }
                
                print(f"      Latest solver: cost={np.mean(latest_costs):.2f}, gap={np.mean(latest_gaps):.2f}%, time={total_time:.2f}s")
                print(f"        Raw costs: {latest_costs}")
                print(f"        Optimal costs: {[costs_list[i] for i in range(len(latest_costs))]}")
                
            except Exception as e:
                print(f"      Error testing latest solver: {e}")
                size_results["latest_solver"] = {
                    "error": str(e),
                    "total_time": time.time() - start_time,
                    "n_instances": len(coords_list)
                }
        
        # Test Nash mixture (current round)
        if "nash" in targets and sigma_h is not None and controller.pools.solver_pool:
            print(f"    Testing Nash mixture...")
            start_time = time.time()
            
            nash_costs = []
            nash_gaps = []
            
            try:
                parallel_args = [
                    (controller.pools.solver_pool, sigma_h, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves)
                    for coords, distances, optimal_cost in zip(coords_list, distances_list, costs_list)
                ]
                
                parallel_results = Parallel(n_jobs=n_jobs, backend="threading", prefer="threads")(
                    delayed(_evaluate_nash_parallel)(args) for args in parallel_args
                )
                
                for cost, gap in parallel_results:
                    nash_costs.append(cost)
                    if gap is not None:
                        nash_gaps.append(gap)
                
                total_time = time.time() - start_time
                
                size_results["nash_mixture"] = {
                    "mean_cost": np.mean(nash_costs),
                    "mean_gap": np.mean(nash_gaps) if nash_gaps else None,
                    "total_time": total_time,
                    "n_instances": len(coords_list)
                }
                
                print(f"      Nash mixture: cost={np.mean(nash_costs):.2f}, gap={np.mean(nash_gaps):.2f}%, time={total_time:.2f}s")
                
            except Exception as e:
                print(f"      Error testing Nash mixture: {e}")
                size_results["nash_mixture"] = {
                    "error": str(e),
                    "total_time": time.time() - start_time,
                    "n_instances": len(coords_list)
                }

        # Test all solvers in pool, individually
        if "pool" in targets and controller.pools.solver_pool:
            print(f"    Testing all solvers in pool (individual)...")
            start_time = time.time()
            per_solver_results = []
            try:
                for solver_idx, solver in enumerate(controller.pools.solver_pool):
                    solver_costs = []
                    solver_gaps = []
                    parallel_args = [
                        (solver.code, coords, distances, optimal_cost, time_limit, ite_max, perturbation_moves)
                        for coords, distances, optimal_cost in zip(coords_list, distances_list, costs_list)
                    ]
                    parallel_results = Parallel(n_jobs=n_jobs, backend="threading", prefer="threads")(delayed(_evaluate_solver_parallel)(args) for args in parallel_args)
                    for cost, gap in parallel_results:
                        solver_costs.append(cost)
                        if gap is not None:
                            solver_gaps.append(gap)
                    per_solver_results.append({
                        "solver_index": solver_idx,
                        "program_id": getattr(solver, 'program_id', None),
                        "algorithm": getattr(solver, 'algorithm', None),
                        "mean_cost": np.mean(solver_costs),
                        "mean_gap": np.mean(solver_gaps) if solver_gaps else None,
                        "n_instances": len(coords_list)
                    })
                total_time = time.time() - start_time
                size_results["solver_pool"] = {
                    "results": per_solver_results,
                    "total_time": total_time,
                    "n_solvers": len(controller.pools.solver_pool)
                }
                print(f"      Pool tested: {len(per_solver_results)} solvers")
            except Exception as e:
                print(f"      Error testing solver pool: {e}")
                size_results["solver_pool"] = {
                    "error": str(e),
                    "total_time": time.time() - start_time
                }

        # Test Nash mixtures from all saved rounds (history)
        if "nash_history" in targets and controller.pools.solver_pool:
            try:
                # Prefer history file; fallback to scanning round files
                psro_dir = controller.file_manager.paths.psro_results_dir if hasattr(controller, 'file_manager') else controller.out_dir
                history_path = os.path.join(psro_dir, "nash_mixture_history.json")
                round_results = {}
                if os.path.exists(history_path):
                    with open(history_path, 'r') as f:
                        history_data = json.load(f)
                    entries = history_data if isinstance(history_data, list) else history_data.get("history", [])
                    print(f"    Testing Nash history from {len(entries)} rounds via history file...")
                    for entry in entries:
                        it = entry.get("iteration")
                        sigma_h_hist = np.array(entry.get("solver_mixture", {}).get("sigma_h", []), dtype=float)
                        if sigma_h_hist.size == 0:
                            continue
                        res = _eval_sigma_on_size(sigma_h_hist, size)
                        round_results[str(it)] = res
                else:
                    # Fallback: scan round files
                    print("    History file not found; scanning per-round mixture files...")
                    files = [fn for fn in os.listdir(psro_dir) if fn.startswith("nash_mixture_round_") and fn.endswith(".json")]
                    files_sorted = sorted(files, key=lambda x: int(x.split("_")[2].split(".")[0]))
                    for fn in files_sorted:
                        it = int(fn.split("_")[2].split(".")[0])
                        with open(os.path.join(psro_dir, fn), 'r') as f:
                            data = json.load(f)
                        sigma_h_hist = np.array(data.get("solver_mixture", {}).get("sigma_h", []), dtype=float)
                        if sigma_h_hist.size == 0:
                            continue
                        res = _eval_sigma_on_size(sigma_h_hist, size)
                        round_results[str(it)] = res
                size_results["nash_history"] = round_results
            except Exception as e:
                print(f"      Error testing Nash history: {e}")
                size_results["nash_history"] = {"error": str(e)}
        
        results["results"][f"TSP{size}"] = size_results
    
    # Save results to file
    output_dir = controller.out_dir
    results_file = os.path.join(output_dir, f"round_{controller.iteration}_testing.json")
    
    try:
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f" Test results saved to: {results_file}")
    except Exception as e:
        print(f"Error saving results: {e}")
    
    return results
