import torch
import numpy as np
import math
import time
import sys
import shutil
import argparse
from rich_argparse_plus import RichHelpFormatterPlus
from tqdm import tqdm
from pathlib import Path

sys.path.append(f"./utils/")
sys.path.append(f"./scripts/")
from utils import get_dist_matrix
from utils import calculate_tour_length
from utils import load_tsp_instances
from augmentation import Augmentation
from model import TSP_net_general_test_version as Model


def main(args):
    if torch.cuda.is_available() and args.device >= 0:
        device = torch.device("cuda")
        print('GPU name: {:s}'.format(torch.cuda.get_device_name(0)))

    print(f"Experiment using {device}")

    # Preparation
    hm_scale = math.ceil(64 / args.aug_size)

    print("Trying loading models...")
    model = Model(args.dim_input_nodes, args.dim_emb, args.dim_ff, args.nb_layers_global_encoder,
                  args.nb_layers_local_encoder, args.nb_layers_decoder, args.nb_heads, args.local_k, args.batchnorm)
    model_path = args.model_path
    load_params = torch.load(model_path)
    model.load_state_dict(load_params['model_baseline'])
    model.to(device)
    model.eval()
    print("Successfully loading the model")

    # load instance
    size = args.size
    print(f"Evaluating on TSP{size}")
    path = Path(args.data_root).joinpath(f"TSP{size}.txt")
    tsp_instances, opt_tours, opt_lens, size, num = load_tsp_instances(path)

    # create output directories
    dir_name = Path(args.save_root).joinpath(f"tsp{size}/")
    if dir_name.exists():
        shutil.rmtree(dir_name)
    instances_dir_name = dir_name.joinpath(f"instances/")
    heatmap_dir_name = dir_name.joinpath(f"heatmap/")
    init_sol_dir_name = dir_name.joinpath(f"init_sol/")
    dir_name.mkdir(parents=True)
    instances_dir_name.mkdir()
    heatmap_dir_name.mkdir()
    init_sol_dir_name.mkdir()

    # transform tspfarm to MCTS format
    instances_path = instances_dir_name.joinpath(f"TSP{size}.txt")
    with open(instances_path, 'w+', encoding='utf8') as write_file:
        for index in range(num):
            text = ""
            for node in tsp_instances[index]:
                x, y = node
                text += f"{x} {y} "
            text += f"output"
            for i in opt_tours[index]:
                text += f" {int(i) + 1}"
            text += f" \n"
            write_file.write(text)

    # model inference
    print(f"Start Model Inference")
    best_lens = []
    gaps = []
    total_time = 0
    # comment these lines for full evaluation
    # num = 10
    # print(f"num = {num}, this is only for pre-evaluation.")
    for index in tqdm(range(num)):
        instance = torch.tensor(tsp_instances[index]).float().to(device)
        opt_tour = torch.tensor(opt_tours[index]).to(device)
        opt_len = torch.tensor(opt_lens[index]).float().to(device)

        start_time = time.time()
        dist_matrix = get_dist_matrix(instance)
        dist_matrix += 2 * torch.eye(size).to(device)
        _, top_index = dist_matrix.topk(args.local_k, largest=False, dim=-1)

        augment_module = Augmentation()
        repeated_instances = instance.unsqueeze(0).repeat((args.aug_size, 1, 1))
        augmented_instances = augment_module.aug_for_train('mixture', repeated_instances, args.aug_size)

        with torch.no_grad():
            model_tours, _ = model(augmented_instances, deterministic=True)

        heatmap = torch.full((size, size), fill_value=0)
        zero_to_size = torch.arange(size)
        heatmap[zero_to_size, top_index[zero_to_size].transpose(0, 1)] = 1

        model_tours_left_shift = torch.roll(model_tours, shifts=-1, dims=1)
        for i in range(model_tours.shape[0]):
            heatmap[model_tours[i], model_tours_left_shift[i]] += hm_scale
            heatmap[model_tours_left_shift[i], model_tours[i]] += hm_scale

        heatmap = torch.distributions.Categorical(heatmap).probs

        model_lens = calculate_tour_length(dist_matrix, model_tours)
        best_tour_index = model_lens.argmin()
        best_tour = model_tours[best_tour_index]
        end_time = time.time()
        total_time += end_time - start_time

        best_len = model_lens[best_tour_index]
        best_lens.append(best_len)
        gap = best_len / opt_len - 1
        gaps.append(gap.item())

        heatmap_file = heatmap_dir_name.joinpath(f"heatmap_{index}.txt")
        with open(heatmap_file, 'w+', encoding='utf8') as write_file:
            write_file.write(f"{size}\n")
            for row in heatmap:
                for val in row:
                    write_file.write(f" {val.item()}")
            write_file.write(f"\n")

        init_sol_file = init_sol_dir_name.joinpath(f"init_sol_{index}.txt")
        with open(init_sol_file, 'w+', encoding='utf8') as write_file:
            contents = " ".join(str(x.item()) for x in best_tour)
            write_file.write(contents)

    # statistics
    avg_model_len = sum(best_lens) / len(best_lens)
    avg_opt_len = sum(opt_lens) / len(opt_lens)
    gaps = np.array(gaps) * 100
    avg_gap = np.mean(gaps)
    std_gap = np.std(gaps)
    print(f"Avg model len {avg_model_len:.4f}, Avg opt len {avg_opt_len:.4f}, "
          f"mean gap {avg_gap:.4f}%, std {std_gap:.4f}%")
    print(f"Using {total_time:.5f} seconds for {num} instances. Each inference cost {total_time / num:.5f} seconds.")


