# Copyright 2022 - Intel Corp. All rights reserved.
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage

"""
Implementation of a new method for fine-tuning transformer models that we call
Information Gain Filtration 'IGF' on WikiText data set and compared the results
with the standard fine-tuning method

Steps followed in the code:

1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular
context. Informativeness is the change in metric between the model’s accuracy on an
objective set before and after seeing that context. For casual language modeling, the
metric is perplexity.

2) A secondary learner is trained to infer a function approximation for IG using the dataset
created in (1).

3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.

Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering

"""

# Prerequisite libraries:

import argparse
import random

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler

import joblib
from igf.igf import (
    SecondaryLearner,
    collect_objective_set,
    compute_perplexity,
    generate_datasets,
    load_gpt2,
    recopy_gpt2,
    set_seed,
    train_secondary_learner,
)
from transformers import GPT2LMHeadModel


def generate_n_pairs(
    context_len=32,
    max_steps=10,
    size_objective_set=100,
    min_len=1026,
    trim=True,
    data_file="data/tokenized_stories_train_wikitext103.jbl",
    igf_data_file="igf_context_pairs.jbl",
):

    """
    Collecting *n* pairs for training the secondary learner
    Args:
        context_len: The maximum total input sequence length after tokenization. Sequences longer
                    than this will be truncated, sequences shorter will be padded
        max_steps: To calculate training epochs of secondary learner
        size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
        min_len: The minimum length of the article to be used as objective set
        trim: If True truncate the context if it exceeds context length
        data_file: Tokenized data set split for training and evaluation of model
        igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner

    Returns:
        Data stored in igf_data_file

    """
    # generates same data everytime
    set_seed(3)
    # generate train_data and objective_set
    train_data, objective_set = generate_datasets(
        context_len, data_file, number=size_objective_set, min_len=1026, trim=True
    )
    # keeps model same across runs
    set_seed(4)
    # model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
    # can we train on GPU?
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrained model
    model = load_gpt2("gpt2").to(device)
    print("computing perplexity on objective set")
    orig_perp = compute_perplexity(model, objective_set, context_len).item()
    print("perplexity on objective set:", orig_perp)

    # collect igf pairs and save to file demo.jbl
    collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)

    # clean up, delete model and data we don't need anymore
    del model, train_data, objective_set
    torch.cuda.empty_cache()


def training_secondary_learner(
    secondary_learner_train_data,
    secondary_learner_max_epochs=15,
    secondary_learner_batch_size=128,
    eval_freq=100,
    igf_model_path="igf_model.pt",
):
    """
    Train the secondary learner

    Args:
        secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
        secondary_learner_max_epochs: Number of epochs to train secondary learner
        secondary_learner_batch_size: Batch size to train secondary learner
        eval_freq (object): secondary model evaluation can be triggered at eval_freq
        igf_model_path: path to store trained secondary learner

    Returns:
        Trained secondary learner
    """

    set_seed(42)

    # Load pre-trained model
    model = GPT2LMHeadModel.from_pretrained("gpt2")

    # Initialize secondary learner to use embedding weights of model
    secondary_learner = SecondaryLearner(model)

    # Train secondary learner
    secondary_learner = train_secondary_learner(
        secondary_learner,
        secondary_learner_train_data,
        max_epochs=secondary_learner_max_epochs,
        batch_size=secondary_learner_batch_size,
        eval_freq=100,
        igf_model_path=igf_model_path,
    )

    del model, secondary_learner_train_data
    torch.cuda.empty_cache()

    return secondary_learner


