# Based on https://github.com/wouterkool/attention-learn-to-route/blob/master/problems/vrp/vrp_baseline.py

import argparse
import torch
import pickle
import os
import logging
import csv
import time
import numpy as np
from subprocess import check_call

SCALING_FACTOR = 100000


def calc_vrp_cost(depot, loc, tour):
    assert (np.sort(tour)[-len(loc):] == np.arange(len(loc)
                                                   ) + 1).all(), "All nodes must be visited once!"
    loc_with_depot = np.vstack((np.array(depot)[None, :], np.array(loc)))
    sorted_locs = loc_with_depot[np.concatenate(([0], tour, [0]))]
    return np.linalg.norm(sorted_locs[1:] - sorted_locs[:-1], axis=-1).sum()


def write_lkh_par(filename, parameters):
    default_parameters = {  # Use none to include as flag instead of kv
        "SPECIAL": None,
        "MAX_TRIALS": 10000,
        "RUNS": 10,
        "TRACE_LEVEL": 1,
        "SEED": 0
    }
    with open(filename, 'w') as f:
        for k, v in {**default_parameters, **parameters}.items():
            if v is None:
                f.write("{}\n".format(k))
            else:
                f.write("{} = {}\n".format(k, v))


def read_vrplib(filename, n):
    with open(filename, 'r') as f:
        tour = []
        dimension = 0
        started = False
        for line in f:
            if started:
                loc = int(line)
                if loc == -1:
                    break
                tour.append(loc)
            if line.startswith("DIMENSION"):
                dimension = int(line.split(" ")[-1])

            if line.startswith("TOUR_SECTION"):
                started = True

    assert len(tour) == dimension
    # Subtract 1 as depot is 1 and should be 0
    tour = np.array(tour).astype(int) - 1
    # Any nodes above the number of nodes there are is also depot
    tour[tour > n] = 0
    assert tour[0] == 0  # Tour should start with depot
    assert tour[-1] != 0  # Tour should not end with depot
    return tour[1:].tolist()


def write_vrplib(filename, depot, loc, demand, capacity, grid_size, name="problem"):

    with open(filename, 'w') as f:
        f.write("\n".join([
            "{} : {}".format(k, v)
            for k, v in (
                ("NAME", name),
                ("TYPE", "CVRP"),
                ("DIMENSION", len(loc) + 1),
                ("EDGE_WEIGHT_TYPE", "EUC_2D"),
                ("CAPACITY", capacity)
            )
        ]))
        f.write("\n")
        f.write("NODE_COORD_SECTION\n")
        f.write("\n".join([
            "{}\t{}\t{}".format(i + 1, int(x / grid_size * SCALING_FACTOR + 0.5),
                                int(y / grid_size * SCALING_FACTOR + 0.5))  # VRPlib does not take floats
            # "{}\t{}\t{}".format(i + 1, x, y)
            for i, (x, y) in enumerate([depot] + loc)
        ]))
        f.write("\n")
        f.write("DEMAND_SECTION\n")
        f.write("\n".join([
            "{}\t{}".format(i + 1, d)
            for i, d in enumerate([0] + demand)
        ]))
        f.write("\n")
        f.write("DEPOT_SECTION\n")
        f.write("1\n")
        f.write("-1\n")
        f.write("EOF\n")


def reformat_tours(tours_list):
    """
    Reformats the LKH output to our standard format.

    Args:
        tours_list (list[int]): List representing the tours output by LKH.
                                This is of the format [c_1, c_2, ..., 0, ..., c_n]
                                where c_i are the indices of the nodes in the tour and 0s are separators between tours.

    Returns:
        list[list[int]]: List of tours in the format [[0, c_1, c_2, ..., 0], [0, c_3, c_4, ..., 0], ...]
    """
    result = []
    current_sublist = []

    for num in tours_list:
        if num == 0:
            if current_sublist:
                current_sublist.append(0)  # Add zero to the end of the sublist
                result.append(current_sublist)
                current_sublist = []
        else:
            if not current_sublist and result:
                # Start a new sublist if we're not at the beginning and the previous sublist was non-empty
                current_sublist = [0]
            current_sublist.append(num)

    # Append the last sublist if it exists
    if current_sublist:
        current_sublist.append(0)
        result.append(current_sublist)

    # add the missing zero for the first tour
    result[0] = [0] + result[0]

    return result


