import torch
import elkai
import numpy as np
import sys
import argparse
from rich_argparse_plus import RichHelpFormatterPlus
from tqdm import tqdm
from pprint import pprint
from collections import defaultdict

sys.path.append(f"./utils/")
sys.path.append(f"./scripts/")
sys.path.append(f"./generator/")
from utils import get_dist_matrix
from utils import load_tsp_instances
from utils import calculate_tour_length
from augmentation import Augmentation
from generate_random_tsp import solve_use_int_elkai


def main(args):
    tsp_instances, opt_tours, opt_lens, size, num = load_tsp_instances(args.path)
    tsp_instances = torch.tensor(tsp_instances).float()
    augment_module = Augmentation()
    tsp_instances, _ = augment_module.scale(tsp_instances)
    noised_tsp_instances = augment_module.noise(tsp_instances, noise=args.noise)
    # print((noised_tsp_instances - tsp_instances).max())

    runs = args.runs
    amp_border = 1000000
    gaps = []

    for index in range(num):
        tsp_instance = tsp_instances[index]
        opt_tour = opt_tours[index]

        dist_matrix = get_dist_matrix(tsp_instance)

        aug_instance = noised_tsp_instances[index]
        aug_dist_matrix = get_dist_matrix(aug_instance)
        amp = amp_border / aug_dist_matrix.max()
        aug_sol = solve_use_int_elkai(aug_dist_matrix, amp=amp, runs=runs)
        aug_len = calculate_tour_length(dist_matrix, torch.tensor(aug_sol))

        opt_len = calculate_tour_length(dist_matrix, torch.tensor(opt_tour))
        gap = aug_len / opt_len - 1
        gaps.append(gap)
        print(f"{gap.item() * 100:.7f}%")

    print(f"Nosy perturbation {args.noise} on TSP{size}")
    print(f"Avg solution gap: {(sum(gaps) / len(gaps)).item() * 100:.2f}%")


def parse():
    RichHelpFormatterPlus.choose_theme("prince")
    parser = argparse.ArgumentParser(
        description="Augmentation statistics for TSP.",
        formatter_class=RichHelpFormatterPlus,
    )

    # general hyperparameters (training values)
    general_args = parser.add_argument_group("General Hyperparameters")
    general_args.add_argument("--runs", type=int, default=100,
                              help="Runs of LKH algorithm for each instance.")
    general_args.add_argument("--no-print-param", action="store_true",
                              help="Do not print the parameter information in log files.")

    # typical hyperparameters (values for research)
    typical_args = parser.add_argument_group("TYPICAL HYPERPARAMETERS")
    typical_args.add_argument("--path", type=str,
                              help="Path for tsp instances.")
    typical_args.add_argument("--noise", type=float, default=1e-5,
                              help="Noise added in augmentation.")

    args = parser.parse_args()

    if not args.no_print_param:
        for key, value in vars(args).items():
            print(f"{key} = {value}")
        print(f"=" * 20)
        print()

    return args


if __name__ == '__main__':
    args = parse()
    main(args)
