"""BP Online-specific evaluation functions for PSRO.

This module contains BP Online-specific evaluation logic that can be used
by the PSRO controller. Functions are pure functions that receive
dependencies as parameters to avoid tight coupling.
"""

from __future__ import annotations

from typing import List, Dict, Optional, Any, Tuple, Union, Union
import numpy as np
import types
import warnings

from ...core.config import HeuPSROConfig

# Try to import scipy.special.erf for numpy.erf compatibility
try:
    from scipy.special import erf as scipy_erf
    HAS_SCIPY_ERF = True
except ImportError:
    HAS_SCIPY_ERF = False
    # Fallback: use math.erf if available (Python 3.7+)
    try:
        from math import erf as math_erf
        scipy_erf = math_erf
    except ImportError:
        scipy_erf = None


def prepare_exec_namespace() -> Dict[str, Any]:
    """
    Prepare a namespace for exec() that includes numpy.erf compatibility.
    
    Returns:
        Dictionary with numpy and other standard imports, including numpy.erf
    """
    namespace = {
        '__builtins__': __builtins__,
        'numpy': np,
        'np': np,
    }
    
    # Add numpy.erf compatibility
    # Since numpy doesn't have erf in newer versions, we add it directly to the namespace
    # This avoids pickling issues with wrapper classes
    if HAS_SCIPY_ERF and scipy_erf is not None:
        # Add erf as a standalone function in the namespace
        namespace['erf'] = scipy_erf
        # Also try to add it to numpy module (may not work if numpy is read-only)
        try:
            # Create a copy of numpy module attributes and add erf
            # This is safer than modifying the original numpy module
            import types
            numpy_module = types.ModuleType('numpy')
            # Copy all numpy attributes
            for attr_name in dir(np):
                if not attr_name.startswith('_'):
                    try:
                        setattr(numpy_module, attr_name, getattr(np, attr_name))
                    except (TypeError, AttributeError):
                        pass
            numpy_module.erf = scipy_erf
            namespace['numpy'] = numpy_module
            namespace['np'] = numpy_module
        except Exception:
            # If that fails, just add erf to namespace and let code use it directly
            # Code can use: erf(x) instead of numpy.erf(x)
            pass
    
    return namespace


def online_binpack(items: np.ndarray, bins: np.ndarray, alg) -> Tuple[List[List[float]], np.ndarray]:
    """
    Performs online binpacking of items into bins.
    
    Args:
        items: Array of item sizes
        bins: Array of bin capacities (will be modified in-place)
        alg: Algorithm module with score() function
        
    Returns:
        Tuple of (packing, bins_packed):
        - packing: List of lists, each containing items in a bin
        - bins_packed: Remaining capacities of bins (copy of bins)
    """
    # Copy bins to avoid modifying the original array
    # Ensure bins are float type to avoid type casting errors in score functions
    bins = bins.copy().astype(float)
    
    # Track which items are added to each bin
    packing = [[] for _ in bins]
    
    # Add items to bins
    for item in items:
        # Extract bins that have sufficient space to fit item
        valid_bin_indices = np.nonzero((bins - item) >= 0)[0]
        
        if len(valid_bin_indices) == 0:
            # No valid bins, skip this item (should not happen in normal operation)
            continue
        

        # Score each bin based on heuristic
        # Ensure bins passed to score are float type
        valid_bins = bins[valid_bin_indices].astype(float)
        
        # Check if alg has score function
        if not hasattr(alg, 'score') or not callable(getattr(alg, 'score', None)):
            raise AttributeError(f"Module '{alg.__name__}' has no attribute 'score'")
        
        try:
            with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
                priorities = alg.score(float(item), valid_bins)
            
            # clean scores
            priorities = np.asarray(priorities, dtype=float)
            
            # check shape: must be the same as valid_bins
            if priorities.shape != valid_bins.shape:
                # shape is not the same, give all bins the worst score (this will trigger fallback)
                priorities = np.full(valid_bins.shape, -1e30, dtype=float)
            else:
                # clean inf/nan, replace with worst score
                priorities = np.nan_to_num(priorities, nan=-1e30, posinf=-1e30, neginf=-1e30)
        except Exception:
            priorities = np.full(valid_bins.shape, -1e30, dtype=float)
        
        # Add item to bin with highest priority
        best_bin = valid_bin_indices[np.argmax(priorities)]
        bins[best_bin] -= item
        packing[best_bin].append(item)
    
    # Remove unused bins from packing
    packing = [bin_items for bin_items in packing if bin_items]
    return packing, bins


