# Based on https://github.com/ai4co/routefinder/blob/main/routefinder/baselines/pyvrp.py

import argparse
import torch
import os
import logging
import csv

from scipy.spatial import distance
import numpy as np

from pyvrp import Client, Depot, ProblemData, VehicleType, solve as _solve
from pyvrp.stop import MaxRuntime
from torch import Tensor


SCALING_FACTOR = 1000

def scale(data, scaling_factor):
    """
    Scales ands rounds data to integers so PyVRP can handle it.
    """
    array = (data * scaling_factor).round()
    array = np.where(array == np.inf, np.iinfo(np.int32).max, array)
    array = array.astype(int)

    if array.size == 1:
        return array.item()

    return array

def solve(instance, max_runtime):
    """
    Solves theinstance with PyVRP.

    Parameters
    ----------
    instance
        The instance to solve.
    max_runtime
        Maximum runtime for the solver.

    Returns
    -------
    tuple[Tensor, Tensor]
        A tuple consisting of the action and the cost, respectively.
    """
    data = instance2data(instance)
    stop = MaxRuntime(max_runtime)
    result = _solve(data, stop)

    solution = result.best
    solution_routes = [[0, *r.visits(), 0] for r in solution.routes()]
    cost = result.cost() / SCALING_FACTOR

    return cost, solution_routes, result.runtime, result.num_iterations


def instance2data(instance):
    """
    Converts an instance to a ProblemData instance.

    Parameters
    ----------
    instance

    Returns
    -------
    ProblemData
        The ProblemData instance.
    """
    depot_xy, node_xy, node_demand, capacity, node_tw, node_service = instance
    locs = np.append(depot_xy, node_xy, axis=0)
    cost_matrix = distance.cdist(locs, locs, 'euclidean')

    num_locs = locs.shape[0]
    node_tw = scale(node_tw, SCALING_FACTOR)
    node_service = scale(node_service, SCALING_FACTOR)
    matrix = scale(cost_matrix, SCALING_FACTOR)
    depot_xy = scale(depot_xy, SCALING_FACTOR)
    node_xy = scale(node_xy, SCALING_FACTOR)



    depot = Depot(
        x=depot_xy[0][0],
        y=depot_xy[0][1],
    )

    clients = [
        Client(
            x=node_xy[idx][0],
            y=node_xy[idx][1],
            tw_early=node_tw[idx][0],
            tw_late=node_tw[idx][1],
            delivery=node_demand[idx],
            service_duration=node_service[idx],
        )
        for idx in range(0, num_locs - 1)
    ]

    vehicle_type = VehicleType(
        num_available=num_locs - 1,  # one vehicle per client
        capacity=capacity,
        #tw_early=time_windows[0][0],
        #tw_late=time_windows[0][1],
    )

    return ProblemData(clients, [depot], [vehicle_type], [matrix], [matrix])

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_path", help="Filename of the dataset(s) to evaluate")
    parser.add_argument('--results_dir', default='results', help="Name of results directory")
    parser.add_argument('--timelimit', type=int, default=60, help="Timelimit per instance in seconds")
    parser.add_argument('--nb_instances', type=int, default=-1, help="Number of instances to process")

    opts = parser.parse_args()

    dataset_basename, ext = os.path.splitext(os.path.split(opts.dataset_path)[-1])

    results_dir = os.path.join(opts.results_dir, dataset_basename + '_' + str(opts.timelimit) + 's')
    os.makedirs(results_dir, exist_ok=True)

    # Define the path to the log file
    log_file = os.path.join(results_dir, 'log.txt')

    # Configure the logging
    logging.basicConfig(
        level=logging.INFO,  # Set the logging level
        format='%(asctime)s - %(message)s',  # Log format
        handlers=[
            logging.FileHandler(log_file),  # Write logs to a file
            logging.StreamHandler()  # Optional: Also output logs to the console
        ]
    )

    dataset = torch.load(opts.dataset_path)
    depot_xy = dataset['depot_xy'].numpy()
    node_xy = dataset['node_xy'].numpy()
    node_tw = dataset['node_tw'].numpy()
    node_sd = dataset['node_sd'].numpy()
    capacity = dataset['capacity'].numpy()
    node_demand = dataset['node_demand'].numpy()

    nb_instances = opts.nb_instances
    if nb_instances < 0:
        nb_instances = depot_xy.shape[0]

    dataset_costs = []
    dataset_rt = []
    dataset_iter = []
    for i in range(nb_instances):
        instance = depot_xy[i], node_xy[i], node_demand[i], capacity[i].item(), node_tw[i], node_sd[i]
        cost, solution_routes, runtime, num_iterations = solve(instance, opts.timelimit)
        dataset_costs.append(cost)
        dataset_rt.append(runtime)
        dataset_iter.append(num_iterations)

        logging.info(
            "Instance {:3d}/{:3d}, score: {:.3f}, running_mean: {:.3f} iter: {}".format(
                i + 1, nb_instances, cost, np.array(dataset_costs).mean(), num_iterations))

        with open(os.path.join(results_dir, "results.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i+1, cost, runtime, num_iterations])

        with open(os.path.join(results_dir, "solutions.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i + 1, solution_routes])

    logging.info(" *** Test Done *** ")
    logging.info(" AVG. COSTS: {:.4f} ".format(np.array(dataset_costs).mean()))
    logging.info(" AVG. RUNTIME: {:.2f} ".format(np.array(dataset_rt).mean()))
    logging.info(" AVG. ITERATIONS: {:.0f} ".format(np.array(dataset_iter).mean()))
