#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import argparse
import subprocess
import time
import math
import importlib
import os, sys
import numpy as np
import itertools

import torch
import random
from torch.utils.data import DataLoader
from gpt2_beam import beam
from gpt2_decode import decode_func
from eval.e2e.measure_scores import evaluating
torch.set_printoptions(threshold=100000)

from gpu import (
    add_gpu_params, 
    parse_gpu,
    distributed_opt,
    distributed_gather,
    average_model,
    distributed_sync, 
    cleanup
)
from optimizer import (
    create_adam_optimizer, 
    create_optimizer_scheduler, 
    add_optimizer_params, 
    create_adam_optimizer_from_args
)

from data_utils import FT_Dataset
from model_ours import GPT2LMModel_ours
from model import GPT2Config, GPT2LMModel
from exp_utils import create_exp_dir

import loralib as lora

parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')

add_gpu_params(parser)
add_optimizer_params(parser)

parser.add_argument('--train_data', required=True, help='location of training data corpus')

# parser.add_argument('--valid_data', required=True, help='location of validation data corpus')

parser.add_argument('--test_data', required=True, help='location of test data corpus')

parser.add_argument('--data_name', default='webnlg_challenge_2017', help='data name')

parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')

parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')

parser.add_argument('--method', type=str, default='ours', help='the method [lora, maml, hetlora, ours]')

parser.add_argument('--heterogeneity', type=float, default=0.3, help='The heterogeneity of data')

parser.add_argument('--gamma', type=float, default=0.99, help='The heterogeneity of data')

parser.add_argument('--lamb', type=float, default=0.001, help='The heterogeneity of data')

parser.add_argument('--rank_max', type=int, default=12, help='The max rank')

parser.add_argument('--rank_min', type=int, default=2, help='The minimal rank')

parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')

parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')

parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')

parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'], 
                    help='model names')

parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')

parser.add_argument('--fp16', action='store_true', help='train model with fp16')

parser.add_argument('--com_interval', type=int, default=10, help='communication interval')

parser.add_argument('--com_rounds', type=int, default=150, help='The total communication rounds')

parser.add_argument('--log_interval', type=int, default=10, help='log interval')

parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')

parser.add_argument('--save_interval', type=int, default=500, help='save interval')

parser.add_argument('--save_path', type=str, default='./outputs', help='save path')

parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'), 
                    help='working folder.')

parser.add_argument('--lora_dim', type=int, default=4, help='lora attn dimension')

parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')

parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'], 
                    help='language model training objective')

parser.add_argument('--lora_dropout', default=0.0, type=float, 
                    help='dropout probability for lora layers')

parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')

parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')

parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')

parser.add_argument('--lr_in', type=float, default=0.00001, help='inner learning rate')

parser.add_argument('--roll_step', type=int, default=100, help='rolling step')

parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')

parser.add_argument('--eval_len', type=int, default=256,
                    help='evaluation length')

parser.add_argument('--eval_ratio', type=float, default=0.0,
                    help='evaluation data propotion')

parser.add_argument('--min_length', type=int, default=0,
                    help='minimum generation length')

parser.add_argument('--beam', type=int, default=10, help='beam search size')

parser.add_argument('--length_penalty', type=float, default=1.0, help='length penalty')

parser.add_argument('--no_repeat_ngram_size', type=int, default=4, help='no_repeat_ngram_size')

parser.add_argument('--repetition_penalty', type=float, default=1.0, help='repetition_penalty')

parser.add_argument('--eos_token_id', action='append', type=int, default=[628],
                    help='eos token id')

parser.add_argument('--output_file', type=str, default='beam_prediction',
                    help='output file name')

parser.add_argument('--output_pred_file', type=str, default='eval/GenerationEval/data/hypothesis_webnlg',
                    help='prediction output file name')

parser.add_argument('--output_ref_file', type=str, default='eval/GenerationEval/data/references_webnlg',
                    help='reference file name')

parser.add_argument('--vocab', type=str, default='./vocab', help='vocab path')

parser.add_argument('--sample_file', default=None, type=str, help='ft sample file')
parser.add_argument('--input_file', default=None, type=str, help='ft input file')

parser.add_argument('--ref_unique_file', default=None, type=str, help='reference unique id file')

parser.add_argument('--ref_type', default='e2e', choices=['e2e', 'webnlg', 'dart'],
                    help='e2e style reference type; webnlg style reference type.')
parser.add_argument('--ref_num', default=4, type=int, help='number of references.')

parser.add_argument('--tokenize', action='store_true', help='')
parser.add_argument('--lower', action='store_true', help='')

parser.add_argument('--filter', default='all', choices=['all', 'seen', 'unseen'],
                    help='for webnlg only, filter categories that are seen during training, unseen, or all')

parser.add_argument('-l', '--sent-level', '--seg-level', '--sentence-level', '--segment-level',
                type=str, help='Output segment-level scores in a TSV format to the given file?',
                default=None)
