"""
Train the ESIM model on the preprocessed MultiNLI dataset.
"""
# Aurelien Coet, 2019.

import os
import argparse
import pickle
import torch
import json

import matplotlib.pyplot as plt
import torch.nn as nn

from torch.utils.data import DataLoader

from esim.data import NLIDataset
from esim.model import ESIM
from utils import train, validate


def main(train_file,
         valid_files,
         embeddings_file,
         target_dir,
         hidden_size=300,
         dropout=0.5,
         num_classes=3,
         epochs=64,
         batch_size=32,
         lr=0.0004,
         patience=5,
         max_grad_norm=10.0,
         checkpoint=None):
    """
    Train the ESIM model on the SNLI dataset.

    Args:
        train_file: A path to some preprocessed data that must be used
            to train the model.
        valid_files: A dict containing the paths to the preprocessed matched
            and mismatched datasets that must be used to validate the model.
        embeddings_file: A path to some preprocessed word embeddings that
            must be used to initialise the model.
        target_dir: The path to a directory where the trained model must
            be saved.
        hidden_size: The size of the hidden layers in the model. Defaults
            to 300.
        dropout: The dropout rate to use in the model. Defaults to 0.5.
        num_classes: The number of classes in the output of the model.
            Defaults to 3.
        epochs: The maximum number of epochs for training. Defaults to 64.
        batch_size: The size of the batches for training. Defaults to 32.
        lr: The learning rate for the optimizer. Defaults to 0.0004.
        patience: The patience to use for early stopping. Defaults to 5.
        checkpoint: A checkpoint from which to continue training. If None,
            training starts from scratch. Defaults to None.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(20 * "=", " Preparing for training ", 20 * "=")

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # -------------------- Data loading ------------------- #
    print("\t* Loading training data...")
    with open(train_file, 'rb') as pkl:
        train_data = NLIDataset(pickle.load(pkl))

    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading validation data...")
    with open(os.path.normpath(valid_files["matched"]), 'rb') as pkl:
        matched_valid_data = NLIDataset(pickle.load(pkl))

    with open(os.path.normpath(valid_files["mismatched"]), 'rb') as pkl:
        mismatched_valid_data = NLIDataset(pickle.load(pkl))

    matched_valid_loader = DataLoader(matched_valid_data,
                                      shuffle=False,
                                      batch_size=batch_size)
    mismatched_valid_loader = DataLoader(mismatched_valid_data,
                                         shuffle=False,
                                         batch_size=batch_size)

    # -------------------- Model definition ------------------- #
    print('\t* Building model...')
    with open(embeddings_file, 'rb') as pkl:
        embeddings = torch.tensor(pickle.load(pkl), dtype=torch.float)\
                     .to(device)

    model = ESIM(embeddings.shape[0],
                 embeddings.shape[1],
                 hidden_size,
                 embeddings=embeddings,
                 dropout=dropout,
                 num_classes=num_classes,
                 device=device).to(device)

    # -------------------- Preparation for training  ------------------- #
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           factor=0.5,
                                                           patience=0)

    best_score = 0.0
    start_epoch = 1

    # Data for loss curves plot.
    epochs_count = []
    train_losses = []
    matched_valid_losses = []
    mismatched_valid_losses = []

    # Continuing training from a checkpoint if one was given as argument.
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        best_score = checkpoint['best_score']

        print("\t* Training will continue on existing model from epoch {}..."
              .format(start_epoch))

        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epochs_count = checkpoint['epochs_count']
        train_losses = checkpoint['train_losses']
        matched_valid_losses = checkpoint['match_valid_losses']
        mismatched_valid_losses = checkpoint['mismatch_valid_losses']

    # Compute loss and accuracy before starting (or resuming) training.
    _, valid_loss, valid_accuracy = validate(model,
                                             matched_valid_loader,
                                             criterion)
    print("\t* Validation loss before training on matched data: {:.4f}, accuracy: {:.4f}%"
          .format(valid_loss, (valid_accuracy*100)))

    _, valid_loss, valid_accuracy = validate(model,
                                             mismatched_valid_loader,
                                             criterion)
    print("\t* Validation loss before training on mismatched data: {:.4f}, accuracy: {:.4f}%"
          .format(valid_loss, (valid_accuracy*100)))

    # -------------------- Training epochs ------------------- #
    print("\n",
          20 * "=",
          "Training ESIM model on device: {}".format(device),
          20 * "=")

    patience_counter = 0
    for epoch in range(start_epoch, epochs+1):
        epochs_count.append(epoch)

        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model,
                                                       train_loader,
                                                       optimizer,
                                                       criterion,
                                                       epoch,
                                                       max_grad_norm)

        train_losses.append(epoch_loss)
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (epoch_accuracy*100)))

        print("* Validation for epoch {} on matched data:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = validate(model,
                                                          matched_valid_loader,
                                                          criterion)
        matched_valid_losses.append(epoch_loss)
        print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%"
              .format(epoch_time, epoch_loss, (epoch_accuracy*100)))

        print("* Validation for epoch {} on mismatched data:".format(epoch))
        epoch_time, epoch_loss, mis_epoch_accuracy = validate(model,
                                                              mismatched_valid_loader,
                                                              criterion)
        mismatched_valid_losses.append(epoch_loss)
        print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (mis_epoch_accuracy*100)))

        # Update the optimizer's learning rate with the scheduler.
        scheduler.step(epoch_accuracy)

        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            # Save the best model. The optimizer is not saved to avoid having
            # a checkpoint file that is too heavy to be shared. To resume
            # training from the best model, use the 'esim_*.pth.tar'
            # checkpoints instead.
            torch.save({'epoch': epoch,
                        'model': model.state_dict(),
                        'best_score': best_score,
                        'epochs_count': epochs_count,
                        'train_losses': train_losses,
                        'match_valid_losses': matched_valid_losses,
                        'mismatch_valid_losses': mismatched_valid_losses},
                       os.path.join(target_dir, "best.pth.tar"))

        # # Save the model at each epoch.
        # torch.save({'epoch': epoch,
        #             'model': model.state_dict(),
        #             'best_score': best_score,
        #             'optimizer': optimizer.state_dict(),
        #             'epochs_count': epochs_count,
        #             'train_losses': train_losses,
        #             'match_valid_losses': matched_valid_losses,
        #             'mismatch_valid_losses': mismatched_valid_losses},
        #            os.path.join(target_dir, "esim_{}.pth.tar".format(epoch)))

        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break

    # Plotting of the loss curves for the train and validation sets.
    plt.figure()
    plt.plot(epochs_count, train_losses, '-r')
    plt.plot(epochs_count, matched_valid_losses, '-b')
    plt.plot(epochs_count, mismatched_valid_losses, '-g')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training loss',
                'Validation loss (matched set)',
                'Validation loss (mismatched set)'])
    plt.title('Cross entropy loss')
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train the ESIM model on MultiNLI')
    parser.add_argument('--config',
                        default="../config/training/mnli_training.json",
                        help='Path to a json configuration file')
    parser.add_argument('--checkpoint',
                        default=None,
                        help='path to a checkpoint file to resume training')
    args = parser.parse_args()

    with open(os.path.normpath(args.config), 'r') as config_file:
        config = json.load(config_file)

    main(os.path.normpath(config["train_data"]),
         config["valid_data"],
         os.path.normpath(config["embeddings"]),
         os.path.normpath(config["target_dir"]),
         config["hidden_size"],
         config["dropout"],
         config["num_classes"],
         config["epochs"],
         config["batch_size"],
         config["lr"],
         config["patience"],
         config["max_gradient_norm"],
         args.checkpoint)
