import argparse
import os

import torch
import numpy as np
import math
import random
from tqdm import tqdm
from str2bool import str2bool

from datetime import datetime

from utils import (
    occupy_mem_new,
    save_hparams,
    WizardDataset_v2,
    BartFullBatcher,
    get_batch_loader,
)
from metrics import (
    bleu_metric,
    distinct_metric,
    f1_metric
)

from model.modeling_full import BartForConditionalGeneration

def ids_to_clean_text(tokenizer, generated_ids):
    hyp_list = []
    for i in range(generated_ids.size(0)):
        dec_out = generated_ids[i].tolist()
        hyp =  tokenizer.decode(dec_out, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        hyp_list.append(hyp)
    return hyp_list

def parse_segments(generated_ids, predicts_z, predicts_m):
    new_words, new_segments, new_context, new_knowledge = 0, 0, 0, 0
    context_first = 0
    num_context_words, num_knowledge_words = 0, 0
    generated_ids = generated_ids.tolist().copy()
    predicts_z = predicts_z.tolist()
    predicts_m = predicts_m.tolist()
    for i in range(len(generated_ids)):
        if predicts_z[i][0] == 0:
            context_first += 1

        if all([id != 2 for id in generated_ids[i]]):
            generated_ids[i][-1] = 2

        start = 0
        for j in range(len(generated_ids[i]) - 1):
            if generated_ids[i][j + 1] == 1:
                break

            if generated_ids[i][j + 1] == 2 or predicts_m[i][j] == 1:
                if predicts_z[i][j] == 0:
                    new_context += 1
                    num_context_words += (j - start + 1)
                else:
                    new_knowledge += 1
                    num_knowledge_words += (j - start + 1)
                start = j + 1
                new_segments += 1
            new_words += 1

    return new_words, new_segments, new_context, new_knowledge, context_first, num_context_words, num_knowledge_words

def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting wihch GPU to use
    occupy_mem_new(args.gpu_list.split(','), ratio=args.gpu_ratio, num_devices=args.n_device)

    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begin... | %s " % time_str)

    test_seen_dataset = WizardDataset_v2(args.test_seen_file)
    test_unseen_dataset = WizardDataset_v2(args.test_unseen_file)
    cmudog_dataset = WizardDataset_v2(args.cmudog_file)

    test_seen_loader = get_batch_loader(test_seen_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)
    test_unseen_loader = get_batch_loader(test_unseen_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)
    cmudog_loader = get_batch_loader(cmudog_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    batcher = BartFullBatcher(
        args.min_segment_len, args.split_threshold, args.know_threshold, args.copy_threshold, args.stop_words_path,
        args.stop_words_size, args.merge, args.percentage, args.drop_know, args.infilling, args.mlm, 0.0, 1.0, False, False, 0.0, True, args.knowledge_trunc,
        args.max_source_len, args.max_target_len, args.text_trunc, args.knowledge_trunc, args.bart_config, args.full_knowledge_attn, args.cuda
    )
        
    model = BartForConditionalGeneration.from_pretrained(args.bart_config)
    # model = BartForConditionalGeneration.from_pretrained(args.bart_config)
    model.resize_token_embeddings(len(batcher.tokenizer))
    model.load_state_dict(torch.load('reddit/log/{}/checkpoints/model-best/pytorch_model.bin'.format(args.pretrain_file)))
    if args.cuda:
        model.cuda()

    def dev_step(split, global_step):
        if split == 'test_seen':
            test_loader = test_seen_loader
        elif split == 'test_unseen':
            test_loader = test_unseen_loader
        elif split == 'cmudog':
            test_loader = cmudog_loader
        else:
            raise ValueError
        model.eval()

        test_loss = 0.0
        test_hyp, test_ref = [], []
        count = 0

        num_words, num_segments, num_context, num_knowledge = 0, 0, 0, 0
        context_first = 0
        num_context_words, num_knowledge_words = 0, 0
        with torch.no_grad():
            for history_list, knowledge_list, response_list in test_loader:
                dec_args = batcher(history_list, knowledge_list, training=False)
                dec_args['max_length'] = args.max_length
                dec_args['min_length'] = args.min_length
                dec_args['num_beams'] = args.num_beams
                dec_args['repetition_penalty'] = args.repetition_penalty
                dec_args['no_repeat_ngram_size'] = args.no_repeat_ngram_size
                dec_args['do_sample'] = False
                generated_ids, predicts_z, predicts_m = model.generate(**dec_args)
                predict_list = ids_to_clean_text(batcher.tokenizer, generated_ids)
                new_words, new_segments, new_context, new_knowledge, new_context_first, new_context_words, new_knowledge_words = parse_segments(generated_ids, predicts_z, predicts_m)
                num_words += new_words
                num_segments += new_segments
                num_context += new_context
                num_knowledge += new_knowledge
                context_first += new_context_first
                num_context_words += new_context_words
                num_knowledge_words += new_knowledge_words

                test_hyp.extend(predict_list)
                test_ref.extend(response_list)

                count += 1
                if count % 1000 == 0:
                    print(count)

        with open(os.path.join(out_dir, '{}-decoded-iter-{}.txt'.format(split, global_step)), 'w', encoding='utf-8') as f:
            for _hyp, _ref in zip(test_hyp, test_ref):
                f.writelines("{} ||| {}\n".format(_hyp, _ref))

        b1, b2, b3, b4 = bleu_metric(test_hyp, test_ref)
        d1, d2 = distinct_metric(test_hyp)
        f1 = f1_metric(test_hyp, test_ref)

        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print('hypothesis: ', len(test_hyp))
        print("Step: %d \t| %s" % (global_step, time_str))
        print('Avg segment/context/knowledge: {:.4f}/{:.4f}/{:.4f}'.format(num_segments * 1. / len(test_hyp), num_context * 1. / len(test_hyp), num_knowledge * 1. / len(test_hyp)))
        print('Context first: {:.4f}'.format(context_first * 1. / len(test_hyp)))
        print('Avg segment/context/knowledge length: {:.4f}/{:.4f}/{:.4f}'.format(
            num_words * 1. / num_segments, num_context_words * 1. / (num_context + 0.00001), num_knowledge_words * 1. / (num_knowledge + 0.00001)))
        print("BLEU-1/2/3/4: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(b1, b2, b3, b4))
        print("Distinct-1/2: {:.4f}/{:.4f}".format(d1, d2))
        print("F1: {:.4f}".format(f1))
        print("**********************************")

        return {'f1': f1, 'bleu1': b1, 'bleu2': b2, 'bleu3': b3, 'bleu4': b4, 'distinct1': d1, 'distinct2': d2}

    test_seen_results = dev_step("test_seen", 0)
    test_unseen_results = dev_step("test_unseen", 0)
    cmudog_results = dev_step("cmudog", 0)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Pre-training for Knowledge-Grounded Conversation'
    )

    parser.add_argument('--model_type', type=str, default='')

    # files
    parser.add_argument('--train_file', type=str, default='')
    parser.add_argument('--test_seen_file', type=str, default='')
    parser.add_argument('--test_unseen_file', type=str, default='')
    parser.add_argument('--cmudog_file', type=str, default='')

    # training scheme
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=2)
    parser.add_argument('--num_steps', type=int, default=1000000)
    parser.add_argument('--accum_steps', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--lr2', type=float, default=5e-5)
    parser.add_argument('--clip', type=float, default=2.0)
    parser.add_argument('--schedule', type=str, default='linear')
    parser.add_argument('--opt', type=str, default='adamw')

    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--warmup_steps', type=int, default=5000)
    parser.add_argument('--num_epochs', type=int, default=3)

    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--valid_every', type=int, default=1)

    # save
    parser.add_argument('--exp_name', type=str, default='0601_test')
    parser.add_argument('--log', type=str, default='wizard_of_wikipedia/log')
    parser.add_argument('--seed', type=int, default=42)

    # model
    parser.add_argument('--bart_config', type=str, default='')
    parser.add_argument('--pretrain_file', type=str, default='')

    parser.add_argument('--min_segment_len', type=int, default=6)
    parser.add_argument('--split_threshold', type=float, default=0.2)
    parser.add_argument('--know_threshold', type=float, default=0.4)
    parser.add_argument('--copy_threshold', type=float, default=1.0)
    parser.add_argument('--stop_words_path', type=str, default='debug_data/stop_words.txt')
    parser.add_argument('--stop_words_size', type=int, default=200)
    parser.add_argument('--merge', type=str2bool, default=False)
    parser.add_argument('--percentage', type=float, default=0.125)
    parser.add_argument('--drop_know', type=float, default=0.0)
    parser.add_argument('--infilling', type=str2bool, default=False)
    parser.add_argument('--mlm', type=float, default=0.0)

    parser.add_argument('--max_source_len', type=int, default=256)
    parser.add_argument('--max_target_len', type=int, default=256)
    parser.add_argument('--knowledge_trunc', type=int, default=64)
    parser.add_argument('--text_trunc', type=int, default=128)
    parser.add_argument('--full_knowledge_attn', type=str2bool, default=False)

    parser.add_argument('--max_length', type=int, default=40)
    parser.add_argument('--min_length', type=int, default=0)
    parser.add_argument('--num_beams', type=int, default=4)
    parser.add_argument('--repetition_penalty', type=float, default=1.2)
    parser.add_argument('--no_repeat_ngram_size', type=int, default=3)

    # gpu
    parser.add_argument('--gpu_list', type=str, default='0')
    parser.add_argument('--gpu_ratio', type=float, default=0.85)
    parser.add_argument('--n_device', type=int, default=7)
    parser.add_argument('--no_cuda', type=str2bool, default=False)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    main(args)
