"""CVRP-specific initialization logic."""

from __future__ import annotations

import numpy as np


def get_nearest_neighbor_solver_code() -> str:
    """
    Get a nearest neighbor solver code for CVRP.
    
    This is a simple baseline that uses nearest neighbor heuristic.
    Returns a select function for step-by-step construction.
    """
    return (
        "import numpy as np\n"
        "def select(current_node: int, depot: int, unvisited_nodes: np.ndarray, rest_capacity: float, demands: np.ndarray, distance_matrix: np.ndarray) -> int:\n"
        "    \"\"\"Select next node using nearest neighbor heuristic.\n"
        "    \n"
        "    Args:\n"
        "        current_node: ID of the current node (0=depot, 1..n=customers)\n"
        "        depot: ID of the depot (always 0)\n"
        "        unvisited_nodes: Array of IDs of unvisited feasible nodes (customer nodes: 1..n)\n"
        "        rest_capacity: Remaining capacity of the vehicle (float)\n"
        "        demands: Demands array (length n+1), index 0 is depot (demand=0)\n"
        "        distance_matrix: Distance matrix (n+1) x (n+1), index 0 is depot\n"
        "    \n"
        "    Returns:\n"
        "        ID of the next node to visit (0=depot, 1..n=customer, must be in unvisited_nodes or 0)\n"
        "    \"\"\"\n"
        "    if len(unvisited_nodes) == 0:\n"
        "        return 0  # Return to depot if no feasible nodes\n"
        "    \n"
        "    # Filter nodes that satisfy capacity constraint (should already be filtered, but double-check)\n"
        "    feasible = [node for node in unvisited_nodes if demands[node] <= rest_capacity]\n"
        "    \n"
        "    if len(feasible) == 0:\n"
        "        return 0  # Return to depot if no feasible nodes\n"
        "    \n"
        "    # Select nearest feasible node\n"
        "    distances = distance_matrix[current_node, feasible]\n"
        "    nearest_idx = np.argmin(distances)\n"
        "    next_node = int(feasible[nearest_idx])\n"
        "    \n"
        "    return next_node\n"
    )


def get_savings_solver_code() -> str:
    """
    Get a savings algorithm solver code for CVRP.
    
    This uses the Clarke-Wright savings heuristic.
    """
    return (
        "import numpy as np\n"
        "def solve_cvrp(instance):\n"
        "    \"\"\"Solve CVRP using Clarke-Wright savings algorithm.\"\"\"\n"
        "    depot = instance['depot']\n"
        "    customers = instance['customers']\n"
        "    vehicle_capacity = instance['vehicle_capacity']\n"
        "    \n"
        "    # Extract customer coordinates and demands\n"
        "    coords = np.array([c['coords'] for c in customers])\n"
        "    demands = np.array([c['demand'] for c in customers])\n"
        "    \n"
        "    # Compute distance matrix\n"
        "    n = len(customers)\n"
        "    distances = np.zeros((n + 1, n + 1))\n"
        "    all_coords = np.vstack([depot.reshape(1, -1), coords])\n"
        "    for i in range(n + 1):\n"
        "        for j in range(n + 1):\n"
        "            distances[i, j] = np.linalg.norm(all_coords[i] - all_coords[j])\n"
        "    \n"
        "    # Calculate savings\n"
        "    savings = []\n"
        "    for i in range(n):\n"
        "        for j in range(i + 1, n):\n"
        "            s = distances[0, i + 1] + distances[0, j + 1] - distances[i + 1, j + 1]\n"
        "            savings.append((s, i, j))\n"
        "    \n"
        "    # Sort savings in descending order\n"
        "    savings.sort(reverse=True)\n"
        "    \n"
        "    # Initialize routes (each customer starts as a separate route)\n"
        "    routes = [[i] for i in range(n)]\n"
        "    route_loads = [demands[i] for i in range(n)]\n"
        "    \n"
        "    # Merge routes based on savings\n"
        "    for s, i, j in savings:\n"
        "        if s <= 0:\n"
        "            break\n"
        "        \n"
        "        # Find routes containing i and j\n"
        "        route_i = None\n"
        "        route_j = None\n"
        "        for idx, route in enumerate(routes):\n"
        "            if i in route:\n"
        "                route_i = idx\n"
        "            if j in route:\n"
        "                route_j = idx\n"
        "        \n"
        "        # Skip if same route or if merge would exceed capacity\n"
        "        if route_i == route_j or route_i is None or route_j is None:\n"
        "            continue\n"
        "        \n"
        "        if route_loads[route_i] + route_loads[route_j] > vehicle_capacity:\n"
        "            continue\n"
        "        \n"
        "        # Check if i and j are at endpoints of their routes\n"
        "        route_i_list = routes[route_i]\n"
        "        route_j_list = routes[route_j]\n"
        "        \n"
        "        i_is_endpoint = route_i_list[0] == i or route_i_list[-1] == i\n"
        "        j_is_endpoint = route_j_list[0] == j or route_j_list[-1] == j\n"
        "        \n"
        "        if not (i_is_endpoint and j_is_endpoint):\n"
        "            continue\n"
        "        \n"
        "        # Merge routes\n"
        "        if route_i_list[-1] == i and route_j_list[0] == j:\n"
        "            new_route = route_i_list + route_j_list\n"
        "        elif route_i_list[0] == i and route_j_list[-1] == j:\n"
        "            new_route = route_j_list + route_i_list\n"
        "        elif route_i_list[-1] == i and route_j_list[-1] == j:\n"
        "            new_route = route_i_list + route_j_list[::-1]\n"
        "        elif route_i_list[0] == i and route_j_list[0] == j:\n"
        "            new_route = route_i_list[::-1] + route_j_list\n"
        "        else:\n"
        "            continue\n"
        "        \n"
        "        # Update routes\n"
        "        routes[route_i] = new_route\n"
        "        route_loads[route_i] += route_loads[route_j]\n"
        "        routes.pop(route_j)\n"
        "        route_loads.pop(route_j)\n"
        "    \n"
        "    # Calculate total distance\n"
        "    total_distance = 0\n"
        "    for route in routes:\n"
        "        if not route:\n"
        "            continue\n"
        "        total_distance += distances[0, route[0] + 1]  # Depot to first customer\n"
        "        for i in range(len(route) - 1):\n"
        "            total_distance += distances[route[i] + 1, route[i + 1] + 1]\n"
        "        total_distance += distances[route[-1] + 1, 0]  # Last customer to depot\n"
        "    \n"
        "    return {\n"
        "        'routes': routes,\n"
        "        'total_distance': total_distance\n"
        "    }\n"
    )


