"""Utility helpers shared by CVRP baseline methods."""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np

from heupsro.problems.cvrp.evolution.shared.solver.solve_instance import (
    instance_to_solver_inputs,
    compute_total_distance
)


@dataclass(frozen=True)
class BaselineResult:
    """Standardized return object for baseline solvers."""

    routes: List[List[int]]  # List of routes, each route is a list of node IDs (0=depot)
    cost: float  # Total distance
    extras: dict | None = None


def routes_to_single_route(routes: List[List[int]]) -> List[int]:
    """
    Convert multiple routes to a single route representation.
    
    Routes are separated by depot (0) nodes.
    Example: [[0, 1, 2, 0], [0, 3, 4, 0]] -> [0, 1, 2, 0, 3, 4, 0]
    
    Args:
        routes: List of routes, each route starts and ends at depot (0)
        
    Returns:
        Single route list with depot separators
    """
    if not routes:
        return [0]
    
    single_route = []
    for route in routes:
        if route and route[0] == 0:
            single_route.extend(route)
        else:
            single_route.append(0)
            single_route.extend(route)
            single_route.append(0)
    
    # Ensure ends at depot
    if single_route and single_route[-1] != 0:
        single_route.append(0)
    
    return single_route


def compute_route_cost(route: List[int], distance_matrix: np.ndarray) -> float:
    """
    Compute total distance of a route.
    
    Args:
        route: List of node IDs (may contain multiple depots)
        distance_matrix: Distance matrix (n+1) x (n+1)
        
    Returns:
        Total distance
    """
    return compute_total_distance(route, distance_matrix)


def compute_routes_cost(routes: List[List[int]], distance_matrix: np.ndarray) -> float:
    """
    Compute total distance of multiple routes.
    
    Args:
        routes: List of routes, each route is a list of node IDs
        distance_matrix: Distance matrix (n+1) x (n+1)
        
    Returns:
        Total distance across all routes
    """
    total_cost = 0.0
    for route in routes:
        total_cost += compute_route_cost(route, distance_matrix)
    return total_cost


def validate_solution(
    routes: List[List[int]],
    demands: np.ndarray,
    vehicle_capacity: float,
    distance_matrix: np.ndarray
) -> Tuple[bool, str]:
    """
    Validate a CVRP solution.
    
    Args:
        routes: List of routes
        demands: Demands array (length n+1), index 0 is depot (demand=0)
        vehicle_capacity: Vehicle capacity
        distance_matrix: Distance matrix
        
    Returns:
        (is_valid, error_message)
    """
    n_plus_1 = len(demands)
    visited = set()
    
    for route_idx, route in enumerate(routes):
        if not route:
            continue
        
        # Check route starts and ends at depot
        if route[0] != 0:
            return False, f"Route {route_idx} does not start at depot"
        if route[-1] != 0:
            return False, f"Route {route_idx} does not end at depot"
        
        # Check capacity constraint
        route_demand = sum(demands[node] for node in route if node != 0)
        if route_demand > vehicle_capacity:
            return False, f"Route {route_idx} exceeds capacity: {route_demand} > {vehicle_capacity}"
        
        # Check all nodes are valid and not duplicated
        for node in route:
            if node < 0 or node >= n_plus_1:
                return False, f"Route {route_idx} contains invalid node: {node}"
            if node != 0 and node in visited:
                return False, f"Route {route_idx} contains duplicate node: {node}"
            if node != 0:
                visited.add(node)
    
    # Check all customers are visited
    all_customers = set(range(1, n_plus_1))
    if visited != all_customers:
        missing = all_customers - visited
        return False, f"Missing customers: {missing}"
    
    return True, ""


def two_opt_improve(
    route: List[int],
    distance_matrix: np.ndarray,
    max_iterations: int = 100
) -> List[int]:
    """
    Apply 2-opt local search improvement to a route.
    
    Args:
        route: Route as list of node IDs (should start and end at depot 0)
        distance_matrix: Distance matrix
        max_iterations: Maximum iterations
        
    Returns:
        Improved route
    """
    if len(route) <= 3:  # Need at least depot + 2 nodes + depot
        return route
    
    # Extract customer nodes (exclude depot endpoints)
    customers = [node for node in route if node != 0]
    if len(customers) <= 1:
        return route
    
    improved = True
    iterations = 0
    
    while improved and iterations < max_iterations:
        improved = False
        iterations += 1
        
        best_cost = compute_route_cost([0] + customers + [0], distance_matrix)
        
        for i in range(len(customers) - 1):
            for j in range(i + 2, len(customers)):
                # Try reversing segment between i and j
                new_customers = customers[:i+1] + customers[i+1:j+1][::-1] + customers[j+1:]
                new_route = [0] + new_customers + [0]
                new_cost = compute_route_cost(new_route, distance_matrix)
                
                if new_cost < best_cost:
                    customers = new_customers
                    best_cost = new_cost
                    improved = True
                    break
            
            if improved:
                break
    
    return [0] + customers + [0]

