import argparse
import os

import torch
import numpy as np
import math
import random
from tqdm import tqdm
from str2bool import str2bool

from datetime import datetime

from utils import (
    occupy_mem_new,
    save_hparams,
    RedditDataset,
    WizardDataset_v2,
    BartFullBatcher,
    get_batch_loader,
)
from metrics import (
    bleu_metric,
    distinct_metric,
    f1_metric
)

from transformers import AdamW, Adafactor
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.optimization import (
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup
)
from model.modeling_full import BartForConditionalGeneration

def ids_to_clean_text(tokenizer, generated_ids):
    hyp_list = []
    for i in range(generated_ids.size(0)):
        dec_out = generated_ids[i].tolist()
        hyp =  tokenizer.decode(dec_out, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        hyp_list.append(hyp)
    return hyp_list

def parse_segments(generated_ids, predicts_z, predicts_m):
    new_words, new_segments, new_context, new_knowledge = 0, 0, 0, 0
    context_first = 0
    num_context_words, num_knowledge_words = 0, 0
    generated_ids = generated_ids.tolist().copy()
    predicts_z = predicts_z.tolist()
    predicts_m = predicts_m.tolist()
    for i in range(len(generated_ids)):
        if predicts_z[i][0] == 0:
            context_first += 1

        if all([id != 2 for id in generated_ids[i]]):
            generated_ids[i][-1] = 2

        start = 0
        for j in range(len(generated_ids[i]) - 1):
            if generated_ids[i][j + 1] == 1:
                break

            if generated_ids[i][j + 1] == 2 or predicts_m[i][j] == 1:
                if predicts_z[i][j] == 0:
                    new_context += 1
                    num_context_words += (j - start + 1)
                else:
                    new_knowledge += 1
                    num_knowledge_words += (j - start + 1)
                start = j + 1
                new_segments += 1
            new_words += 1

    return new_words, new_segments, new_context, new_knowledge, context_first, num_context_words, num_knowledge_words

def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting wihch GPU to use
    occupy_mem_new(args.gpu_list.split(','), ratio=args.gpu_ratio, num_devices=args.n_device)

    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begin... | %s " % time_str)

    train_dataset = RedditDataset(args.train_file)
    test_seen_dataset = WizardDataset_v2(args.test_seen_file)

    train_loader = get_batch_loader(train_dataset, collate_fn=RedditDataset.collate_fn, batch_size=args.batch_size, is_test=False)
    test_seen_loader = get_batch_loader(test_seen_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    batcher = BartFullBatcher(
        args.min_segment_len, args.split_threshold, args.know_threshold, args.copy_threshold, args.stop_words_path,
        args.stop_words_size, args.merge, args.percentage, args.drop_know, args.infilling, args.mlm, args.random_know,
        args.bleu_percent, args.reverse, args.use_sep, args.mix_know, args.add_prefix_space, args.test_knowledge_truncate, args.test_knowledge_num,
        args.max_source_len, args.max_target_len, args.text_trunc, args.knowledge_trunc, args.bart_config, args.full_knowledge_attn, args.cuda
    )
    model = BartForConditionalGeneration.from_pretrained(args.bart_config)
    # model = BartForConditionalGeneration.from_pretrained(args.bart_config)
    model.resize_token_embeddings(len(batcher.tokenizer))
    if args.cuda:
        model.cuda()

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if 'model' in n and not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if 'model' in n and any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    total_steps = args.num_epochs * (len(train_dataset) / (args.batch_size * args.accum_steps))
    if args.opt == 'adamw':
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
        optimizer2 = AdamW([p for n, p in model.named_parameters() if 'model' not in n], lr=args.lr2, eps=args.adam_epsilon)
    elif args.opt == 'adafactor':
        optimizer = Adafactor(optimizer_grouped_parameters, lr=args.lr, beta1=None, relative_step=False, scale_parameter=False, warmup_init=False)
        optimizer2 = Adafactor([p for n, p in model.named_parameters() if 'model' not in n], lr=args.lr2, beta1=None, relative_step=False, scale_parameter=False, warmup_init=False)

    if args.schedule == 'linear':
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
        scheduler2 = get_linear_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
        scheduler2 = get_cosine_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'cosine_restart':
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps, num_cycles=30)
        scheduler2 = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    elif args.schedule == 'polynomial':
        scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
        scheduler2 = get_constant_schedule(optimizer2)
    elif args.schedule == 'constant':
        scheduler = get_constant_schedule(optimizer)
        scheduler2 = get_constant_schedule(optimizer2)
    elif args.schedule == 'constant_warmup':
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
        scheduler2 = get_constant_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_steps)

    def train_step(global_step):
        lm_loss_total, kl_abs_total, kl_mask_total = 0.0, 0.0, 0.0
        for _ in range(args.accum_steps):
            history_list, knowledge_list, tree_list, check_list = next(train_loader)

            model.train()
            fwd_args = batcher(history_list, knowledge_list, tree_list, check_list, training=True)
            lm_loss, kl_abs_loss, kl_mask_loss = model(**fwd_args)
            lm_loss = lm_loss / args.accum_steps
            kl_abs_loss = kl_abs_loss / args.accum_steps
            kl_mask_loss = kl_mask_loss / args.accum_steps
            loss = lm_loss + kl_abs_loss + kl_mask_loss
            loss.backward()

            lm_loss_total += lm_loss.item()
            kl_abs_total += kl_abs_loss.item()
            kl_mask_total += kl_mask_loss.item()

        grad_norm = torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.clip)
        # if grad_norm >= 1e2:
        #     print('WARNING : Exploding Gradients {:.2f}'.format(grad_norm))
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        optimizer2.step()
        scheduler2.step()
        optimizer2.zero_grad()

        if global_step % args.print_every == 0 and global_step != 0:
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Step: %d \t| lm_loss: %.3f \t| kl_abs_loss: %.3f \t| kl_mask_loss: %.3f \t| lr: %.8f \t| %s" % (
                global_step, lm_loss_total, kl_abs_total, kl_mask_total, scheduler.get_lr()[0], time_str
            ))

    def dev_step(split, global_step):
        if split == 'test_seen':
            test_loader = test_seen_loader
        else:
            raise ValueError
        model.eval()

        test_loss = 0.0
        test_hyp, test_ref = [], []
        count = 0

        num_words, num_segments, num_context, num_knowledge = 0, 0, 0, 0
        context_first = 0
        num_context_words, num_knowledge_words = 0, 0
        with torch.no_grad():
            for history_list, knowledge_list, response_list in test_loader:
                dec_args = batcher(history_list, knowledge_list, training=False)
                dec_args['max_length'] = args.max_length
                dec_args['min_length'] = args.min_length
                dec_args['num_beams'] = args.num_beams
                dec_args['repetition_penalty'] = args.repetition_penalty
                dec_args['no_repeat_ngram_size'] = args.no_repeat_ngram_size
                dec_args['do_sample'] = False
                generated_ids, predicts_z, predicts_m = model.generate(**dec_args)
                predict_list = ids_to_clean_text(batcher.tokenizer, generated_ids)
                new_words, new_segments, new_context, new_knowledge, new_context_first, new_context_words, new_knowledge_words = parse_segments(generated_ids, predicts_z, predicts_m)
                num_words += new_words
                num_segments += new_segments
                num_context += new_context
                num_knowledge += new_knowledge
                context_first += new_context_first
                num_context_words += new_context_words
                num_knowledge_words += new_knowledge_words

                test_hyp.extend(predict_list)
                test_ref.extend(response_list)

                count += 1
                if count % 1000 == 0:
                    print(count)

        with open(os.path.join(out_dir, '{}-decoded-iter-{}.txt'.format(split, global_step)), 'w', encoding='utf-8') as f:
            for _hyp, _ref in zip(test_hyp, test_ref):
                f.writelines("{} ||| {}\n".format(_hyp, _ref))

        b1, b2, b3, b4 = bleu_metric(test_hyp, test_ref)
        d1, d2 = distinct_metric(test_hyp)
        f1 = f1_metric(test_hyp, test_ref)

        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print('hypothesis: ', len(test_hyp))
        print("Step: %d \t| %s" % (global_step, time_str))
        print('Avg segment/context/knowledge: {:.4f}/{:.4f}/{:.4f}'.format(num_segments * 1. / len(test_hyp), num_context * 1. / len(test_hyp), num_knowledge * 1. / len(test_hyp)))
        print('Context first: {:.4f}'.format(context_first * 1. / len(test_hyp)))
        print('Avg segment/context/knowledge length: {:.4f}/{:.4f}/{:.4f}'.format(
            num_words * 1. / num_segments, num_context_words * 1. / (num_context + 0.00001), num_knowledge_words * 1. / (num_knowledge + 0.00001)))
        print("BLEU-1/2/3/4: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(b1, b2, b3, b4))
        print("Distinct-1/2: {:.4f}/{:.4f}".format(d1, d2))
        print("F1: {:.4f}".format(f1))
        print("**********************************")

        return {'f1': f1, 'bleu1': b1, 'bleu2': b2, 'bleu3': b3, 'bleu4': b4, 'distinct1': d1, 'distinct2': d2}

    best_f1 = 0.
    for i in range(args.num_steps):
        train_step(i + 1)
        if (i + 1) % args.valid_every == 0:
            test_seen_results = dev_step("test_seen", i + 1)
            if test_seen_results["f1"] > best_f1:
                best_f1 = test_seen_results["f1"]

                save_path = "{}-best".format(checkpoint_prefix)
                os.makedirs(save_path, exist_ok=True)
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(save_path)

                print("Saved model checkpoint to {}\n".format(checkpoint_prefix))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Pre-training for Knowledge-Grounded Conversation'
    )

    parser.add_argument('--model_type', type=str, default='')

    # files
    parser.add_argument('--train_file', type=str, default='')
    parser.add_argument('--test_seen_file', type=str, default='')

    # training scheme
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=2)
    parser.add_argument('--num_steps', type=int, default=1000000)
    parser.add_argument('--accum_steps', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--lr2', type=float, default=5e-5)
    parser.add_argument('--clip', type=float, default=2.0)
    parser.add_argument('--schedule', type=str, default='linear')
    parser.add_argument('--opt', type=str, default='adamw')

    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--warmup_steps', type=int, default=5000)
    parser.add_argument('--num_epochs', type=int, default=3)

    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--valid_every', type=int, default=1)

    # save
    parser.add_argument('--exp_name', type=str, default='0601_test')
    parser.add_argument('--log', type=str, default='wizard_of_wikipedia/log')
    parser.add_argument('--seed', type=int, default=42)

    # model
    parser.add_argument('--bart_config', type=str, default='')

    parser.add_argument('--min_segment_len', type=int, default=6)
    parser.add_argument('--split_threshold', type=float, default=0.2)
    parser.add_argument('--know_threshold', type=float, default=0.4)
    parser.add_argument('--copy_threshold', type=float, default=1.0)
    parser.add_argument('--stop_words_path', type=str, default='debug_data/stop_words.txt')
    parser.add_argument('--stop_words_size', type=int, default=200)
    parser.add_argument('--merge', type=str2bool, default=False)
    parser.add_argument('--percentage', type=float, default=0.125)
    parser.add_argument('--drop_know', type=float, default=0.0)
    parser.add_argument('--infilling', type=str2bool, default=False)
    parser.add_argument('--mlm', type=float, default=0.0)
    parser.add_argument('--random_know', type=float, default=0.0)
    parser.add_argument('--bleu_percent', type=float, default=0.125)
    parser.add_argument('--reverse', type=str2bool, default=False)
    parser.add_argument('--use_sep', type=str2bool, default=False)
    parser.add_argument('--mix_know', type=float, default=0.0)
    parser.add_argument('--add_prefix_space', type=str2bool, default=True)
    parser.add_argument('--test_knowledge_truncate', type=int, default=64)
    parser.add_argument('--test_knowledge_num', type=int, default=64)

    parser.add_argument('--max_source_len', type=int, default=256)
    parser.add_argument('--max_target_len', type=int, default=256)
    parser.add_argument('--knowledge_trunc', type=int, default=64)
    parser.add_argument('--text_trunc', type=int, default=128)
    parser.add_argument('--full_knowledge_attn', type=str2bool, default=False)

    parser.add_argument('--max_length', type=int, default=40)
    parser.add_argument('--min_length', type=int, default=25)
    parser.add_argument('--num_beams', type=int, default=4)
    parser.add_argument('--repetition_penalty', type=float, default=1.2)
    parser.add_argument('--no_repeat_ngram_size', type=int, default=3)

    # gpu
    parser.add_argument('--gpu_list', type=str, default='0')
    parser.add_argument('--gpu_ratio', type=float, default=0.85)
    parser.add_argument('--n_device', type=int, default=7)
    parser.add_argument('--no_cuda', type=str2bool, default=False)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    main(args)