def evaluate_solver_on_instance(
    code: str,
    instance: Dict,
    lb: Optional[float] = None
) -> Tuple[float, float]:
    """
    Evaluate a solver on a single BP Online instance.
    
    Args:
        code: Solver code string
        instance: Dict with 'items', 'capacity', 'num_items'
        lb: Lower bound (if None, will be computed)
        
    Returns:
        Tuple of (num_bins, gap)
    """
    try:
        # Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            # Create a new module object
            heuristic_module = types.ModuleType("heuristic_module")
            
            # Prepare namespace with numpy.erf compatibility
            exec_namespace = prepare_exec_namespace()
            exec_namespace.update(heuristic_module.__dict__)
            
            # Execute the code string in the new module's namespace
            exec(code, exec_namespace)
            heuristic_module.__dict__.update(exec_namespace)
            
            # Verify that the module has a score function
            if not hasattr(heuristic_module, 'score') or not callable(getattr(heuristic_module, 'score', None)):
                raise AttributeError(f"Module 'heuristic_module' has no attribute 'score'")
            
            # Get instance data
            capacity = instance['capacity']
            items = np.array(instance['items'], dtype=float)  # Ensure float type
            num_items = instance['num_items']
            
            # Create num_items bins so there will always be space for all items
            # Use float type to avoid type casting errors in score functions
            bins = np.array([float(capacity) for _ in range(num_items)], dtype=float)
            
            # Pack items into bins
            packing, bins_packed = online_binpack(items, bins, heuristic_module)
            
            # Count number of used bins - use packing length which is more reliable
            # packing already filters out unused bins, so len(packing) is the actual number of bins used
            num_bins = len(packing)
            # Compute gap if lb is provided
            if lb is not None and lb > 0:
                gap = (num_bins / lb - 1.0) * 100.0
            else:
                gap = float('inf')
            
            return float(num_bins), float(gap)
            
    except Exception as e:
        print(f"      Evaluation failed: {e}")
        return float('inf'), float('inf')


def compute_gap_with_oracle(
    code: str,
    instances: List[Dict],
    capacity: int,
    oracle_type: str = "lb",
    oracle_timeout: int = 0,
    utility_cache: Optional[Dict[str, Any]] = None,
    config: Optional[HeuPSROConfig] = None
) -> float:
    """
    unified solver evaluation function: compute lb for each instance, then evaluate solver, compute gap.
    
    reference tsp_gls logic, but use BP Online specific evaluation method.
    
    Args:
        code: Solver code string
        instances: BP Online instance list, each is Dict containing 'items', 'capacity', 'num_items'
        capacity: bin capacity
        oracle_type: Oracle type ("lb" or "none")
        oracle_timeout: Oracle timeout (seconds, BP Online不需要)
        utility_cache: optional cache dictionary
        config: optional config object
    Returns:
        Gap percentage (average)
    """
    from .oracle import create_bp_online_oracle
    
    try:
        # 1. check if oracle is configured
        if oracle_type == "none":
            # if oracle is not used, return inf
            return float('inf')
        
        # 2. create oracle
        # pass config object if provided
        if config is not None:
            oracle = create_bp_online_oracle(
                config=config,
                oracle_type=oracle_type,  # 可以覆盖 config 中的 oracle_type
            )
        else:
            # fallback: create oracle with parameters if config is not provided
            oracle = create_bp_online_oracle(oracle_type=oracle_type)
        
        # 3. evaluate each instance
        gaps = []
        for instance in instances:
            # compute lb
            lb_result = oracle.solve_exact(instance)
            # lb_result may be a tuple (value, status) or a single value (fallback)
            if isinstance(lb_result, tuple):
                lb = lb_result[0]
            else:
                lb = lb_result
            
            # evaluate solver
            num_bins, gap = evaluate_solver_on_instance(code, instance, lb)
            
            if np.isfinite(gap):
                gaps.append(gap)
        
        # 4. aggregate and return average gap
        mean_gap = float(np.mean(gaps)) if gaps else float('inf')
        if np.isnan(mean_gap) or np.isinf(mean_gap):
            mean_gap = float('inf')
        
        if utility_cache is not None:
            cache_key = f"{code[:10]}_{hash(str(instances))}"
            utility_cache[f"gap_{cache_key}"] = mean_gap
        
        return mean_gap
        
    except Exception as e:
        print(f"      Evaluation failed: {e}")
        return float('inf')


