import elkai
import torch
import time
import sys
import argparse
from rich_argparse_plus import RichHelpFormatterPlus
from tqdm import tqdm
from pathlib import Path

sys.path.append(f"./utils/")
from utils import get_dist_matrix
from utils import calculate_tour_length


def solve_use_int_elkai(dist_matrix, amp=1, runs=10):
    dist_matrix = amp * dist_matrix
    # print(dist_matrix)
    sol = elkai.solve_int_matrix(dist_matrix.int(), runs=runs)
    sol.append(0)
    return sol


def main(args):
    tsp_size = args.size
    tsp_number = args.num
    runs = args.runs
    amp_border = 1000000

    total_time = 0

    path = Path(args.root)
    data_dir_name = f"tspfarm"
    data_file_name = f"TSP{tsp_size}.txt"

    path = path.joinpath(data_dir_name, data_file_name)
    if not path.parent.exists():
        path.parent.mkdir(parents=True)

    if not args.append:
        with open(path, "w+", encoding='utf8') as file:
            pass

    for num in tqdm(range(tsp_number)):
        tsp_data = torch.rand((tsp_size, 2)).to(args.device)
        dist_matrix = get_dist_matrix(tsp_data)

        amp = amp_border / dist_matrix.max()

        start_time = time.time()
        sol = solve_use_int_elkai(dist_matrix, amp=amp, runs=runs)
        end_time = time.time()

        delta = end_time - start_time
        total_time += delta

        sol = torch.Tensor(sol).long().to(args.device)
        sol_len = calculate_tour_length(dist_matrix, sol)

        with open(path, "a+", encoding='utf8') as file:
            tsp_string = " ".join([f"{node[0]},{node[1]}" for node in tsp_data])
            file.write(f"{tsp_string}")
            file.write(" | ")
            opt_string = " ".join([str(x.item()) for x in sol])
            file.write(f"{opt_string}")
            file.write(" | ")
            file.write(f"{sol_len.item()}")
            file.write("\n")

    print(f"Successfully generate the {tsp_number} tsp in {total_time / 3600:.2f} hrs.")
    print(f"Each inference costs {total_time / tsp_number} seconds.")


def parse():
    RichHelpFormatterPlus.choose_theme("prince")
    parser = argparse.ArgumentParser(
        description="Random TSP generator. For research TS4.",
        formatter_class=RichHelpFormatterPlus,
    )

    # general hyperparameters (training values)
    general_args = parser.add_argument_group("General Hyperparameters")
    general_args.add_argument("--runs", type=int, default=100,
                              help="Runs of LKH algorithm for each instance.")
    general_args.add_argument("--device", type=int, default=0,
                              help="GPU device ID. -1 for CPU only.")

    # customized hyperparameters (preferred default values)
    customized_args = parser.add_argument_group("Customized Hyperparameters")
    customized_args.add_argument("--root", type=str, default="./data/",
                                 help="Path to data directory, where all the datasets will be placed.")
    customized_args.add_argument("--append", action="store_true",
                                 help="Append the data to the existed file.")
    customized_args.add_argument("--no-print-param", action="store_true",
                                 help="Do not print the parameter information in log files.")

    # typical hyperparameters (values for research)
    typical_args = parser.add_argument_group("TYPICAL HYPERPARAMETERS")
    typical_args.add_argument("--size", type=int, default=50,
                              help="Size of TSP instances.")
    typical_args.add_argument("--num", type=int, default=1000,
                              help="Number of TSP instances.")

    args = parser.parse_args()

    if not args.no_print_param:
        for key, value in vars(args).items():
            print(f"{key} = {value}")
        print(f"=" * 20)
        print()

    return args


if __name__ == '__main__':
    args = parse()
    main(args)
