import torch
import time
import os
import math
import sys
import shutil
import argparse
from rich_argparse_plus import RichHelpFormatterPlus
from tqdm import tqdm
from pathlib import Path

sys.path.append(f"./utils/")
sys.path.append(f"./scripts/")
from utils import tsplib_collections
from utils import get_dist_matrix
from utils import calculate_tour_length
from augmentation import Augmentation
from model import TSP_net_general_test_version as Model


def load_tsplib(path):
    with open(path, 'r', encoding='utf8') as file:
        contents = [x.strip() for x in file.readlines()]
    start_identifier1 = f'NODE_COORD_SECTION'
    start_identifier2 = f'DISPLAY_DATA_SECTION'
    end_identifier = f"EOF"

    # start
    if start_identifier1 in contents:
        start = contents.index(start_identifier1) + 1
    elif start_identifier2 in contents:
        start = contents.index(start_identifier2) + 1
    else:
        return None
    # end
    if end_identifier in contents:
        end = contents.index(end_identifier)
    else:
        end = len(contents)

    contents = contents[start:end]

    tsp_instance = []
    for node_info in contents:
        tsp_instance.append([float(x) for x in node_info.split(" ")[1:] if x])
    return tsp_instance


def main(args):
    if torch.cuda.is_available() and args.device >= 0:
        device = torch.device("cuda")
        print('GPU name: {:s}'.format(torch.cuda.get_device_name(0)))

    print(f"Experiment using {device}")

    # Preparation
    print("Trying loading models...")
    model = Model(args.dim_input_nodes, args.dim_emb, args.dim_ff, args.nb_layers_global_encoder,
                  args.nb_layers_local_encoder, args.nb_layers_decoder, args.nb_heads, args.local_k,
                  args.batchnorm)
    model_path = args.model_path
    load_params = torch.load(model_path)
    model.load_state_dict(load_params['model_baseline'])
    model.to(device)
    model.eval()
    print("Successfully loading the model")

    # create output directories
    dir_name = Path(args.save_root).joinpath(f"tsplib/")
    if dir_name.exists():
        shutil.rmtree(dir_name)
    instances_dir_name = dir_name.joinpath(f"instances/")
    heatmap_dir_name = dir_name.joinpath(f"heatmap/")
    init_sol_dir_name = dir_name.joinpath(f"init_sol/")
    dir_name.mkdir(parents=True)
    instances_dir_name.mkdir()
    heatmap_dir_name.mkdir()
    init_sol_dir_name.mkdir()

    for name, opt_len in tsplib_collections.items():
        path = Path(args.data_root).joinpath(f"{name}.tsp")
        tsp_instance = load_tsplib(path)
        tsp_instance = torch.tensor(tsp_instance).float().to(device)
        scaled_instance = tsp_instance + tsp_instance.min(0).values
        scaled_instance = scaled_instance / scaled_instance.max()
        size = tsp_instance.shape[0]

        # model inference
        # print(f"Start Model Inference")
        start_time = time.time()

        # parameters settings
        if size <= 50:
            k_aug = 64
        elif size <= 200:
            k_aug = 32
        else:
            k_aug = 16

        if size >= 4000:
            max_bs = 2
        elif size >= 3000:
            max_bs = 4
        elif size >= 1500:
            max_bs = 8
        else:
            max_bs = k_aug

        repeat_num = max(1, k_aug // max_bs)
        hm_scale = math.ceil(64 / k_aug)

        # preprocessing
        dist_matrix = get_dist_matrix(scaled_instance)
        dist_matrix += 2 * torch.eye(size).to(device)
        augment_module = Augmentation()
        repeated_instances = scaled_instance.unsqueeze(0).repeat((k_aug, 1, 1))
        augmented_instances = augment_module.aug_for_train('mixture', repeated_instances, k_aug)

        # model inference
        model_tours = []
        with torch.no_grad():
            for i in range(repeat_num):
                tours, _ = model(augmented_instances[max_bs * i:max_bs * (i + 1), :, :], deterministic=True)
                model_tours.append(tours)
        model_tours = torch.cat(model_tours, dim=0)

        # select the best tour
        model_lens = calculate_tour_length(dist_matrix, model_tours)
        best_tour_index = model_lens.argmin()
        best_tour = model_tours[best_tour_index]

        # initialization (top10 to be 1, others to be 0)
        _, top_index = dist_matrix.topk(args.local_k, largest=False, dim=-1)
        heatmap = torch.full((size, size), fill_value=0)
        zero_to_size = torch.arange(size)
        heatmap[zero_to_size, top_index[zero_to_size].transpose(0, 1)] = 1

        # augmented solution to heatmap
        model_tours_left_shift = torch.roll(tours, shifts=-1, dims=1)
        for i in range(tours.shape[0]):
            heatmap[tours[i], model_tours_left_shift[i]] += hm_scale
            heatmap[model_tours_left_shift[i], tours[i]] += hm_scale

        heatmap = torch.distributions.Categorical(heatmap).probs

        end_time = time.time()
        delta = end_time - start_time

        # stats
        true_dist_matrix = get_dist_matrix(tsp_instance)
        best_len = calculate_tour_length(true_dist_matrix, best_tour)
        best_len = round(best_len.item())
        gap = best_len / opt_len - 1
        print(f"TSP-{name}, Model inference length {best_len}, gap {gap * 100:.4f}%, optimal {opt_len}")

        # exports
        instances_path = instances_dir_name.joinpath(f"{name}.txt")
        with open(instances_path, 'w+', encoding='utf8') as write_file:
            text = ""
            for node in tsp_instance:
                x, y = node
                text += f"{x} {y} "
            text += f"optimal"
            text += f" {opt_len}"
            write_file.write(text)

        heatmap_file = heatmap_dir_name.joinpath(f"{name}.txt")
        with open(heatmap_file, 'w+', encoding='utf8') as file:
            file.write(f"{size}\n")
            for row in heatmap:
                for val in row:
                    file.write(f" {val.item()}")
            file.write(f"\n")

        init_sol_file = init_sol_dir_name.joinpath(f"{name}.txt")
        with open(init_sol_file, 'w+', encoding='utf8') as file:
            contents = " ".join(str(x.item()) for x in best_tour)
            file.write(contents)


def parse():
    RichHelpFormatterPlus.choose_theme("prince")
    parser = argparse.ArgumentParser(
        description="TS4 model phase evaluation (TS3) for TSPLIB.",
        formatter_class=RichHelpFormatterPlus,
    )

    # general hyperparameters (training values)
    general_args = parser.add_argument_group("General Hyperparameters")
    general_args.add_argument("--nb-nodes", type=int, default=50,
                              help="The size of each tsp model size.")
    general_args.add_argument("--dim-emb", type=int, default=128,
                              help="The dimension of node embeddings.")
    general_args.add_argument("--nb-layers-global-encoder", type=int, default=4,
                              help="The number of global Encoder layers.")
    general_args.add_argument("--nb-layers-local-encoder", type=int, default=6,
                              help="The number of local Encoder layers.")
    general_args.add_argument("--nb-layers-decoder", type=int, default=2,
                              help="The number of Decoder layers.")
    general_args.add_argument("--nb-epochs", type=int, default=1500,
                              help="The number of total training epochs.")
    general_args.add_argument("--dim-input-nodes", type=int, default=2,
                              help="The feature number of each node.")
    general_args.add_argument("--nb-heads", type=int, default=8,
                              help="The number of attention heads.")
    general_args.add_argument("--dim-ff", type=int, default=512,
                              help="The dimension of feed-forward networks.")
    general_args.add_argument("--global-k", type=int, default=20,
                              help=".")
    general_args.add_argument("--batchnorm", action="store_true", default=False,
                              help="Use batchnorm or layernorm otherwise.")
    general_args.add_argument("--local-k", type=int, default=12,
                              help="The number of knn neighbors.")
    general_args.add_argument("--no-print-param", action="store_true",
                              help="Do not print the parameter information in log files.")
    general_args.add_argument("--device", type=int, default=0,
                              help="GPU device for training. -1 for CPU.")

    # customized hyperparameters (preferred default values)
    customized_args = parser.add_argument_group("Customized Hyperparameters")
    customized_args.add_argument("--data-root", type=str, default="./data/tsplib/",
                                 help="Path to TSP instances.")
    customized_args.add_argument("--save-root", type=str, default="./MCTS/data/",
                                 help="Path to TSP instances.")

    # typical hyperparameters (values for research)
    typical_args = parser.add_argument_group("TYPICAL HYPERPARAMETERS")
    typical_args.add_argument("--model-path", type=str, default="./models/checkpoint_23-08-17--04-49-43-n50-gpu0.pkl",
                              help="Trained model path.")

    args = parser.parse_args()

    if not args.no_print_param:
        for key, value in vars(args).items():
            print(f"{key} = {value}")
        print(f"=" * 20)
        print()

    return args


if __name__ == '__main__':
    args = parse()
    main(args)
