import torch
import numpy as np
from pathlib import Path

#######################
# TSP utility functions
#######################


def get_dist_matrix(instance):
    size = instance.shape[0]
    x = instance.unsqueeze(0).repeat((size, 1, 1))
    y = instance.unsqueeze(1).repeat((1, size, 1))
    return torch.norm(x - y, p=2, dim=-1)


def calculate_tour_length(dist_matrix, tours):
    if tours.dim() == 1:
        tours = tours.unsqueeze(0)
    tour_shifts = torch.roll(tours, shifts=-1, dims=1)
    tour_lens = dist_matrix[tours, tour_shifts].sum(dim=1)
    return tour_lens


def load_tsp_instances(path):
    path = Path(path)
    if not path.exists():
        raise RuntimeError(f"{path} does not exist.")

    tsp_instance_list = []
    opt_tour_list = []
    opt_len_list = []

    with open(path, 'r', encoding='utf8') as file:
        for line in file.readlines():
            line_contents = line.strip().split(" | ")
            tsp_instance_string, opt_tour_string, opt_len_string = line_contents

            tsp_instance = []
            for node_string in tsp_instance_string.split(" "):
                node = node_string.split(",")
                tsp_instance.append([float(node[0]), float(node[1])])
            tsp_instance_list.append(np.array(tsp_instance))

            opt_tour = [int(x) for x in opt_tour_string.split(" ")]
            opt_tour_list.append(np.array(opt_tour))

            opt_len_list.append(float(opt_len_string))

    tsp_instances = np.array(tsp_instance_list)
    opt_tours = np.array(opt_tour_list)
    opt_lens = np.array(opt_len_list)

    num = tsp_instances.shape[0]
    size = tsp_instances.shape[1]

    return tsp_instances, opt_tours, opt_lens, size, num


#####################
# TSPLIB information
#####################

tsplib_collections = {
    "att48": 33524,
    "eil51": 426,
    "berlin52": 7542,
    "st70": 675,
    "eil76": 538,
    "pr76": 108159,
    "rat99": 1211,
    "kroA100": 21282,
    "kroB100": 22141,
    "kroC100": 20749,
    "kroD100": 20749,
    "kroE100": 22068,
    "rd100": 7910,
    "eil101": 629,
    "lin105": 14379,
    "pr107": 44303,
    "pr124": 59030,
    "bier127": 118282,
    "ch130": 6110,
    "pr136": 96772,
    "pr144": 58537,
    "ch150": 6528,
    "kroA150": 26524,
    "kroB150": 26130,
    "pr152": 73682,
    "u159": 42080,
    "rat195": 2323,
    "d198": 15780,
    "kroA200": 29368,
    "kroB200": 29437,
    "ts225": 126643,
    "tsp225": 3916,
    "pr226": 80369,
    "gil262": 2378,
    "pr264": 49135,
    "a280": 2579,
    "pr299": 48191,
    "lin318": 42029,
    "rd400": 15281,
    "fl417": 11861,
    "pr439": 107217,
    "pcb442": 50778,
    "d493": 35002,
    "u574": 36905,
    "rat575": 6773,
    "p654": 34643,
    "d657": 48912,
    "u724": 41910,
    "rat783": 8806,
    "pr1002": 259045,
    "vm1084": 239297,
    "pcb1173": 56892,
    "d1291": 50801,
    "rl1304": 252948,
    "rl1323": 270199,
    "nrw1379": 56638,
    "fl1400": 20127,
    "u1432": 152970,
    "fl1577": 22249,
    "d1655": 62128,
    "vm1748": 336556,
    "u1817": 57201,
    "rl1889": 316536,
    "d2103": 80450,
    "u2152": 64253,
    "u2319": 234256,
    "pr2392": 378032,
    "pcb3038": 137694,
    "fl3795": 28772,
    "fnl4461": 182566,
    "rl5915": 565530,
    "rl5934": 556045,
}