def evaluate_solvers_on_instances_with_lbs(
    solver_codes: Union[str, List[str]],
    instances: Union[Dict, List[Dict]],
    lower_bounds: List[Optional[float]],
    config: HeuPSROConfig,
    return_format: str = "per_solver"  # "per_solver" | "raw"
) -> Union[Dict[int, Tuple[Optional[float], Optional[float]]],
           Dict[Tuple[int, int], Tuple[Optional[float], Optional[float]]]]:
    """
    evaluate multiple solvers on multiple instances (requires lower_bound for each instance).
    - use instance index as generator_id, avoid being averaged in batch evaluation.
    - output (mean_num_bins, mean_gap) for each solver, or raw {(solver_id, inst_idx): (num_bins, gap)}.

    Args:
        solver_codes: single or multiple solver code strings
        instances: single or multiple instance Dict, each containing 'items', 'capacity', 'num_items'
        lower_bounds: lower bound list aligned with instances (None means instance has no lower bound)
        config: HeuPSROConfig (parallel parameters read from)
        return_format:
            - "per_solver": return {solver_idx: (mean_num_bins, mean_gap)}
            - "raw": return {(solver_idx, inst_idx): (num_bins, gap)}

    Returns:
        see return_format
    """
    from .evolution.shared.batch_eval import batch_evaluate_tasks, evaluate_single_solver_instance
    
    solver_list: List[str] = [solver_codes] if isinstance(solver_codes, str) else list(solver_codes)
    inst_list: List[Dict] = [instances] if isinstance(instances, dict) else list(instances)
    if len(lower_bounds) != len(inst_list):
        raise ValueError(f"lower_bounds length ({len(lower_bounds)}) must match instances length ({len(inst_list)})")

    capacity = getattr(config, 'capacity', 100)
    num_items = getattr(config, 'num_items', 100)
    time_limit = getattr(config, 'instance_solver_time_limit', 5)
    oracle_timeout = getattr(config, 'oracle_timeout', 0)
    debug_mode = getattr(config, 'debug_mode', False)
    
    solver_modules: List[Optional[types.ModuleType]] = []
    for idx, code in enumerate(solver_list):
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                m = types.ModuleType(f"heuristic_module_{idx}")
                
                # Prepare namespace with numpy.erf compatibility
                exec_namespace = prepare_exec_namespace()
                exec_namespace.update(m.__dict__)
                
                exec(code, exec_namespace)
                m.__dict__.update(exec_namespace)
                
                # Verify that the module has a score function
                if not hasattr(m, 'score') or not callable(getattr(m, 'score', None)):
                    if debug_mode:
                        print(f"      Warning: Solver {idx} compiled but has no 'score' function")
                    solver_modules.append(None)
                else:
                    solver_modules.append(m)
        except Exception as e:
            if debug_mode:
                print(f"      Warning: Failed to compile solver {idx}: {e}")
            solver_modules.append(None)
    
    # construct tasks: use inst_idx as generator_id, avoid being averaged in batch evaluation
    tasks = []
    for s_idx, mod in enumerate(solver_modules):
        if mod is None:
            continue
        for inst_idx, instance in enumerate(inst_list):
            lb = lower_bounds[inst_idx]
            tasks.append((
                s_idx,
                inst_idx,  # use different gid for each instance
                instance,
                lb,
                1.0,  # weight
                time_limit,
                mod
            ))

    if not tasks:
        if return_format == "raw":
            return {}
        return {i: (None, None) for i in range(len(solver_list))}

    # batch evaluate
    timeout_per_task = time_limit + oracle_timeout + 10
    batch_timeout = len(tasks) * timeout_per_task * 1.5 if tasks else None
    results = batch_evaluate_tasks(
        tasks=tasks,
        evaluate_fn=evaluate_single_solver_instance,
        n_jobs=getattr(config, 'eval_n_jobs', -1),
        backend=getattr(config, 'parallel_backend', 'loky'),
        prefer=getattr(config, 'parallel_prefer', 'processes'),
        timeout=batch_timeout,
        debug_mode=debug_mode,
        track_time=False,
        time_key="solver",
        task_batch_size=getattr(config, 'batch_eval_task_batch_size', None)
    )

    # raw return: per instance (num_bins, gap)
    if return_format == "raw":
        raw: Dict[Tuple[int, int], Tuple[Optional[float], Optional[float]]] = {}
        for (s_id, inst_idx), gap_val in results.items():
            lb = lower_bounds[inst_idx]
            if lb is None or np.isnan(lb) or not np.isfinite(gap_val) or gap_val >= 1e8:
                raw[(s_id, inst_idx)] = (None, None)
            else:
                num_bins = float(lb) * (1.0 + float(gap_val) / 100.0)
                raw[(s_id, inst_idx)] = (float(num_bins), float(gap_val))
        return raw

    per_solver: Dict[int, Tuple[Optional[float], Optional[float]]] = {}
    n_solvers = len(solver_list)
    for s_id in range(n_solvers):
        num_bins_list, gaps = [], []
        for inst_idx in range(len(inst_list)):
            key = (s_id, inst_idx)
            if key not in results:
                continue
            gap_val = results[key]
            lb = lower_bounds[inst_idx]
            if lb is not None and not np.isnan(lb) and np.isfinite(gap_val) and gap_val < 1e8:
                num_bins = float(lb) * (1.0 + float(gap_val) / 100.0)
                num_bins_list.append(num_bins)
                gaps.append(float(gap_val))
        mean_num_bins = float(np.mean(num_bins_list)) if num_bins_list else None
        mean_gap = float(np.mean(gaps)) if gaps else None
        per_solver[s_id] = (mean_num_bins, mean_gap)
    return per_solver