def finetune(
    model,
    train_dataset,
    test_dataset,
    context_len=32,
    max_steps=1000,
    batch_size=16,
    threshold=1.0,
    recopy_model=recopy_gpt2,
    secondary_learner=None,
    eval_interval=10,
    finetuned_model_name="gpt2_finetuned.pt",
):
    """
    fine-tune with IGF if secondary_learner is not None, else standard fine-tuning

    Args:
        model: pre-trained GPT-2 model
        train_dataset: Data set to train GPT-2 model
        test_dataset: Evaluate GPT-2 model
        context_len: The maximum total input sequence length after tokenization. Sequences longer
                    than this will be truncated, sequences shorter will be padded
        max_steps: To calculate training epochs
        batch_size: Batch size to train GPT-2 model
        threshold: The threshold value used by secondary learner to filter the train_data and allow only"
                    informative data as input to the model
        recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
        secondary_learner: Selection of IGF as fine-tuning method if not None
        eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
                        1 standard deviation above average to 1 below average
        fine-tuned_model_name: name of the final final-tuned GPT-2 model

    Returns:
        Fine-tuned GPT-2 model

    """

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler)

    num_train_epochs = max_steps // (len(train_dataset)) + 1
    global_step = 0
    context = torch.zeros((1, context_len), dtype=torch.long, device=device)
    model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)

    model.train()
    if secondary_learner is not None:
        secondary_learner.to(device)
        secondary_learner.eval()
    contexts = []
    examples = 0

    observed_qs = []
    test_perps = []

    # Compute the performance of the transformer model at the beginning
    real_perp = compute_perplexity(model, test_dataset, context_len)
    test_perps.append(real_perp)
    print("Test perplexity, step", global_step, ":", real_perp)
    for epoch in range(int(num_train_epochs)):
        for step, example in enumerate(train_dataloader):
            torch.cuda.empty_cache()
            start = random.randint(0, example.size(2) - context_len - 1)
            context[0, :] = example[0, 0, start : start + context_len]
            lm_optimizer.zero_grad()
            outputs = model(context, labels=context)
            do_backprop = True

            if secondary_learner is not None:
                predicted_q = secondary_learner.forward(
                    torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
                )[0].item()
                observed_qs.append(float(predicted_q))

                # Here we implement the simple non-constant threshold for the predicted IG(X) value
                # We will decay the selectivity of our secondary learner filter from
                # 1 standard deviation above average to 1 below average after 10 batches.

                if global_step == 10:
                    threshold = -1
                if predicted_q < threshold:
                    do_backprop = False

            # If we passed the filter, add the context to the batch!
            if do_backprop:
                contexts.append(np.array(context.cpu()))
                lm_loss = outputs[0]
                lm_loss.backward()
                examples += 1

            del outputs

            # Once the batch is filled with enough contexts, backprop on the batch.
            if examples == batch_size:
                torch.cuda.empty_cache()
                examples = 0
                # Do LM backprop
                torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
                lm_optimizer.step()
                lm_scheduler.step()  # Update learning rate schedule
                global_step += 1
                # Compute the performance of the transformer model at this batch
                if global_step % eval_interval == 0:
                    real_perp = compute_perplexity(model, test_dataset, context_len)
                    test_perps.append(real_perp)

                    print("Test perplexity, step", global_step, ":", real_perp)
            # Break out of the loop after 60 batches
            if max_steps > 0 and global_step > 60:
                break
        if max_steps > 0 and global_step > 60:
            break

    # save finetuned transformer model
    torch.save(model.state_dict(), finetuned_model_name)
    torch.cuda.empty_cache()
    # Do some cleaning up so we can reinitialize for the next run of this function
    del lm_optimizer
    del lm_scheduler
    return model


def main():
    parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain data files for WikiText.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default=None,
        help=(
            "A jbl file containing tokenized data which can be split as objective dataset, "
            "train_dataset and test_dataset."
        ),
    )

    parser.add_argument(
        "--igf_data_file",
        type=str,
        default=None,
        help="A jbl file containing the context and information gain pairs to train secondary learner.",
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the final fine-tuned model is stored.",
    )

    parser.add_argument(
        "--tokenizer_name",
        default=None,
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")

    parser.add_argument(
        "--context_len",
        default=32,
        type=int,
        help=(
            "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        ),
    )

    parser.add_argument(
        "--size_objective_set",
        default=100,
        type=int,
        help="number of articles that are long enough to be used as our objective set",
    )
    parser.add_argument(
        "--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
    )

    parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")

    parser.add_argument(
        "--secondary_learner_batch_size",
        default=128,
        type=int,
        help="batch size of training data for secondary learner",
    )

    parser.add_argument(
        "--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) "
    )

    parser.add_argument(
        "--eval_interval",
        default=10,
        type=int,
        help=(
            "decay the selectivity of our secondary learner filter from"
            "1 standard deviation above average to 1 below average after 10 batches"
        ),
    )

    parser.add_argument(
        "--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
    )

    parser.add_argument(
        "--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
    )

    parser.add_argument(
        "--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
    )

    parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")

    parser.add_argument(
        "--threshold",
        default=1.0,
        type=float,
        help=(
            "The threshold value used by secondary learner to filter the train_data and allow only"
            " informative data as input to the model"
        ),
    )

    parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name")

    parser.add_argument(
        "--recopy_model",
        default=recopy_gpt2,
        type=str,
        help="Reset the model to the original pretrained GPT-2 weights after each iteration",
    )

    # function calls
    # Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
    generate_n_pairs(
        context_len=32,
        max_steps=10,
        size_objective_set=100,
        min_len=1026,
        trim=True,
        data_file="data/tokenized_stories_train_wikitext103.jbl",
        igf_data_file="igf_context_pairs.jbl",
    )

    # Load train data for secondary learner
    secondary_learner_train_data = joblib.load("data/IGF_values.jbl")

    # Train secondary learner
    secondary_learner = training_secondary_learner(
        secondary_learner_train_data,
        secondary_learner_max_epochs=15,
        secondary_learner_batch_size=128,
        eval_freq=100,
        igf_model_path="igf_model.pt",
    )

    # load pretrained gpt2 model
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    set_seed(42)

    # Generate train and test data to train and evaluate gpt2 model
    train_dataset, test_dataset = generate_datasets(
        context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
    )

    # fine-tuning of the gpt2 model using igf (Information Gain Filtration)
    finetune(
        model,
        train_dataset,
        test_dataset,
        context_len=32,
        max_steps=1000,
        batch_size=16,
        threshold=1.0,
        recopy_model=recopy_gpt2,
        secondary_learner=secondary_learner,
        eval_interval=10,
        finetuned_model_name="gpt2_finetuned.pt",
    )


if __name__ == "__main__":
    main()
