import numpy as np
import torch
import pickle
import os
import csv
import vrplib


def cal_cost(routes, nodes):
    cost = 0.
    for route in routes:
        for i in range(len(route)):
            if i == 0:
                cost += np.linalg.norm(nodes[0][route[i]] - nodes[0][0])
            else:
                cost += np.linalg.norm(nodes[0][route[i]] - nodes[0][route[i - 1]])
        cost += np.linalg.norm(nodes[0][route[i]] - nodes[0][0])
    
    return cost

if __name__ == "__main__":
    path = "total/"
    file_list = os.listdir(path)
    dataset = []
    opts = []
    for file in file_list:
        if '.sol' in file:
            continue
        file = file[:-4]
        instance_file = path + '/' + file + '.vrp'
        solution_file = path + '/' + file + '.sol'
        
        solution = vrplib.read_solution(solution_file)
        # optimal = solution['cost']
        routes = solution['routes']

        instance = vrplib.read_instance(instance_file)
        node_coord = torch.tensor(instance['node_coord']).unsqueeze(0)
        min_x = torch.min(node_coord[:, :, 0], 1)[0]
        min_y = torch.min(node_coord[:, :, 1], 1)[0]
        max_x = torch.max(node_coord[:, :, 0], 1)[0]
        max_y = torch.max(node_coord[:, :, 1], 1)[0]
        scaled_depot_node_x = (node_coord[:, :, 0] - min_x) / (max_x - min_x)
        scaled_depot_node_y = (node_coord[:, :, 1] - min_y) / (max_y - min_y)
        # scaled_depot_node_x = node_coord[:, :, 0]
        # scaled_depot_node_y = node_coord[:, :, 1]
        depot_node_xy = torch.cat((scaled_depot_node_x[:, :, None]
                                        , scaled_depot_node_y[:, :, None]), dim=2)
        optimal = cal_cost(routes, depot_node_xy)
        # print(optimal, solution['cost'])
        depot = depot_node_xy[:, instance['depot'], :]
        node_xy = depot_node_xy[:, instance['depot'][0]:, :]
        demand = torch.tensor(instance['demand']).unsqueeze(0)
        demand = demand / instance['capacity']
        data = {
            'loc': node_xy,
            # Uniform 1 - 9, scaled by capacities
            'demand': demand,
            'depot': depot
        }
        dataset.append(data)
        opts.append(optimal)

    with open("dataset.pkl", 'wb') as f:
        pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)

    save_dir_name = "results/result_opt.txt"
    with open(save_dir_name, 'w') as f:
        csv_writer = csv.writer(f)
        for i, opt in enumerate(opts):
            csv_writer.writerow([i, opt])