def get_random_instance_generator_code() -> str:
    """
    Return Python code that generates CVRP instances using random customer locations and demands.
    """
    return (
        "import numpy as np\n"
        "def generate_instances(seeds, num_customers, vehicle_capacity):\n"
        "    instances = []\n"
        "    for seed in seeds:\n"
        "        rng = np.random.default_rng(seed)\n"
        "        \n"
        "        # Generate depot location (center of area)\n"
        "        depot = rng.uniform(0, 100, size=2)\n"
        "        \n"
        "        # Generate customer locations\n"
        "        customer_coords = rng.uniform(0, 100, size=(num_customers, 2))\n"
        "        \n"
        "        # Generate customer demands (between 1 and vehicle_capacity // 3)\n"
        "        max_demand = max(1, vehicle_capacity // 3)\n"
        "        demands = rng.integers(1, max_demand + 1, size=num_customers)\n"
        "        \n"
        "        customers = []\n"
        "        for i in range(num_customers):\n"
        "            customers.append({\n"
        "                'coords': customer_coords[i].tolist(),\n"
        "                'demand': int(demands[i])\n"
        "            })\n"
        "        \n"
        "        instance = {\n"
        "            'depot': depot.tolist(),\n"
        "            'customers': customers,\n"
        "            'vehicle_capacity': vehicle_capacity\n"
        "        }\n"
        "        instances.append(instance)\n"
        "    return instances\n"
    )


def initialize_cvrp_strategies(controller) -> None:
    """
    Initialize CVRP-specific strategy pools.
    
    Creates:
    - Basic solver (nearest neighbor)
    - Random instance generator
    
    Args:
        controller: HeuPSROController instance
    """
    print(" Initializing HeuPSRO for CVRP...")
    
    # Create basic solver using nearest neighbor heuristic
    basic_solver = get_nearest_neighbor_solver_code()
    
    h_idx = controller.pools.add_solver(
        "h0", basic_solver, "Nearest neighbor solver", {}, {"source": "init"}
    )
    
    # Create random instance generator
    num_customers = controller.cfg.num_customers
    vehicle_capacity = controller.cfg.vehicle_capacity
    random_generator_code = get_random_instance_generator_code()
    
    g_idx = controller.pools.add_generator(
        "g0",
        random_generator_code,
        "Random instance generator",
        {"num_customers": num_customers, "vehicle_capacity": vehicle_capacity},
        {"min_ratio": controller.cfg.min_simple_ratio, "source": "init"}
    )
    
    print("   Initialized with nearest neighbor solver and random instance generator")




































