import torch
import numpy as np
from pathlib import Path


#######################
# TSP utility functions
#######################

def parse_tsplib_name(tsplib_name):
    return "".join(filter(str.isalpha, tsplib_name)), int("".join(filter(str.isdigit, tsplib_name)))


def read_tsplib_file(file_path):
    """
    The read_tsplib_file function reads a TSPLIB file and returns the nodes and name of the problem.
    
    :param file_path: Specify the path to the file that is being read
    :return: A list of nodes and a name
    """
    properties = {}
    reading_properties_flag = True
    nodes = []

    with open(file_path, "r", encoding="utf8") as read_file:
        line = read_file.readline()
        while line.strip():
            # read properties
            if reading_properties_flag:
                if ':' in line:
                    key, val = [x.strip() for x in line.split(':')]
                    properties[key] = val
                else:
                    reading_properties_flag = False

            # read node coordinates
            else:
                if line.startswith("NODE_COORD_SECTION"):
                    pass
                elif line.startswith("EOF"):
                    pass
                else:
                    line_contents = [x.strip() for x in line.split(" ") if x.strip()]
                    _, x, y = line_contents
                    nodes.append([float(x), float(y)])
            line = read_file.readline()

    return nodes, properties["NAME"]


def choose_bsz(size):
    if size<=200:
        return 64
    elif size<=1000:
        return 32
    elif size<=5000:
        return 16
    else:
        return 4


def load_tsplib_file(root, tsplib_name):
    tsplib_dir = "tsplib"
    file_name = f"{tsplib_name}.tsp"
    file_path = root.joinpath(tsplib_dir).joinpath(file_name)
    instance, name = read_tsplib_file(file_path)

    instance = torch.tensor(instance)
    return instance, name

def avg_list(list_object):
    return sum(list_object) / len(list_object) if len(list_object) > 0 else 0


def normalize_tsp_to_unit_board(tsp_instance):
    """
    normalize a tsp instance to a [0, 1]^2 unit board, prefer to have points on both x=0 and y=0
    :param tsp_instance: a (tsp_size, 2) tensor
    :return: a (tsp_size, 2) tensor, a normalized tsp instance
    """
    normalized_instance = tsp_instance.clone()
    normalization_factor = (normalized_instance.max(dim=0).values - normalized_instance.min(dim=0).values).max()
    normalized_instance = (normalized_instance - normalized_instance.min(dim=0).values) / normalization_factor
    return normalized_instance


def normalize_nodes_to_unit_board(nodes):
    return normalize_tsp_to_unit_board(nodes)


def get_dist_matrix(instance):
    size = instance.shape[0]
    x = instance.unsqueeze(0).repeat((size, 1, 1))
    y = instance.unsqueeze(1).repeat((1, size, 1))
    return torch.norm(x - y, p=2, dim=-1)


def calculate_tour_length_by_dist_matrix(dist_matrix, tours):
    # useful to evaluate one/multiple solutions on one (not-extremely-huge) instance
    if tours.dim() == 1:
        tours = tours.unsqueeze(0)
    tour_shifts = torch.roll(tours, shifts=-1, dims=1)
    tour_lens = dist_matrix[tours, tour_shifts].sum(dim=1)
    return tour_lens


def calculate_tour_length_by_instances(instances, tours):
    # evaluate a batch of solutions
    pass

def load_tsp_instances(path):
    path = Path(path)
    if not path.exists():
        raise RuntimeError(f"{path} does not exist.")

    tsp_instance_list = []
    opt_tour_list = []
    opt_len_list = []

    with open(path, 'r', encoding='utf8') as file:
        for line in file.readlines():
            line_contents = line.strip().split(" | ")
            tsp_instance_string, opt_tour_string, opt_len_string = line_contents

            tsp_instance = []
            for node_string in tsp_instance_string.split(" "):
                node = node_string.split(",")
                tsp_instance.append([float(node[0]), float(node[1])])
            tsp_instance_list.append(np.array(tsp_instance))

            opt_tour = [int(x) for x in opt_tour_string.split(" ")]
            opt_tour_list.append(np.array(opt_tour))

            opt_len_list.append(float(opt_len_string))

    tsp_instances = np.array(tsp_instance_list)
    opt_tours = np.array(opt_tour_list)
    opt_lens = np.array(opt_len_list)

    num = tsp_instances.shape[0]
    size = tsp_instances.shape[1]

    return tsp_instances, opt_tours, opt_lens, size, num



#####################
# TSPLIB information
#####################

tsplib_collections = {
    'eil51': 426,
    'berlin52': 7542,
    'st70': 675,
    'pr76': 108159,
    'eil76': 538,
    'rat99': 1211,
    'kroA100': 21282,
    'kroE100': 22068,
    'kroB100': 22141,
    'rd100': 7910,
    'kroD100': 21294,
    'kroC100': 20749,
    'eil101': 629,
    'lin105': 14379,
    'pr107': 44303,
    'pr124': 59030,
    'bier127': 118282,
    'ch130': 6110,
    'pr136': 96772,
    'pr144': 58537,
    'kroA150': 26524,
    'kroB150': 26130,
    'ch150': 6528,
    'pr152': 73682,
    'u159': 42080,
    'rat195': 2323,
    'd198': 15780,
    'kroA200': 29368,
    'kroB200': 29437,
    'tsp225': 3916,
    'ts225': 126643,
    'pr226': 80369,
    'gil262': 2378,
    'pr264': 49135,
    'a280': 2579,
    'pr299': 48191,
    'lin318': 42029,
    'rd400': 15281,
    'fl417': 11861,
    'pr439': 107217,
    'pcb442': 50778,
    'd493': 35002,
    'u574': 36905,
    'rat575': 6773,
    'p654': 34643,
    'd657': 48912,
    'u724': 41910,
    'rat783': 8806,
    'pr1002': 259045,
    'u1060': 224094,
    'vm1084': 239297,
    'pcb1173': 56892,
    'd1291': 50801,
    'rl1304': 252948,
    'rl1323': 270199,
    'nrw1379': 56638,
    'fl1400': 20127,
    'u1432': 152970,
    'fl1577': 22249,
    'd1655': 62128,
    'vm1748': 336556,
    'u1817': 57201,
    'rl1889': 316536,
    'd2103': 80450,
    'u2152': 64253,
    'u2319': 234256,
    'pr2392': 378032,
    'pcb3038': 137694,
    'fl3795': 28772,
    'fnl4461': 182566,
    'rl5915': 565530,
    'rl5934': 556045,
    'rl11849': 923288,
    'usa13509': 19982859,
    'brd14051': 469385,
    'd15112': 1573084,
    'd18512': 645238
}