parser.add_argument('-s', '--src-file', type=str, help='Source file -- if given, system output ' +
                                                   'should be a TSV with source & output columns, source is checked for integrity',
                default=None)
parser.add_argument('--python', action='store_true',
                help='Use Python implementation of MTEval instead of Perl?')
parser.add_argument('-t', '--table', action='store_true', help='Print out results as a line in a'
                                                           'TSV table?')
parser.add_argument('--header', action='store_true', help='Print TSV table header?')
# parser.add_argument("-R", "--reference", help="reference translation", required=True)
# parser.add_argument("-H", "--hypothesis", help="hypothesis translation", required=True)
parser.add_argument("-lng", "--language", help="evaluated language", default='en')
parser.add_argument("-nr", "--num_refs", help="number of references", type=int, default=6)
parser.add_argument("-m", "--metrics", help="evaluation metrics to be computed",
                       default='bleu,meteor,ter,rouge_l')
parser.add_argument("-nc", "--ncorder", help="chrF metric: character n-gram order (default=6)", type=int, default=6)
parser.add_argument("-nw", "--nworder", help="chrF metric: word n-gram order (default=2)", type=int, default=2)
parser.add_argument("-b", "--beta", help="chrF metric: beta parameter (default=2)", type=float, default=2.0)

# influence model, calculate the influence score between two samples.
def print_args(args):
    if args.rank == 0:
        print('=' * 100)
        for k, v in args.__dict__.items():
            print(f'        - {k} : {v}')
        print('=' * 100)


if __name__ == '__main__':
    args = parser.parse_args()
    parse_gpu(args)
    if args.method == 'hetlora':
        args.lora_dim = args.rank_min+int((args.rank_max - args.rank_min)*args.rank/args.world_size)
    print_args(args)
    args.output_pred_file = args.output_pred_file + f'_{args.method}' + f'_{args.rank}'
    args.output_ref_file = args.output_ref_file + f'_{args.method}' + f'_{args.rank}'
    if args.fp16:
        try:
            from apex import amp
        except Exception as e:
            warnings.warn('Could not import amp, apex may not be installed')

    torch.manual_seed(args.random_seed)
    random.seed(args.random_seed)
    
    if args.rank == 0:
        args.logging = create_exp_dir(args.work_dir)

    train_data = FT_Dataset(args, "train")
    
    valid_data = FT_Dataset(args, "eval")

    valid_data_text = valid_data.sample_text

    train_low_data = FT_Dataset(args, None)

    train_loader = DataLoader(
        train_data, batch_size=args.train_batch_size, num_workers=0, 
        shuffle=False, pin_memory=False, drop_last=True
    )
    
    valid_loader = DataLoader(
        valid_data, batch_size=args.valid_batch_size, num_workers=0,
        shuffle=False, pin_memory=False, drop_last=False
    )

    train_low_loader = DataLoader(
        train_low_data, batch_size=int(args.train_batch_size/2), num_workers=0,
        shuffle=True, pin_memory=False, drop_last=False
    )
    date = time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time()))
    save_direct = os.path.join(args.save_path, f'{args.method}')
    if not os.path.exists(save_direct):
        os.makedirs(save_direct)
    file_name = f'lr_{args.lr}_lr_in_{args.lr_in}_het{args.heterogeneity}_{date}'
    args.save_path = os.path.join(save_direct, file_name)
    if args.model_card == 'gpt2.sm':
        config = GPT2Config(
            n_embd=768, n_layer=12, n_head=12, 
            lora_attn_dim=args.lora_dim, 
            lora_attn_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
        )
    elif args.model_card == 'gpt2.md':
        config = GPT2Config(
            n_embd=1024, n_layer=24, n_head=16, 
            lora_attn_dim=args.lora_dim, 
            lora_attn_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
        )
    elif args.model_card == 'gpt2.lg':
        config = GPT2Config(
            n_embd=1280, n_layer=36, n_head=20, 
            lora_attn_dim=args.lora_dim, 
            lora_attn_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
        )

    lm_net = GPT2LMModel(config) if args.method != 'ours' else GPT2LMModel_ours(config)
    if args.init_checkpoint is not None:
        print('loading model pretrained weight.')
        lm_net.load_weight(torch.load(args.init_checkpoint))    

    lm_net = lm_net.cuda()
    start_time = time.time()
    # mark the trainable parameters
    if args.lora_dim > 0:
        lora.mark_only_lora_as_trainable(lm_net)
    distributed_sync(args)
    if args.max_step is None:
        args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
        print('set max_step:', args.max_step)
    method = importlib.import_module('methods.' + args.method)
    model = method.Model(args, lm_net)
    if args.method == 'ours' or args.method == 'maml':
        model.train(train_loader, valid_loader, valid_data_text, train_low_loader)
    else:
        model.train(train_loader, valid_loader, valid_data_text)
    cleanup(args)