def solve(instance, directory, name, runs=1, grid_size=1):
    executable = os.path.join(os.path.dirname(__file__), opts.executable)

    depot, loc, demand, capacity = instance

    problem_filename = os.path.join(
        directory, "{}.lkh{}.vrp".format(name, runs))
    tour_filename = os.path.join(directory, "{}.lkh{}.tour".format(name, runs))
    param_filename = os.path.join(directory, "{}.lkh{}.par".format(name, runs))
    log_filename = os.path.join(directory, "{}.lkh{}.log".format(name, runs))

    try:
        write_vrplib(problem_filename, depot, loc, demand,
                     capacity, grid_size, name=name)

        params = {"PROBLEM_FILE": problem_filename,
                  "OUTPUT_TOUR_FILE": tour_filename, "RUNS": runs, "SEED": 1234}
        write_lkh_par(param_filename, params)

        with open(log_filename, 'w') as f:
            start = time.time()
            check_call([executable, param_filename], stdout=f, stderr=f)
            duration = time.time() - start

        tours = read_vrplib(tour_filename, n=len(demand))

        return calc_vrp_cost(depot, loc, tours), reformat_tours(tours), duration

    except Exception as e:
        logging.error("Error in instance {}: {}".format(name, e))
        return None


def use_pkl_saved_problems(dataset_path, nb_instances):
    with open(dataset_path, 'rb') as f:
        dataset = pickle.load(f)
    nb_customers = len(dataset[0][1])

    if nb_instances < 0:
        nb_instances = len(dataset)

    dataset_size = len(dataset)
    depot_xy = [dataset[i][0] for i in range(dataset_size)]
    node_xy = [dataset[i][1] for i in range(dataset_size)]
    demand = [dataset[i][2] for i in range(dataset_size)]
    capacity = [dataset[i][3] for i in range(dataset_size)]

    return depot_xy, node_xy, demand, capacity, nb_customers, nb_instances


def use_saved_problems(dataset_path, nb_instances):
    dataset = torch.load(dataset_path, weights_only=True, map_location='cpu')
    nb_customers = len(dataset['node_xy'][0])

    if nb_instances < 0:
        nb_instances = dataset['depot_xy'].shape[0]

    depot_xy = np.array(dataset['depot_xy']).squeeze().tolist()
    node_xy = dataset['node_xy'].tolist()
    demand = dataset['node_demand'].tolist()
    capacity = np.array(dataset['capacity']).astype(float).squeeze().tolist()

    return depot_xy, node_xy, demand, capacity, nb_customers, nb_instances


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dataset_path", help="Filename of the dataset(s) to evaluate")
    parser.add_argument('--results_dir', default='results',
                        help="Name of results directory")
    parser.add_argument('--executable', default='LKH',
                        help="Name of LKH executable")
    parser.add_argument('--nb_instances', type=int, default=-1,
                        help="Number of instances to process")
    parser.add_argument('--runs', type=int, default=1,
                        help="Number of times to run LKH")

    opts = parser.parse_args()

    dataset_basename, ext = os.path.splitext(
        os.path.split(opts.dataset_path)[-1])

    results_dir = os.path.join(opts.results_dir, dataset_basename)
    os.makedirs(results_dir, exist_ok=True)

    working_dir = os.path.join(results_dir, 'temp')
    os.makedirs(working_dir, exist_ok=True)

    # Verify that the LKH executable exists in the current directory
    assert os.path.isfile(
        opts.executable), "LKH executable not found. Please download it from http://webhotel4.ruc.dk/~keld/research/LKH-3/ and place it in the current directory."

    # Define the path to the log file
    log_file = os.path.join(results_dir, 'log.txt')

    # Configure the logging
    logging.basicConfig(
        level=logging.INFO,  # Set the logging level
        format='%(asctime)s - %(message)s',  # Log format
        handlers=[
            logging.FileHandler(log_file),  # Write logs to a file
            logging.StreamHandler()  # Optional: Also output logs to the console
        ]
    )

    if opts.dataset_path.endswith(".pkl"):
        depot_xy, node_xy, demand, capacity, nb_customers, nb_instances = use_pkl_saved_problems(
            opts.dataset_path, opts.nb_instances)
    elif opts.dataset_path.endswith('.pt'):
        depot_xy, node_xy, demand, capacity, nb_customers, nb_instances = use_saved_problems(
            opts.dataset_path, opts.nb_instances)
    else:
        raise NotImplementedError

    dataset_costs = []
    dataset_rt = []
    dataset_iter = []
    for i in range(nb_instances):
        instance = depot_xy[i], node_xy[i], demand[i], capacity[i]

        cost, solution_routes, runtime = solve(
            instance, working_dir, dataset_basename, runs=opts.runs)
        dataset_costs.append(cost)
        dataset_rt.append(runtime)

        logging.info(
            "Instance {:3d}/{:3d}, score: {:.3f}, running_mean: {:.3f}".format(
                i + 1, nb_instances, cost, np.array(dataset_costs).mean()))

        with open(os.path.join(results_dir, "results.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i+1, cost, runtime])

        with open(os.path.join(results_dir, "solutions.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i + 1, solution_routes])

    logging.info(" *** Test Done *** ")
    logging.info(" AVG. COSTS: {:.4f} ".format(np.array(dataset_costs).mean()))
    logging.info(" AVG. RUNTIME: {:.2f} ".format(np.array(dataset_rt).mean()))
