"""
Adapted from: https://github.com/wouterkool/attention-learn-to-route/blob/master/problems/vrp/vrp_baseline.py

"""

import argparse
import os
import pickle
import numpy as np
from subprocess import run
import time
import torch
import csv

import logging


def use_pkl_saved_problems(filename):
    with open(filename, 'rb') as pickle_file:
        data = pickle.load(pickle_file)

    dataset_size = len(data)

    depot_data = [data[i][0] for i in range(dataset_size)]
    depot_data = np.array(depot_data)[:, None, :]
    # shape: (batch, 1, 2)

    node_data = [data[i][1] for i in range(dataset_size)]
    node_data = np.array(node_data)
    # shape: (batch, problem, 2)

    demand_data = [data[i][2] for i in range(dataset_size)]
    demand_data = np.array(demand_data)
    capacity = [data[i][3] for i in range(dataset_size)]

    # Check that all elements in capacity are integers or can be converted to integers
    assert all(isinstance(c, int) or isinstance(c, float) and c.is_integer() for c in
               capacity), "All capacity values must be integers."

    # Convert all elements to integers
    capacity = [int(c) for c in capacity]

    depot_node_xy = np.concatenate((depot_data, node_data), axis=1)
    # shape: (batch, problem+1, 2)
    depot_demand = np.zeros(shape=(dataset_size, 1), dtype="int")
    # shape: (batch, 1)
    depot_node_demand = np.concatenate((depot_demand, demand_data), axis=1)
    # shape: (batch, problem+1)

    return capacity, depot_node_xy, depot_node_demand


def use_saved_problems(filename):
    loaded_dict = torch.load(filename, map_location='cpu')
    saved_depot_xy = loaded_dict['depot_xy'].numpy()
    saved_node_xy = loaded_dict['node_xy'].numpy()
    saved_node_demand = loaded_dict['node_demand'].numpy()
    capacity = loaded_dict['capacity'].numpy()
    grid_size = loaded_dict['grid_size']

    depot_node_xy = np.concatenate((saved_depot_xy, saved_node_xy), axis=1)

    depot_demand = np.zeros(shape=(saved_node_xy.shape[0], 1), dtype="int")
    depot_node_demand = np.concatenate((depot_demand, saved_node_demand), axis=1)

    return capacity, depot_node_xy, depot_node_demand, grid_size



def write_vrplib(filename, depot, loc, demand, capacity, grid_size, name="problem"):

    with open(filename, 'w') as f:
        f.write("\n".join([
            "{} : \t{}\t".format(k, v)
            
            # NEED the comment here. Not sure why removing it causes hgs to throw an error.
            for k, v in (
                ("NAME", name),
                ("COMMENT", "Generated by pickle file."),
                ("TYPE", "CVRP"),
                ("DIMENSION", len(loc) + 1),
                ("EDGE_WEIGHT_TYPE", "EUC_2D"),
                ("CAPACITY", int(capacity))
            )
        ]))
        f.write("\n")
        f.write("NODE_COORD_SECTION\t\t\n")

        if grid_size == 1:
            f.write("\n".join([
                "{}\t{}\t{}".format(i + 1, int(x / grid_size * 10000 + 0.5), int(y / grid_size * 10000 + 0.5))
                for i, (x, y) in enumerate(np.append(depot[np.newaxis, :], loc, axis=0))
            ]))
        elif grid_size == 1000:
            f.write("\n".join([
                # Changed scale by a factor of 10 for HGS
                "{}\t{}\t{}".format(i + 1, x, y)
                for i, (x, y) in enumerate(np.append(depot[np.newaxis, :], loc, axis=0))
            ]))
        f.write("\n")
        f.write("DEMAND_SECTION\t\t\n")
        f.write("\n".join([
            "{}\t{}".format(i + 1, d)
            for i, d in enumerate(np.append([0], demand))
        ]))
        f.write("\n")
        f.write("DEPOT_SECTION\n")
        f.write("1\n")
        f.write("-1\n")
        f.write("EOF\n")
        
        
def read_vrplib(filename):
    with open(filename, 'r') as f:
        tours = []

        for line in f:
            if line.startswith("Route"):
                tour = list(map(int, line.split(":")[1].split()))
                tours.append(tour + [0])

    return tours


def calc_vrp_cost(depot, loc, tour):
    assert (np.sort(tour)[-len(loc):] == np.arange(len(loc)) + 1).all(), "All nodes must be visited once!"
    # TODO validate capacity constraints
    loc_with_depot = np.vstack((np.array(depot)[None, :], np.array(loc)))
    sorted_locs = loc_with_depot[np.concatenate(([0], tour, [0]))]
    return np.linalg.norm(sorted_locs[1:] - sorted_locs[:-1], axis=-1).sum()




