"""CVRP instance solving with step-by-step construction heuristic."""

from __future__ import annotations

import numpy as np
from typing import List, Callable, Dict, Tuple, Optional


def nearest_neighbor_fallback(
    current_node: int,
    feasible_arr: np.ndarray,
    demands: np.ndarray,
    distance_matrix: np.ndarray
) -> int:
    """
    Fallback: select nearest feasible node.
    
    Args:
        current_node: Current node ID
        feasible_arr: Array of feasible node IDs
        demands: Demands array (not used, kept for interface consistency)
        distance_matrix: Distance matrix
        
    Returns:
        Next node ID (0 if no feasible nodes, otherwise nearest feasible node)
    """
    if len(feasible_arr) == 0:
        return 0  # Return to depot
    
    distances = distance_matrix[current_node, feasible_arr]
    nearest_idx = np.argmin(distances)
    return int(feasible_arr[nearest_idx])


def compute_total_distance(route: List[int], distance_matrix: np.ndarray) -> float:
    """
    Compute total distance of a route.
    
    Supports multi-route solutions (multiple depot separators).
    Route should end at depot (0).
    
    Args:
        route: List of node IDs (may contain multiple depots)
        distance_matrix: Distance matrix (n+1) x (n+1)
        
    Returns:
        Total distance
    """
    if len(route) <= 1:
        return 0.0
    
    total_distance = 0.0
    for i in range(len(route) - 1):
        total_distance += distance_matrix[route[i], route[i + 1]]
    
    return float(total_distance)


def route_construct(
    distance_matrix: np.ndarray,
    demands: np.ndarray,
    vehicle_capacity: float,
    select_func: Callable,
    fallback_select: Callable
) -> List[int]:
    """
    Construct route using step-by-step select heuristic.
    
    Args:
        distance_matrix: Distance matrix (n+1) x (n+1), index 0 is depot
        demands: Demands array (length n+1), index 0 is depot (demand=0)
        vehicle_capacity: Vehicle capacity
        select_func: Select function from solver module
        fallback_select: Fallback select function (e.g., nearest neighbor)
        
    Returns:
        Route as list of node IDs (starts and ends at depot 0)
    """
    n_plus_1 = len(demands)  # nodes: 0..n
    unvisited = set(range(1, n_plus_1))  # Customer nodes: 1..n
    
    route = [0]  # Start at depot
    current_node = 0
    current_load = 0.0
    consecutive_depot = 0
    steps = 0
    max_steps = 10 * (n_plus_1 + 1)  # Anti-stuck guard
    
    while unvisited and steps < max_steps:
        steps += 1
        rest = vehicle_capacity - current_load
        
        # Filter feasible nodes directly from unvisited set
        feasible = [u for u in unvisited if demands[u] <= rest]
        
        if not feasible:
            # Close current route and restart
            if route[-1] != 0:
                route.append(0)
            current_node = 0
            current_load = 0.0
            consecutive_depot = 0
            continue
        
        # Prepare inputs for select function
        feasible_arr = np.array(feasible, dtype=np.int64)
        rest_capacity = float(rest)  # Use scalar float instead of array
        
        # Call select function with error handling and numerical safety
        try:
            with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
                next_node = select_func(
                    current_node, 0, feasible_arr, rest_capacity, demands, distance_matrix
                )
            
            # Validate return value: must be finite, integer, and non-negative
            if not np.isfinite(next_node) or np.isnan(next_node):
                # Invalid value (inf/nan), use fallback
                next_node = fallback_select(current_node, feasible_arr, demands, distance_matrix)
            else:
                # Convert to int and validate
                next_node = int(next_node)
                if next_node < 0:
                    # Negative value, use fallback
                    next_node = fallback_select(current_node, feasible_arr, demands, distance_matrix)
        except Exception:
            # Any exception (including int() conversion error), use fallback
            next_node = fallback_select(current_node, feasible_arr, demands, distance_matrix)
        
        # Handle depot return
        if next_node == 0:
            consecutive_depot += 1
            if route[-1] != 0:
                route.append(0)
            current_node = 0
            current_load = 0.0
            
            # Anti-stuck: too many depot returns -> force pick
            if consecutive_depot >= 3:
                next_node = fallback_select(current_node, feasible_arr, demands, distance_matrix)
            else:
                continue
        
        # Validate next_node is in feasible set
        if next_node != 0:
            feasible_set = set(feasible)
            if next_node not in unvisited or next_node not in feasible_set:
                # Invalid selection, use fallback
                next_node = fallback_select(current_node, feasible_arr, demands, distance_matrix)
                # Re-validate fallback result
                if next_node != 0 and (next_node not in unvisited or next_node not in feasible_set):
                    # Fallback also invalid, return to depot
                    if route[-1] != 0:
                        route.append(0)
                    current_node = 0
                    current_load = 0.0
                    consecutive_depot += 1
                    continue
            
            # Visit customer node
            route.append(int(next_node))
            current_load += float(demands[next_node])
            unvisited.remove(int(next_node))
            current_node = int(next_node)
            consecutive_depot = 0
    
    # Force end at depot
    if route[-1] != 0:
        route.append(0)
    
    # Handle remaining unvisited nodes (shouldn't happen, but safety check)
    if unvisited:
        # Force-finish with fallback: continue until all visited
        while unvisited and steps < max_steps * 2:
            steps += 1
            rest = vehicle_capacity - current_load
            feasible = [u for u in unvisited if demands[u] <= rest]
            
            if not feasible:
                if route[-1] != 0:
                    route.append(0)
                current_node = 0
                current_load = 0.0
                continue
            
            feasible_arr = np.array(feasible, dtype=np.int64)
            next_node = fallback_select(current_node, feasible_arr, demands, distance_matrix)
            
            if next_node != 0 and next_node in unvisited:
                route.append(int(next_node))
                current_load += float(demands[next_node])
                unvisited.remove(int(next_node))
                current_node = int(next_node)
            else:
                if route[-1] != 0:
                    route.append(0)
                current_node = 0
                current_load = 0.0
    
    # Final depot return
    if route[-1] != 0:
        route.append(0)
    
    return route


