import torch
import numpy as np
import sys
import argparse
from rich_argparse_plus import RichHelpFormatterPlus
from tqdm import tqdm
from pprint import pprint
from collections import defaultdict

sys.path.append(f"./utils/")
from utils import get_dist_matrix
from utils import load_tsp_instances


def main(args):
    tsp_instances, opt_tours, opt_lens, size, num = load_tsp_instances(args.path)
    counter = defaultdict(int)

    for index in tqdm(range(num)):
        tsp_instance = tsp_instances[index]
        opt_tour = opt_tours[index]
        opt_len = opt_lens[index]

        dist_matrix = get_dist_matrix(torch.tensor(tsp_instance)).numpy()
        for j in range(len(tsp_instance)):
            curr_node_index = opt_tour[j]
            prev_node_index = opt_tour[(j - 1) % size]
            next_node_index = opt_tour[(j + 1) % size]

            dist_list = dist_matrix[curr_node_index]
            prev_node_dist_rank = np.where(np.argsort(dist_list) == prev_node_index)[0][0]
            next_node_dist_rank = np.where(np.argsort(dist_list) == next_node_index)[0][0]

            counter[prev_node_dist_rank] += 1
            counter[next_node_dist_rank] += 1

    print(f"Nearest node selection for TSP{size}")
    pprint(counter)


def parse():
    RichHelpFormatterPlus.choose_theme("prince")
    parser = argparse.ArgumentParser(
        description="KNN statistics for TSP.",
        formatter_class=RichHelpFormatterPlus,
    )

    # general hyperparameters (training values)
    general_args = parser.add_argument_group("General Hyperparameters")
    general_args.add_argument("--no-print-param", action="store_true",
                              help="Do not print the parameter information in log files.")

    # typical hyperparameters (values for research)
    typical_args = parser.add_argument_group("TYPICAL HYPERPARAMETERS")
    typical_args.add_argument("--path", type=str,
                              help="Path for tsp instances.")

    args = parser.parse_args()

    if not args.no_print_param:
        for key, value in vars(args).items():
            print(f"{key} = {value}")
        print(f"=" * 20)
        print()

    return args


if __name__ == '__main__':
    args = parse()
    main(args)