def solve_hgs_log(directory, id, depot, loc, demand, capacity, grid_size, timelimit):

    problem_filename = os.path.join(directory, "{}.vrp".format(id))
    output_filename = os.path.join(directory, "{}.sol".format(id))
    log_filename = os.path.join(directory, "{}.log".format(id))


    write_vrplib(problem_filename, depot, loc, demand, capacity, grid_size)

    with open(log_filename, 'w') as f:
        start = time.time()
        if timelimit is None:
            run(["./hgs", problem_filename, output_filename, "-seed", "0", "-round", "0"], stdout=f, stderr=f)
        else:
            run(["./hgs", problem_filename, output_filename, "-seed", "0", "-round", "0", "-it", str(int(1e7)), "-t", str(timelimit)], stdout=f, stderr=f)
        duration = time.time() - start

    tours = read_vrplib(output_filename)

    calculated_cost = calc_vrp_cost(depot, loc, [c for tour in tours for c in tour])
    with open(output_filename, "a") as out_file:
        out_file.write("Calculated Cost: {cost}".format(cost=calculated_cost))

    return tours, calculated_cost, duration, 0



    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_path", help="Filename of the dataset(s) to evaluate")
    parser.add_argument('--results_dir', default='results', help="Name of results directory")
    parser.add_argument('--timelimit', type=int, default=60, help="Timelimit per instance in seconds")
    parser.add_argument('--nb_instances', type=int, default=-1, help="Number of instances to process")


    opts = parser.parse_args()

    dataset_basename, ext = os.path.splitext(os.path.split(opts.dataset_path)[-1])

    if opts.dataset_path.endswith(".pkl"):
        capacity, depot_node_xy, depot_node_demand = use_pkl_saved_problems(opts.dataset_path)
        grid_size = 1
        start_temp = 0.1
        end_temp = 0.001
    elif opts.dataset_path.endswith('.pt'):
        capacity, depot_node_xy, depot_node_demand, grid_size = use_saved_problems(opts.dataset_path)
        assert grid_size == 1
        start_temp = 0.1
        end_temp = 0.001
    else:
        raise NotImplementedError

    results_dir = os.path.join(opts.results_dir, dataset_basename + '_' + str(opts.timelimit) + 's')
    temp_dir = os.path.join(results_dir, "tmp")

    nb_instances = opts.nb_instances
    if nb_instances < 0:
        nb_instances = depot_node_xy.shape[0]

    # Logging
    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(temp_dir, exist_ok=True)
    # Define the path to the log file
    log_file = os.path.join(results_dir, 'log.txt')

    # Configure the logging
    logging.basicConfig(
        level=logging.INFO,  # Set the logging level
        format='%(asctime)s - %(message)s',  # Log format
        handlers=[
            logging.FileHandler(log_file),  # Write logs to a file
            logging.StreamHandler()  # Optional: Also output logs to the console
        ]
    )

    dataset_costs = []
    dataset_rt = []
    dataset_iter = []
    for i in range(nb_instances):
        solution, cost, runtime, nb_iter = solve_hgs_log(temp_dir, i, depot_node_xy[i, 0], depot_node_xy[i, 1:],
                                                                                                       depot_node_demand[i, 1:], capacity[i], grid_size,
                                                         opts.timelimit)
        dataset_costs.append(cost)
        dataset_rt.append(runtime)
        dataset_iter.append(nb_iter)

        solution_routes = solution
        solution_routes = [[0, *r] for r in solution_routes]

        logging.info(
            "Instance {:3d}/{:3d}, score: {:.3f}, running_mean: {:.3f} iter: {}".format(
                i + 1, nb_instances, cost, np.array(dataset_costs).mean(), nb_iter,
            ))

        with open(os.path.join(results_dir, "results.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i + 1, cost, runtime, nb_iter])

        with open(os.path.join(results_dir, "solutions.csv"), mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([i + 1, solution_routes])

    logging.info(" *** Test Done *** ")
    logging.info(" AVG. COSTS: {:.4f} ".format(np.array(dataset_costs).mean()))
    logging.info(" AVG. RUNTIME: {:.2f} ".format(np.array(dataset_rt).mean()))
    logging.info(" AVG. ITERATIONS: {:.0f} ".format(np.array(dataset_iter).mean()))

        