def instance_to_solver_inputs(instance: Dict) -> Tuple[np.ndarray, np.ndarray, float]:
    """
    Convert instance dict to solver input format.
    
    Fixed conversion rules:
    - coords[0] = depot_coord
    - coords[1:] = customers_coord
    - demands[0] = 0 (depot)
    - demands[1:] = customer_demands
    - Build distance_matrix (n+1) x (n+1) as float
    - Also stores distance_matrix_int in instance for oracle consistency
    
    Args:
        instance: Dict with 'depot', 'customers', 'vehicle_capacity'
        
    Returns:
        (distance_matrix, demands, vehicle_capacity)
    """
    depot = np.array(instance['depot'])
    customers = instance['customers']
    vehicle_capacity = float(instance['vehicle_capacity'])
    
    n_customers = len(customers)
    n_plus_1 = n_customers + 1
    
    # Build coordinates array
    coords = np.zeros((n_plus_1, 2))
    coords[0] = depot
    for i, c in enumerate(customers):
        coords[i + 1] = np.array(c['coords'])
    
    # Build demands array
    demands = np.zeros(n_plus_1)
    demands[0] = 0  # Depot
    for i, c in enumerate(customers):
        demands[i + 1] = float(c['demand'])
    
    # Build distance matrix (float, for solver)
    distance_matrix = np.zeros((n_plus_1, n_plus_1))
    # Note: distance_matrix_int is only used for oracle consistency (parsing tour file)
    # Scale 1000.0 (0.001 precision) is sufficient for CVRP
    distance_scale = 1000.0  # Scale factor for integer distance matrix
    
    for i in range(n_plus_1):
        for j in range(n_plus_1):
            dist = np.linalg.norm(coords[i] - coords[j])
            distance_matrix[i, j] = dist
    
    # Also store integer distance matrix in instance for oracle consistency
    # This ensures oracle and solver use the same quantized distances
    distance_matrix_int = np.zeros((n_plus_1, n_plus_1), dtype=np.int64)
    for i in range(n_plus_1):
        for j in range(n_plus_1):
            distance_matrix_int[i, j] = int(distance_matrix[i, j] * distance_scale)
    
    # Store in instance for oracle to use
    instance['distance_matrix_int'] = distance_matrix_int
    instance['distance_scale'] = distance_scale
    
    return distance_matrix, demands, vehicle_capacity


def solve_instance(
    instance: Dict,
    select_func: Callable,
    fallback_select: Optional[Callable] = None,
    time_limit: Optional[float] = None
) -> Tuple[float, List[int]]:
    """
    Solve CVRP instance using step-by-step select heuristic.
    
    Args:
        instance: CVRP instance dict with 'depot', 'customers', 'vehicle_capacity'
        select_func: Select function from solver module
        fallback_select: Fallback select function (default: nearest_neighbor_fallback)
        time_limit: Time limit in seconds (not used currently, kept for interface consistency)
        
    Returns:
        (solution_cost, route)
        - solution_cost: Total distance
        - route: List of node IDs
    """
    if fallback_select is None:
        fallback_select = nearest_neighbor_fallback
    
    # Convert instance to solver inputs
    distance_matrix, demands, vehicle_capacity = instance_to_solver_inputs(instance)
    
    # Construct route
    route = route_construct(
        distance_matrix, demands, vehicle_capacity, select_func, fallback_select
    )
    
    # Compute total distance
    solution_cost = compute_total_distance(route, distance_matrix)
    
    return float(solution_cost), route

