import numpy as np
import sys
from ortools.constraint_solver import routing_enums_pb2, pywrapcp


def create_data_model(distances, demands, capacity, problem_size, factor = 10000):
    """Stores the data for the problem."""
    data = {}
    data["distance_matrix"] = (distances*factor).astype(int).tolist()
    data["demands"] = demands.astype(int).tolist()
    data["num_vehicles"] = round(problem_size / 5.)  # NOTE: just make sure there are enough vehicles wrt the number of customers
    data["vehicle_capacities"] = [capacity] * data["num_vehicles"]
    data["depot"] = 0
    return data


def print_solution(data, manager, routing, solution):
    """Prints solution on console."""
    print(f"Objective: {solution.ObjectiveValue()}")
    total_distance = 0
    total_load = 0
    for vehicle_id in range(data["num_vehicles"]):
        if not routing.IsVehicleUsed(solution, vehicle_id):
            continue
        index = routing.Start(vehicle_id)
        plan_output = f"Route for vehicle {vehicle_id}:\n"
        route_distance = 0
        route_load = 0
        while not routing.IsEnd(index):
            node_index = manager.IndexToNode(index)
            route_load += data["demands"][node_index]
            plan_output += f" {node_index} Load({route_load}) -> "
            previous_index = index
            index = solution.Value(routing.NextVar(index))
            route_distance += routing.GetArcCostForVehicle(
                previous_index, index, vehicle_id
            )
        plan_output += f" {manager.IndexToNode(index)} Load({route_load})\n"
        plan_output += f"Distance of the route: {route_distance}m\n"
        plan_output += f"Load of the route: {route_load}\n"
        print(plan_output)
        total_distance += route_distance
        total_load += route_load
    print(f"Total distance of all routes: {total_distance}m")
    print(f"Total load of all routes: {total_load}")


if __name__ == '__main__':

    factor = 10000

    dirname = 'dataset_mcts'
    problem_size = int(sys.argv[1])

    raw_data = np.load(f'{dirname}/test{problem_size}_dataset.npy')
    data = []
    n_instance = raw_data.shape[0]
    for i in range(n_instance):
        coordinates = raw_data[i,:,-2:]
        distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
        demands = raw_data[i,:,0]
        capacity = 50
        data.append((coordinates, distances, demands, capacity))

    total_value = 0
    for i, (coordinates, distances, demands, capacity) in enumerate(data):
        
        data = create_data_model(distances, demands, capacity, problem_size, factor)

        # Create the routing index manager.
        manager = pywrapcp.RoutingIndexManager(
            len(data["distance_matrix"]), data["num_vehicles"], data["depot"]
        )

        # Create Routing Model.
        routing = pywrapcp.RoutingModel(manager)

        # Create and register a transit callback.
        def distance_callback(from_index, to_index):
            """Returns the distance between the two nodes."""
            # Convert from routing variable Index to distance matrix NodeIndex.
            from_node = manager.IndexToNode(from_index)
            to_node = manager.IndexToNode(to_index)
            return data["distance_matrix"][from_node][to_node]

        transit_callback_index = routing.RegisterTransitCallback(distance_callback)

        # Define cost of each arc.
        routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)

        # Add Capacity constraint.
        def demand_callback(from_index):
            """Returns the demand of the node."""
            # Convert from routing variable Index to demands NodeIndex.
            from_node = manager.IndexToNode(from_index)
            return data["demands"][from_node]

        demand_callback_index = routing.RegisterUnaryTransitCallback(demand_callback)
        routing.AddDimensionWithVehicleCapacity(
            demand_callback_index,
            0,  # null capacity slack
            data["vehicle_capacities"],  # vehicle maximum capacities
            True,  # start cumul to zero
            "Capacity",
        )

        # Setting first solution heuristic.
        search_parameters = pywrapcp.DefaultRoutingSearchParameters()
        search_parameters.first_solution_strategy = (
            routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC
        )
        search_parameters.local_search_metaheuristic = (
            routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH
        )
        search_parameters.time_limit.FromSeconds(1)

        # Solve the problem.
        # solution = routing.Solve()
        solution = routing.SolveWithParameters(search_parameters)

        # Print solution on console.
        if solution:
            opt = solution.ObjectiveValue()
            # print_solution(data, manager, routing, solution)
        else:
            print(f'No solution for instance #{i+1}/{n_instance}! Skip to the next instance...')
            continue

        total_value += opt
        print(f"{i+1}/{n_instance}: OPT = {opt / factor}", flush=True)
        # break
    print(f"problem size = {problem_size}: OPT = {total_value / n_instance / factor}")
