import argparse
import os

import torch
import numpy as np
import math
import random
from tqdm import tqdm
from str2bool import str2bool

from datetime import datetime

from utils import (
    occupy_mem_new,
    save_hparams,
    RedditKsDataset_v0,
    RobertaBatcher,
    get_batch_loader,
)
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.optimization import (
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup
)


def recall_metric(scores, labels):
    score_list, label_list = [], []
    begin = 0
    for end in range(1, len(labels)):
        if labels[end] == 1:
            score_list.append(scores[begin:end])
            label_list.append(labels[begin:end])
            begin = end
    score_list.append(scores[begin:len(labels)])
    label_list.append(labels[begin:len(labels)])
    print('hypothesis: ', len(score_list))
    r1, r2, r5, r10 = 0., 0., 0., 0.
    for score, label in zip(score_list, label_list):
        assert label[0] == 1
        rank = sorted(range(len(score)), key=lambda x: score[x], reverse=True)
        if 0 in rank[:1]:
            r1 += 1
        if 0 in rank[:2]:
            r2 += 1
        if 0 in rank[:5]:
            r5 += 1
        if 0 in rank[:10]:
            r10 += 1
    return r1 / len(score_list), r2 / len(score_list), r5 / len(score_list), r10 / len(score_list)


def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting wihch GPU to use
    occupy_mem_new(args.gpu_list.split(','), ratio=args.gpu_ratio, num_devices=args.n_device)

    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begin... | %s " % time_str)

    train_dataset = RedditKsDataset_v0(args.train_file)
    test_seen_dataset = RedditKsDataset_v0(args.test_seen_file)
    test_unseen_dataset = RedditKsDataset_v0(args.test_unseen_file)
    test_cmudog_dataset = RedditKsDataset_v0(args.test_cmudog_file)

    train_loader = get_batch_loader(train_dataset, collate_fn=RedditKsDataset_v0.collate_fn, batch_size=args.batch_size, is_test=False)
    test_seen_loader = get_batch_loader(test_seen_dataset, collate_fn=RedditKsDataset_v0.collate_fn, batch_size=args.eval_batch_size, is_test=True)
    test_unseen_loader = get_batch_loader(test_unseen_dataset, collate_fn=RedditKsDataset_v0.collate_fn, batch_size=args.eval_batch_size, is_test=True)
    test_cmudog_loader = get_batch_loader(test_cmudog_dataset, collate_fn=RedditKsDataset_v0.collate_fn, batch_size=args.eval_batch_size, is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    batcher = RobertaBatcher(args.block_size, args.knowledge_trunc, args.text_trunc, args.bert_config, args.cuda)

    if args.label_smooth:
        from model.modeling_roberta import RobertaForSequenceClassification
    else:
        from transformers import RobertaForSequenceClassification

    model = RobertaForSequenceClassification.from_pretrained(args.bert_config)
    model.resize_token_embeddings(len(batcher.tokenizer))
    if args.cuda:
        model.cuda()

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
    total_steps = args.num_epochs * (len(train_dataset) / (args.batch_size * args.accum_steps))
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    if args.schedule == 'linear':
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'cosine_restart':
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps, num_cycles=30)
    elif args.schedule == 'polynomial':
        scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'constant':
        scheduler = get_constant_schedule(optimizer)
    elif args.schedule == 'constant_warmup':
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)

    def train_step(global_step):
        ks_loss_total = 0.0

        for _ in range(args.accum_steps):
            query_list, candidate_list, label_list = next(train_loader)
            query_list = [q.split('\n\n') for q in query_list]

            model.train()
            fwd_args = batcher(query_list, candidate_list, label_list, training=True)
            loss = model(return_dict=True, **fwd_args)['loss']
            loss = loss / args.accum_steps
            loss.backward()
            ks_loss_total += loss.item()

        grad_norm = torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.clip)
        if grad_norm >= 1e2:
            print('WARNING : Exploding Gradients {:.2f}'.format(grad_norm))
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if global_step % args.print_every == 0 and global_step != 0:
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Step: %d \t| ks_loss: %.3f \t| lr: %.8f \t| %s" % (
                global_step, ks_loss_total, scheduler.get_lr()[0], time_str
            ))

    def dev_step(split, global_step):
        if split == 'test_seen':
            test_loader = test_seen_loader
        elif split == 'test_unseen':
            test_loader = test_unseen_loader
        elif split == 'test_cmudog':
            test_loader = test_cmudog_loader
        else:
            raise ValueError

        model.eval()
        labels, scores = [], []
        count = 0
        with torch.no_grad():
            for query_list, candidate_list, label_list in test_loader:
                query_list = [q.split('\n\n') for q in query_list]

                fwd_args = batcher(query_list, candidate_list, training=False)
                logits = model(return_dict=True, **fwd_args)['logits']
                labels.extend(label_list)
                scores.extend(logits[:,1].tolist())

                count += 1
                if count % 1000 == 0:
                    print(count)

        with open(os.path.join(out_dir, '{}-decoded-iter-{}.txt'.format(split, global_step)), 'w', encoding='utf-8') as f:
            for label, score in zip(labels, scores):
                f.write('{}\t{}\n'.format(label, score))

        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print("Step: %d \t|  %s" % (global_step, time_str))

        r1, r2, r5, r10 = recall_metric(scores, labels)
        print("RECALL-1/2/5/10: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(r1, r2, r5, r10))
        print("**********************************")

        return {'r_at_1': r1, 'r_at_2': r2, 'r_at_5': r5, 'r_at_10': r10}

    best_r1 = 0.
    for i in range(args.num_steps):
        train_step(i + 1)
        if (i + 1) % args.valid_every == 0:

            test_seen_results = dev_step("test_seen", i + 1)
            # test_unseen_results = dev_step("test_unseen", i + 1)
            # test_cmudog_results = dev_step("test_cmudog", i + 1)



            if test_seen_results["r_at_1"] > best_r1:
                best_r1 = test_seen_results["r_at_1"]

                # save_path = '{}-best'.format(checkpoint_prefix)
                # os.makedirs(save_path, exist_ok=True)
                # model_to_save = model.module if hasattr(model, "module") else model
                # model_to_save.save_pretrained(save_path)

                print("Saved model checkpoint to {}\n".format(checkpoint_prefix))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Pre-training for Knowledge-Grounded Conversation'
    )

    # files
    parser.add_argument('--train_file', type=str, default='')
    parser.add_argument('--test_seen_file', type=str, default='')
    parser.add_argument('--test_unseen_file', type=str, default='')
    parser.add_argument('--test_cmudog_file', type=str, default='')

    # training scheme
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=2)
    parser.add_argument('--num_steps', type=int, default=1000000)
    parser.add_argument('--accum_steps', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--clip', type=float, default=2.0)
    parser.add_argument('--schedule', type=str, default='linear')
    parser.add_argument('--label_smooth', type=str2bool, default=False)

    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--warmup_steps', type=int, default=5000)
    parser.add_argument('--num_epochs', type=int, default=3)

    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--valid_every', type=int, default=1)

    # save
    parser.add_argument('--exp_name', type=str, default='0601_test')
    parser.add_argument('--log', type=str, default='wizard_of_wikipedia/log')
    parser.add_argument('--seed', type=int, default=42)

    # model
    parser.add_argument('--bert_config', type=str, default='/home2/xxx/Data/pretrain-models/bert_base_uncased')
    parser.add_argument('--pretrain_file', type=str, default='')
    parser.add_argument('--block_size', type=int, default=256)
    parser.add_argument('--knowledge_trunc', type=int, default=64)
    parser.add_argument('--text_trunc', type=int, default=128)

    # gpu
    parser.add_argument('--gpu_list', type=str, default='0')
    parser.add_argument('--gpu_ratio', type=float, default=0.85)
    parser.add_argument('--n_device', type=int, default=7)
    parser.add_argument('--no_cuda', type=str2bool, default=False)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    main(args)
