import glob
import sys

import torch
from torch.optim import Adam
from transformers import AutoTokenizer

sys.path.append('../')
import argparse
import os
from os.path import exists

from utils import get_data_path
from dataloader import Seq2SeqPipe
from model.model import GeneratorBaseline, CLGenerator
from model.metrics import Loss
from callback import SaveCkptCallback, lrCallback
from fastNLP import DistTrainer, get_local_rank
import torch.distributed as dist
from transformers import Adafactor, AdamW, get_linear_schedule_with_warmup


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def configure_training(args):
    devices = [int(gpu) for gpu in args.gpus.split(',')]
    params = {}
    params['beam_size'] = args.beam_size
    params['batch_size'] = args.batch_size
    params['accum_count'] = args.accum_count
    params['margin'] = args.margin
    params['warmup_steps'] = args.warmup_steps
    params['n_epochs'] = args.n_epochs
    params['save_steps'] = args.save_steps
    return devices, params


def train_model(args):
    # 初始化分布式进程
    dist.init_process_group(backend="nccl")
    ######## 读取数据
    if get_local_rank() != 0:
        dist.barrier()  # 先让主进程(rank==0)先执行，进行数据处理，预训模型参数下载等操作，然后保存cache
    # load summarization datasets
    data_paths = get_data_path("train", args.PTM, args.dataset)
    if args.PTM == "codet5":
        tokenize_name = "Salesforce/codet5-base"
    elif args.PTM == "t5":
        tokenize_name = "t5-small"
    else:
        tokenize_name = "google/pegasus-xsum"
    tokenizer = AutoTokenizer.from_pretrained(tokenize_name)
    args.pad_id = tokenizer.pad_token_id
    args.eos_id = tokenizer.eos_token_id
    args.bos_id = tokenizer.bos_token_id
    datasets = Seq2SeqPipe(args).process_from_file(data_paths)
    print('Information of dataset is:')
    print(datasets)
    train_set = datasets.datasets['train']
    if args.baseline:
        print("=" * 10, "initialize baseline model...", "=" * 10)
        model = GeneratorBaseline(args.PTM, args.model_name, args.pad_id, args.scratch)
    else:
        print("=" * 10, "initialize contrastive model...", "=" * 10)
        model = CLGenerator(args.PTM, args.model_name, args.pad_id, args)
    for name in data_paths:
        assert exists(data_paths[name])
    if not exists(args.save_path):
        os.makedirs(args.save_path)
    if get_local_rank() == 0:
        dist.barrier()

    devices, train_params = configure_training(args)
    # optimizer = Adafactor(
    #     model.parameters(),
    #     lr=1e-4,
    #     relative_step=False,
    #     scale_parameter=False,
    #     warmup_init=False,
    # )
    # callbacks = None

    optimizer = Adafactor(
        model.parameters(),
        lr=args.lr,
        relative_step=False,
        scale_parameter=False,
        warmup_init=False,
    )
    if args.PTM == "codet5":
        optim_pt = torch.load(os.path.join(args.model_name, "optim.pt"))
        optimizer.load_state_dict(optim_pt)
        print("=" * 20, "load optim from", os.path.join(args.model_name, "optim.pt", "=" * 20))

    callbacks_master = [SaveCkptCallback(args)]
    criterion = Loss()

    trainer = DistTrainer(train_data=train_set, model=model, optimizer=optimizer,
                          loss=criterion, batch_size_per_gpu=args.batch_size,
                          update_every=args.accum_count, n_epochs=args.n_epochs,
                          print_every=10,
                          save_path=args.save_path, callbacks_master=callbacks_master)

    print('Start training with the following hyper-parameters:')
    print(train_params)
    trainer.train()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='training/testing of MatchSum'
    )

    parser.add_argument('--save_path', required=True,
                        help='root of the model', type=str)
    parser.add_argument('--gpus', default="0,1,2,3",
                        help='available gpus for training(separated by commas)', type=str)
    parser.add_argument('--version', default="base")

    parser.add_argument('--batch_size', default=32,
                        help='the training batch size', type=int)
    parser.add_argument('--accum_count', default=1,
                        help='number of updates steps to accumulate before performing a backward/update pass', type=int)
    parser.add_argument('--lr', default=None, type=float)
    parser.add_argument('--margin', default=0.01,
                        help='parameter for MarginRankingLoss', type=float)
    parser.add_argument('--warmup_steps', default=10000,
                        help='warm up steps for training', type=int)
    parser.add_argument('--n_epochs', default=50,
                        help='total number of training epochs', type=int)
    parser.add_argument('--label_smoothing', type=float, default=0.0)
    parser.add_argument('--save_steps', default=2000,
                        help='number of update steps for validation and saving checkpoint', type=int)

    parser.add_argument('--dataset', default="wmt16")
    parser.add_argument('--baseline', type=str2bool)
    parser.add_argument('--PTM', default="t5")
    parser.add_argument('--scratch', type=str2bool, default=False)
    parser.add_argument('--model_name', default="google/pegasus-xsum")
    parser.add_argument('--max_src_len', default=512, type=int)
    parser.add_argument('--max_tgt_len', default=128, type=int)
    # inference parameters
    parser.add_argument('--min_length', default=5, type=int)
    parser.add_argument('--max_length', default=128, type=int)
    parser.add_argument('--beam_size', default=12, type=int)
    parser.add_argument('--early_stop', default=True, type=str2bool)
    parser.add_argument('--no_repeat_ngram', default=4, type=int)
    parser.add_argument('--alpha', default=0.5, type=float)
    parser.add_argument('--diversity_pen', default=1.0, type=float)
    parser.add_argument('--length_pen', default=2.0, type=float)
    parser.add_argument('--margin_func', default="ranking", choices=["const", "ranking"])
    parser.add_argument('--max_sample_num', default=16, type=int)
    parser.add_argument('--n_gram', default=2, type=int)

    # no need to set
    parser.add_argument('--pad_id', default=0, type=int)
    parser.add_argument('--eos_id', default=1, type=int)
    parser.add_argument('--bos_id', default=None, type=int)

    args = parser.parse_known_args()[0]
    train_model(args)