def parse():
    RichHelpFormatterPlus.choose_theme("prince")
    parser = argparse.ArgumentParser(
        description="TS4 model phase evaluation (TS3) for random TSP.",
        formatter_class=RichHelpFormatterPlus,
    )

    # general hyperparameters (training values)
    general_args = parser.add_argument_group("General Hyperparameters")
    general_args.add_argument("--dim-emb", type=int, default=128,
                              help="The dimension of node embeddings.")
    general_args.add_argument("--nb-layers-global-encoder", type=int, default=4,
                              help="The number of global Encoder layers.")
    general_args.add_argument("--nb-layers-local-encoder", type=int, default=6,
                              help="The number of local Encoder layers.")
    general_args.add_argument("--nb-layers-decoder", type=int, default=2,
                              help="The number of Decoder layers.")
    general_args.add_argument("--dim-input-nodes", type=int, default=2,
                              help="The feature number of each node.")
    general_args.add_argument("--nb-heads", type=int, default=8,
                              help="The number of attention heads.")
    general_args.add_argument("--dim-ff", type=int, default=512,
                              help="The dimension of feed-forward networks.")
    general_args.add_argument("--global-k", type=int, default=20,
                              help=".")
    general_args.add_argument("--batchnorm", action="store_true", default=False,
                              help="Use batchnorm or layernorm otherwise.")
    general_args.add_argument("--local-k", type=int, default=12,
                              help="The number of knn neighbors.")
    general_args.add_argument("--no-print-param", action="store_true",
                              help="Do not print the parameter information in log files.")
    general_args.add_argument("--device", type=int, default=0,
                              help="GPU device for training. -1 for CPU.")

    # customized hyperparameters (preferred default values)
    customized_args = parser.add_argument_group("Customized Hyperparameters")
    customized_args.add_argument("--data-root", type=str, default="./data/tspfarm/",
                                 help="Path to TSP instances.")
    customized_args.add_argument("--save-root", type=str, default="./MCTS/data/",
                                 help="Path to TSP instances.")

    # typical hyperparameters (values for research)
    typical_args = parser.add_argument_group("TYPICAL HYPERPARAMETERS")
    typical_args.add_argument("--size", type=int, default=50,
                              help="Size of TSP instances.")
    typical_args.add_argument("--aug-size", type=int, default=16,
                              help="Augmentation size for each TSP instance.")
    typical_args.add_argument("--model-path", type=str, default="./models/checkpoint_23-08-24--15-11-53-n50-gpu0.pkl",
                              help="Trained model path.")

    args = parser.parse_args()

    if not args.no_print_param:
        for key, value in vars(args).items():
            print(f"{key} = {value}")
        print(f"=" * 20)
        print()

    return args


if __name__ == '__main__':
    args = parse()
    main(args)
