import torch
import time
import sys
import argparse
from rich_argparse_plus import RichHelpFormatterPlus
from pathlib import Path

sys.path.append(f"./utils/")
sys.path.append(f"./scripts/")
from augmentation import Augmentation
from model import TSP_net_general_test_version as Model

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def calculate_tour_length_batch(x, tour):
    bsz = x.shape[0]
    nb_nodes = tour.shape[1]
    arange_vec = torch.arange(bsz, device=x.device)
    first_cities = x[arange_vec, tour[:, 0], :]
    previous_cities = first_cities
    L = torch.zeros(bsz, device=x.device)
    with torch.no_grad():
        for i in range(1, nb_nodes):
            current_cities = x[arange_vec, tour[:, i], :]
            L += torch.sum((current_cities - previous_cities) ** 2, dim=1) ** 0.5  # dist(current, previous node)
            previous_cities = current_cities
        L += torch.sum((current_cities - first_cities) ** 2, dim=1) ** 0.5  # dist(last, first node)
    return L


def main(args):
    # device settings (preferred GPU)
    if torch.cuda.is_available() and args.device >= 0:
        device = torch.device("cuda")
        print(f"Detect GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
    print(f"Using: {device}")

    checkpoint_dir = Path(args.save_root)
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True)

    # network initialization
    model_train = Model(args.dim_input_nodes, args.dim_emb, args.dim_ff,
                        args.nb_layers_global_encoder, args.nb_layers_local_encoder,
                        args.nb_layers_decoder, args.nb_heads, args.local_k, batchnorm=args.batchnorm)

    model_baseline = Model(args.dim_input_nodes, args.dim_emb, args.dim_ff,
                           args.nb_layers_global_encoder, args.nb_layers_local_encoder,
                           args.nb_layers_decoder, args.nb_heads, args.local_k, batchnorm=args.batchnorm)

    model_train = model_train.to(device)
    model_baseline = model_baseline.to(device)
    model_baseline.eval()

    optimizer = torch.optim.Adam(model_train.parameters(), lr=args.lr)

    aug_module = Augmentation()

    ######################
    # Main training loop #
    ######################

    start_training_time = time.time()
    for epoch in range(0, args.nb_epochs):

        ###########################
        # Train model for one epoch
        ###########################

        start = time.time()
        model_train.train()

        for step in range(1, args.nb_batch_per_epoch + 1):
            # create random input instances, shape(x) = (bsz, size, 2)
            x = torch.rand(int(args.bsz / args.aug_num), args.size, args.dim_input_nodes, device=device)
            x_repeat = x.unsqueeze(1).repeat((1, args.aug_num, 1, 1)).view((args.bsz, args.size, args.dim_input_nodes))
            x_aug = aug_module.aug_for_train(args.aug_type, x_repeat, args.aug_num)

            # compute tours for baseline, shape(tour) = (bsz, size)
            with torch.no_grad():
                tour_baseline, _ = model_baseline(x_aug, deterministic=True)

            # compute tours for model, shape(tour) = (bsz, size)
            tour_train, nabla_pi = model_train(x_aug, deterministic=False)

            # get the lengths of the tours, shape(L) = (bsz,)
            L_train = calculate_tour_length_batch(x_repeat, tour_train)
            L_baseline = calculate_tour_length_batch(x_repeat, tour_baseline)

            # backpropagation
            loss = torch.mean((L_train - L_baseline) * nabla_pi)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        time_one_epoch = time.time() - start
        time_tot = time.time() - start_training_time

        #######################
        # Update baseline model
        #######################

        model_train.eval()
        mean_tour_length_train = 0
        mean_tour_length_baseline = 0

        for step in range(0, args.nb_batch_eval):
            # create random input instances, shape(x) = (bsz, size, 2)
            x = torch.rand(int(args.bsz / args.aug_num), args.size, args.dim_input_nodes, device=device)
            x_repeat = x.unsqueeze(1).repeat((1, args.aug_num, 1, 1)).view((args.bsz, args.size, args.dim_input_nodes))
            x_aug = aug_module.aug_for_train(args.aug_type, x_repeat, args.aug_num)

            # compute tours for model and baseline, shape(tour) = (bsz, size)
            with torch.no_grad():
                tour_train, _ = model_train(x_aug, deterministic=True)
                tour_baseline, _ = model_baseline(x_aug, deterministic=True)

            # get the lengths of the tours on !!x_repeat!!, shape(L) = (bsz,)
            L_train = calculate_tour_length_batch(x_repeat, tour_train)
            L_baseline = calculate_tour_length_batch(x_repeat, tour_baseline)
            # get the best solution for each augmentation group, shape(L) = (bsz // args.aug_num,)
            L_train = L_train.view((int(args.bsz / args.aug_num), args.aug_num))
            L_train = torch.min(L_train, dim=1).values
            L_baseline = L_baseline.view((int(args.bsz / args.aug_num), args.aug_num))
            L_baseline = torch.min(L_baseline, dim=1).values

            # Compute the mean tour length
            mean_tour_length_train += L_train.mean().item()
            mean_tour_length_baseline += L_baseline.mean().item()

        mean_tour_length_train = mean_tour_length_train / args.nb_batch_eval
        mean_tour_length_baseline = mean_tour_length_baseline / args.nb_batch_eval

        # update baseline if train model is better
        update_baseline = mean_tour_length_train + args.tol < mean_tour_length_baseline
        if update_baseline:
            model_baseline.load_state_dict(model_train.state_dict())

        #####################
        # Save Model and Info
        #####################

        # Print and save in txt file
        epoch_info = f"Epoch: {epoch:d}, epoch time: {time_one_epoch / 60:.3f} min, " \
                     f"tot time: {time_tot / 3600:.3f}h, L_train: {mean_tour_length_train:.3f}, " \
                     f"L_base: {mean_tour_length_baseline:.3f}, update: {update_baseline}"
        print(epoch_info)

        # Saving checkpoint

        checkpoint = {
            'epoch': epoch,
            'time': time_one_epoch,
            'tot_time': time_tot,
            'mean_tour_length_train': mean_tour_length_train,
            'mean_tour_length_baseline': mean_tour_length_baseline,
            'model_baseline': model_baseline.state_dict(),
            'model_train': model_train.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        checkpoint_name = f"checkpoint_{str(epoch).zfill(3)}.pkl"
        checkpoint_path = checkpoint_dir.joinpath(checkpoint_name)

        torch.save(checkpoint, checkpoint_path)

    print(f"Training ends!")
    print(f"Checkpoint saved at {checkpoint_dir}.")


def parse():
    RichHelpFormatterPlus.choose_theme("prince")
    parser = argparse.ArgumentParser(
        description="TS4 model phase evaluation (TS3) for random TSP.",
        formatter_class=RichHelpFormatterPlus,
    )

    # general hyperparameters (training values)
    general_args = parser.add_argument_group("General Hyperparameters")
    general_args.add_argument("--lr", type=int, default=1.875e-5,
                              help="Learning rate for training.")
    general_args.add_argument("--bsz", type=int, default=48,
                              help="Batch size for training.")
    general_args.add_argument("--test-bsz", type=int, default=64,
                              help="Batch size for evaluation in training.")
    general_args.add_argument("--dim-emb", type=int, default=128,
                              help="The dimension of node embeddings.")
    general_args.add_argument("--nb-epochs", type=int, default=170,
                              help="The number of total training epochs.")
    general_args.add_argument("--nb-batch-per-epoch", type=int, default=1600,
                              help="The number of training batches for each epoch.")
    general_args.add_argument("--nb-batch-eval", type=int, default=80,
                              help="The number of evaluation batches for each epoch.")
    general_args.add_argument("--dim-input-nodes", type=int, default=2,
                              help="The feature number of each node.")
    general_args.add_argument("--nb-heads", type=int, default=8,
                              help="The number of attention heads.")
    general_args.add_argument("--dim-ff", type=int, default=512,
                              help="The dimension of feed-forward networks.")
    general_args.add_argument("--batchnorm", action="store_true", default=False,
                              help="Use batchnorm layers.")
    general_args.add_argument("--tol", type=float, default=1e-3,
                              help="Tolerance for updating baseline model.")
    general_args.add_argument("--no-print-param", action="store_true",
                              help="Do not print the parameter information in log files.")
    general_args.add_argument("--device", type=int, default=0,
                              help="GPU device for training. -1 for CPU.")

    # customized hyperparameters (preferred default values)
    customized_args = parser.add_argument_group("Customized Hyperparameters")
    customized_args.add_argument("--save-root", type=str, default="./models/",
                                 help="Path to model saving.")
    customized_args.add_argument("--aug-type", type=str, default="mixture",
                              help="Augmentation type for each TSP instance.")

    # typical hyperparameters (values for research)
    typical_args = parser.add_argument_group("TYPICAL HYPERPARAMETERS")
    typical_args.add_argument("--nb-layers-global-encoder", type=int, default=4,
                              help="The number of global Encoder layers.")
    typical_args.add_argument("--nb-layers-local-encoder", type=int, default=6,
                              help="The number of local Encoder layers.")
    typical_args.add_argument("--nb-layers-decoder", type=int, default=2,
                              help="The number of Decoder layers.")
    typical_args.add_argument("--local-k", type=int, default=12,
                              help="The number of knn neighbors.")
    typical_args.add_argument("--size", type=int, default=50,
                              help="Size of TSP instances.")
    typical_args.add_argument("--aug-num", type=int, default=8,
                              help="Augmentation size for each TSP instance.")

    args = parser.parse_args()

    if not args.no_print_param:
        for key, value in vars(args).items():
            print(f"{key} = {value}")
        print(f"=" * 20)
        print()

    return args


if __name__ == '__main__':
    args = parse()
    main(args)
