import elkai
import torch
import time

def solve_tsp_instance_by_LKH3(dist_matrix, border=1000000, runs=10):
    """
    solve an instance (by distance matrix) using LKH3 algorithm
    :param dist_matrix: a (size, size) tensor, the distance matrix for tsp instance
    :param border: the maximum of scaled distance values
    :param runs: repetition of LKH3
    :return: a (size, ) tensor, the solution tour
    """
    n = dist_matrix.size(0)
    amp = border / dist_matrix.max()
    dist_matrix = amp * dist_matrix
    tour = elkai.solve_int_matrix(dist_matrix.int().tolist(), runs=runs)
    assert len(tour) == n
    return torch.tensor(tour)

def get_dist_matrix(instance):
    size = instance.shape[0]
    x = instance.unsqueeze(0).repeat((size, 1, 1))
    y = instance.unsqueeze(1).repeat((1, size, 1))
    return torch.norm(x - y, p=2, dim=-1)

def check_tsp_solution_validity(tour):
    """
    a valid tour should have a non-overlapping sequential indexing of existing nodes
    (this already ensures the absence of sub-tour)
    :param tour: the solution tour to be checked
    :return: True for valid tour; False for
    """
    return (tour >= 0).all().item() and (tour <= tour.size(0)).all().item() and tour.size() == tour.unique().size()

def calculate_tour_length_by_dist_matrix(dist_matrix, tours):
    # useful to evaluate one/multiple solutions on one (not-extremely-huge) instance
    if tours.dim() == 1:
        tours = tours.unsqueeze(0)
    tour_shifts = torch.roll(tours, shifts=-1, dims=1)
    tour_lens = dist_matrix[tours, tour_shifts].sum(dim=1)
    return tour_lens

def avg_list(list_object):
    return sum(list_object) / len(list_object) if len(list_object) > 0 else 0

def create_tsp_baselines_by_LKH3(num,size,runs=10,border=1000000):
    # file_path = Path(args.path)
    # if not file_path.exists():
    #     print(f"[!] TSP file {file_path} does not exist.")
    #     exit(0)

    # save_dir = Path(args.solution_root)
    # solution_dir = save_dir.joinpath(file_path.stem)
    # solution_name = f"LKH3_runs{args.runs}.txt"
    # solution_path = solution_dir.joinpath(solution_name)
    # solution_path.parent.mkdir(parents=True, exist_ok=True)

    # if solution_path.exists() and not args.overwrite:
    #     print(f"[!] Solution file {solution_path} already exists. Turn on overwrite flag")
    #     exit(0)

    # if args.overwrite:
    #     with open(solution_path, 'w+', encoding='utf8'):
    #         pass

    tsp_instances = torch.rand(num,size,2)

    tour_len_storage = []
    ellapsed_time_storage = []
    count = 0
    for instance in tsp_instances:
        start_time = time.time()
        dist_matrix = get_dist_matrix(instance)
        tour = solve_tsp_instance_by_LKH3(dist_matrix, border=border, runs=runs)
        end_time = time.time()

        assert check_tsp_solution_validity(tour)
        tour_len = calculate_tour_length_by_dist_matrix(dist_matrix, tour)
        ellapsed_time = end_time - start_time

        tour_text = ",".join([f"{val.item()}" for val in tour])
        solution_text = f"{tour_text} {tour_len.item()} {ellapsed_time}\n"

        tour_len_storage.append(tour_len.item())
        ellapsed_time_storage.append(ellapsed_time)

        #if size == 100:
        #    print(f"One inference costs {ellapsed_time}")
        count = count + 1
        if count % 10 == 0:
            print(f'{count} epochs')

    print(f"LKH3 Inference Finished!")
    print(f"[*] Summary           : LKH3 solver with {runs} runs")
    print(f"[*] Nodes number      : {num}")
    print(f"[*] Average length    : {avg_list(tour_len_storage)}")
    print(f"[*] Average time (s)  : {avg_list(ellapsed_time_storage)}")
    print(f"\n" * 5)

create_tsp_baselines_by_LKH3(1000,50)