import argparse
import os

import torch
import numpy as np
import random
from str2bool import str2bool

from datetime import datetime

from utils import occupy_mem_new, save_hparams, get_batch_loader
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup
)

from datasets import load_dataset
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW

def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting wihch GPU to use
    occupy_mem_new(args.gpu_list.split(','), ratio=args.gpu_ratio, num_devices=args.n_device)

    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begin... | %s " % time_str)

    tokenizer = RobertaTokenizer.from_pretrained(args.bert_config, do_lower_case=True)

    train_dataset = load_dataset('datasets/glue/glue.py', 'sst2', split='train')
    train_dataset = train_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=args.block_size), batched=True)
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    test_dataset = load_dataset('datasets/glue/glue.py', 'sst2', split='validation')
    test_dataset = test_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=args.block_size), batched=True)
    test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    train_loader = get_batch_loader(train_dataset, collate_fn=None, batch_size=args.batch_size, is_test=False)
    test_loader = get_batch_loader(test_dataset, collate_fn=None, batch_size=args.eval_batch_size, is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    model = RobertaForSequenceClassification.from_pretrained(args.bert_config)
    model.resize_token_embeddings(len(tokenizer))
    if args.cuda:
        model.cuda()

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
    total_steps = args.num_epochs * (len(train_dataset) / (args.batch_size * args.accum_steps))
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    if args.schedule == 'linear':
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'cosine_restart':
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps, num_cycles=30)
    elif args.schedule == 'polynomial':
        scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'constant':
        scheduler = get_constant_schedule(optimizer)
    elif args.schedule == 'constant_warmup':
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)

    def train_step(global_step):
        loss_total = 0.0

        for _ in range(args.accum_steps):
            model.train()
            batch = next(train_loader)
            input_ids = batch['input_ids'].to(torch.device('cuda' if args.cuda else 'cpu'))
            attention_mask = batch['attention_mask'].to(torch.device('cuda' if args.cuda else 'cpu'))
            labels = batch['label'].to(torch.device('cuda' if args.cuda else 'cpu'))
            loss = model(return_dict=True, input_ids=input_ids, attention_mask=attention_mask, labels=labels)['loss']
            loss = loss / args.accum_steps
            loss.backward()
            loss_total += loss.item()

        grad_norm = torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.clip)
        if grad_norm >= 1e2:
            print('WARNING : Exploding Gradients {:.2f}'.format(grad_norm))
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if global_step % args.print_every == 0 and global_step != 0:
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Step: %d \t| loss: %.3f \t| lr: %.8f \t| %s" % (
                global_step, loss_total, scheduler.get_lr()[0], time_str
            ))

    def dev_step(split, global_step):
        model.eval()
        labels, predicts = [], []
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(torch.device('cuda' if args.cuda else 'cpu'))
                attention_mask = batch['attention_mask'].to(torch.device('cuda' if args.cuda else 'cpu'))
                logits = model(return_dict=True, input_ids=input_ids, attention_mask=attention_mask)['logits']
                labels.extend(batch['label'].tolist())
                predicts.extend(torch.argmax(logits, dim=1).tolist())
        acc = np.mean([a == b for a, b in zip(labels, predicts)])

        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print("Step: %d \t|  %s" % (global_step, time_str))
        print("Acc: {:.4f}".format(acc))
        print("**********************************")
        return {'acc': acc}

    best_acc = 0.
    for i in range(args.num_steps):
        train_step(i + 1)
        if (i + 1) % args.valid_every == 0:
            test_seen_results = dev_step("test", i + 1)
            if test_seen_results['acc'] > best_acc:
                best_acc = test_seen_results["acc"]

                save_path = '{}-best'.format(checkpoint_prefix)
                os.makedirs(save_path, exist_ok=True)
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(save_path)

                print("Saved model checkpoint to {}\n".format(checkpoint_prefix))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Pre-training for Knowledge-Grounded Conversation'
    )

    # training scheme
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=2)
    parser.add_argument('--num_steps', type=int, default=1000000)
    parser.add_argument('--accum_steps', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--clip', type=float, default=2.0)
    parser.add_argument('--schedule', type=str, default='linear')

    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--warmup_steps', type=int, default=5000)
    parser.add_argument('--num_epochs', type=int, default=3)

    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--valid_every', type=int, default=1)

    # save
    parser.add_argument('--exp_name', type=str, default='0601_test')
    parser.add_argument('--log', type=str, default='wizard_of_wikipedia/log')
    parser.add_argument('--seed', type=int, default=42)

    # model
    parser.add_argument('--bert_config', type=str, default='/home2/xxx/Data/pretrain-models/bert_base_uncased')
    parser.add_argument('--block_size', type=int, default=256)

    # gpu
    parser.add_argument('--gpu_list', type=str, default='0')
    parser.add_argument('--gpu_ratio', type=float, default=0.85)
    parser.add_argument('--n_device', type=int, default=7)
    parser.add_argument('--no_cuda', type=str2bool, default=False)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    main(args)
