import argparse
import os

import torch
import numpy as np
import random
from str2bool import str2bool

from datetime import datetime

from utils import (
    occupy_mem_new,
    save_hparams,
    RedditSentiDataset,
    WizardDataset_v2,
    BartSentiBatcher,
    get_batch_loader,
)
from metrics import (
    bleu_metric,
    distinct_metric,
    f1_metric
)
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.optimization import (
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup
)

from transformers import RobertaForSequenceClassification, RobertaTokenizer

def is_batch_valid(batch):
    labels_z = batch['labels_z']
    labels_z = labels_z.reshape(-1).tolist()
    if 0 in labels_z and 1 in labels_z and 2 in labels_z and 3 in labels_z:
    # if 0 in labels_z and 2 in labels_z and 3 in labels_z:
        return True
    else:
        return False


def ids_to_clean_text(tokenizer, generated_ids):
    hyp_list = []
    for i in range(len(generated_ids)):
        dec_out = generated_ids[i]
        hyp =  tokenizer.decode(dec_out, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        hyp_list.append(hyp)
    return hyp_list

def parse_segments(generated_ids, predicts_z, predicts_m):
    new_words, new_segments, new_context, new_knowledge = 0, 0, 0, 0
    new_positive, new_negative = 0, 0
    context_first = 0
    num_context_words, num_knowledge_words = 0, 0
    num_positive_words, num_negative_words = 0, 0
    generated_ids = generated_ids.tolist().copy()
    predicts_z = predicts_z.tolist()
    predicts_m = predicts_m.tolist()
    knowledge_ids = []
    for i in range(len(generated_ids)):
        if predicts_z[i][0] == 0:
            context_first += 1

        if all([id != 2 for id in generated_ids[i]]):
            generated_ids[i][-1] = 2

        start = 0
        for j in range(len(generated_ids[i]) - 1):
            if generated_ids[i][j + 1] == 1:
                break
            if generated_ids[i][j + 1] == 2 or predicts_m[i][j] == 1:
                if predicts_z[i][j] == 0:
                    new_context += 1
                    num_context_words += (j - start + 1)
                elif predicts_z[i][j] == 1:
                    new_knowledge += 1
                    num_knowledge_words += (j - start + 1)
                    knowledge_ids.append(generated_ids[i][start:j+1].copy())
                elif predicts_z[i][j] == 2:
                    new_positive += 1
                    num_positive_words += (j - start + 1)
                    knowledge_ids.append(generated_ids[i][start:j + 1].copy())
                elif predicts_z[i][j] == 3:
                    new_negative += 1
                    num_negative_words += (j - start + 1)
                    knowledge_ids.append(generated_ids[i][start:j + 1].copy())
                else:
                    raise ValueError
                start = j + 1
                new_segments += 1
            new_words += 1
    return {
        'new_words': new_words,
        'new_segments': new_segments,
        'new_context': new_context,
        'new_knowledge': new_knowledge,
        'new_positive': new_positive,
        'new_negative': new_negative,
        'context_first': context_first,
        'num_context_words': num_context_words,
        'num_knowledge_words': num_knowledge_words,
        'num_positive_words': num_positive_words,
        'num_negative_words': num_negative_words,
        'knowledge_ids': knowledge_ids,
    }

def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting wihch GPU to use
    occupy_mem_new(args.gpu_list.split(','), ratio=args.gpu_ratio, num_devices=args.n_device)

    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begin... | %s " % time_str)

    train_dataset = RedditSentiDataset(args.train_file)
    test_seen_dataset = WizardDataset_v2(args.test_seen_file)
    test_unseen_dataset = WizardDataset_v2(args.test_unseen_file)
    cmudog_dataset = WizardDataset_v2(args.cmudog_file)

    train_loader = get_batch_loader(train_dataset, collate_fn=RedditSentiDataset.collate_fn, batch_size=args.batch_size, is_test=False)
    test_seen_loader = get_batch_loader(test_seen_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)
    test_unseen_loader = get_batch_loader(test_unseen_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)
    cmudog_loader = get_batch_loader(cmudog_dataset, collate_fn=WizardDataset_v2.collate_fn, batch_size=args.eval_batch_size, is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    batcher = BartSentiBatcher(
        args.min_segment_len, args.split_threshold, args.know_threshold, args.stop_words_path, args.stop_words_size,
        args.percentage, args.test_knowledge_truncate, args.test_knowledge_num, args.senti_threshold, args.senti_percentage,
        args.max_source_len, args.max_target_len, args.text_trunc, args.knowledge_trunc, args.bart_config, args.cuda
    )
    if args.model_type == 'pos':
        from model.modeling_senti_pos import BartForConditionalGeneration
    elif args.model_type == 'neg':
        from model.modeling_senti_neg import BartForConditionalGeneration
    else:
        from model.modeling_senti import BartForConditionalGeneration

    model = BartForConditionalGeneration.from_pretrained(args.pretrain_file)
    model.resize_token_embeddings(len(batcher.tokenizer))
    if args.cuda:
        model.cuda()

    roberta_tokenizer = RobertaTokenizer.from_pretrained(args.bert_config, do_lower_case=True)
    roberta_model = RobertaForSequenceClassification.from_pretrained(args.sst_pretrain_file)
    roberta_model.resize_token_embeddings(len(roberta_tokenizer))
    if args.cuda:
        roberta_model.cuda()

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if 'model' in n and not any(nd in n for nd in no_decay) and 'adapter' not in n],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if 'model' in n and any(nd in n for nd in no_decay) and 'adapter' not in n],
            "weight_decay": 0.0,
        },
    ]
    total_steps = args.num_epochs * (len(train_dataset) / (args.batch_size * args.accum_steps))

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
    optimizer2 = AdamW([p for n, p in model.named_parameters() if 'model' not in n], lr=args.lr2, eps=args.adam_epsilon)
    optimizer3 = AdamW([p for n, p in model.named_parameters() if 'adapter' in n], lr=args.lr3, eps=args.adam_epsilon)

    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    scheduler2 = get_linear_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)
    scheduler3 = get_linear_schedule_with_warmup(optimizer3, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)

    def train_step(global_step, train_adapter=True):
        lm_loss_total, pos_loss_total, neg_loss_total, kl_abs_total, kl_mask_total = 0.0, 0.0, 0.0, 0.0, 0.0
        batches = []
        while True:
            history_list, knowledge_list, tree_list, check_list, segment_dict_list = next(train_loader)
            batch = batcher(history_list, knowledge_list, tree_list, check_list, segment_dict_list, training=True)
            if is_batch_valid(batch):
                batches.append(batch)
                if len(batches) >= args.accum_steps:
                    break

        if train_adapter:
            # optimize adapter
            for batch in batches:
                model.train()
                _, positive_loss, negative_loss, _, _ = model(**batch)
                positive_loss = positive_loss / args.accum_steps
                negative_loss = negative_loss / args.accum_steps
                loss = positive_loss + negative_loss
                loss.backward()
                pos_loss_total += positive_loss.item()
                neg_loss_total += negative_loss.item()
            torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.clip)
            optimizer3.step()
            scheduler3.step()
            model.zero_grad() # reset other bart parameters, because we only want to update sentiment adapters

        else:
            # optimize bart and prior modules
            for batch in batches:
                model.train()
                lm_loss, positive_loss, negative_loss, kl_abs_loss, kl_mask_loss = model(**batch)
                lm_loss = lm_loss / args.accum_steps
                kl_abs_loss = kl_abs_loss / args.accum_steps
                positive_loss = positive_loss / args.accum_steps
                negative_loss = negative_loss / args.accum_steps
                kl_mask_loss = kl_mask_loss / args.accum_steps
                loss = lm_loss + kl_abs_loss + kl_mask_loss + positive_loss + negative_loss
                loss.backward()
                lm_loss_total += lm_loss.item()
                pos_loss_total += positive_loss.item()
                neg_loss_total += negative_loss.item()
                kl_abs_total += kl_abs_loss.item()
                kl_mask_total += kl_mask_loss.item()
            torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.clip)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            optimizer2.step()
            scheduler2.step()
            optimizer2.zero_grad()
            # optimizer3.step() # todo: del
            # scheduler3.step() # todo: del
            model.zero_grad()

        if global_step % args.print_every == 0 and global_step != 0:
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Step: %d \t| lm_loss: %.3f \t| pos_loss: %.3f \t| neg_loss: %.3f \t| module_loss: %.3f \t| mask_loss: %.3f \t| lr: %.8f \t| %s" % (
                global_step, lm_loss_total, pos_loss_total, neg_loss_total, kl_abs_total, kl_mask_total, scheduler.get_lr()[0], time_str
            ))

    def dev_step(split, global_step):
        if split == 'test_seen':
            test_loader = test_seen_loader
        elif split == 'test_unseen':
            test_loader = test_unseen_loader
        elif split == 'cmudog':
            test_loader = cmudog_loader
        else:
            raise ValueError
        model.eval()

        test_hyp, test_ref = [], []
        test_span, test_label = [], []
        count = 0

        num_words, num_segments, num_context, num_knowledge, num_positive, num_negative = 0, 0, 0, 0, 0, 0
        context_first = 0
        num_context_words, num_knowledge_words, num_positive_words, num_negative_words = 0, 0, 0, 0
        with torch.no_grad():
            for history_list, knowledge_list, response_list in test_loader:
                dec_args = batcher(history_list, knowledge_list, training=False)
                dec_args['max_length'] = args.max_length
                dec_args['min_length'] = args.min_length
                dec_args['num_beams'] = args.num_beams
                dec_args['repetition_penalty'] = args.repetition_penalty
                dec_args['no_repeat_ngram_size'] = args.no_repeat_ngram_size
                dec_args['do_sample'] = False
                generated_ids, predicts_z, predicts_m = model.generate(**dec_args)

                predict_list = ids_to_clean_text(batcher.tokenizer, generated_ids.tolist())
                parse_results = parse_segments(generated_ids, predicts_z, predicts_m)
                num_words += parse_results['new_words']
                num_segments += parse_results['new_segments']
                num_context += parse_results['new_context']
                num_knowledge += parse_results['new_knowledge']
                num_positive += parse_results['new_positive']
                num_negative += parse_results['new_negative']
                context_first += parse_results['context_first']
                num_context_words += parse_results['num_context_words']
                num_knowledge_words += parse_results['num_knowledge_words']
                num_positive_words += parse_results['num_positive_words']
                num_negative_words += parse_results['num_negative_words']

                test_hyp.extend(predict_list)
                test_ref.extend(response_list)

                if len(parse_results['knowledge_ids']) > 0:
                    knowledge_spans = ids_to_clean_text(batcher.tokenizer, parse_results['knowledge_ids'])
                    batch = roberta_tokenizer(text=knowledge_spans, padding=True, truncation=True, max_length=16, return_tensors='pt')
                    input_ids = batch['input_ids'].to(torch.device('cuda' if args.cuda else 'cpu'))
                    attention_mask = batch['attention_mask'].to(torch.device('cuda' if args.cuda else 'cpu'))
                    logits = roberta_model(return_dict=True, input_ids=input_ids, attention_mask=attention_mask)['logits']
                    knowledge_labels = torch.argmax(logits, dim=1).tolist()
                    test_span.extend(knowledge_spans)
                    test_label.extend(knowledge_labels)

                count += 1
                if count % 1000 == 0:
                    print(count)

        with open(os.path.join(out_dir, '{}-decoded-iter-{}.txt'.format(split, global_step)), 'w', encoding='utf-8') as f:
            for _hyp, _ref in zip(test_hyp, test_ref):
                f.writelines("{} ||| {}\n".format(_hyp, _ref))
        with open(os.path.join(out_dir, '{}-decoded-iter-{}-spans.txt'.format(split, global_step)), 'w', encoding='utf-8') as f:
            for _span, _label in zip(test_span, test_label):
                f.writelines("{} ||| {}\n".format(_span, _label))

        b1, b2, b3, b4 = bleu_metric(test_hyp, test_ref)
        d1, d2 = distinct_metric(test_hyp)
        f1 = f1_metric(test_hyp, test_ref)
        if 'pos' in args.model_type:
            acc = np.mean([pred == 1 for pred in test_label])
        elif 'neg' in args.model_type:
            acc = np.mean([pred == 0 for pred in test_label])
        else:
            acc = np.mean([1 for pred in test_label])

        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print('hypothesis: ', len(test_hyp))
        print('knowledge spans: ', len(test_span))
        print("Step: %d \t| %s" % (global_step, time_str))
        print('Avg context/knowledge/positive/negative: {:.4f}/{:.4f}/{:.4f}/{:.4f}'.format(
            num_context * 1. / len(test_hyp),
            num_knowledge * 1. / len(test_hyp),
            num_positive * 1. / len(test_hyp),
            num_negative * 1. / len(test_hyp)
        ))
        print('Context first: {:.4f}'.format(context_first * 1. / len(test_hyp)))
        print('Avg context/knowledge/positive/negative length: {:.4f}/{:.4f}/{:.4f}/{:.4f}'.format(
            num_context_words * 1. / (num_context + 1e-5),
            num_knowledge_words * 1. / (num_knowledge + 1e-5),
            num_positive_words * 1. / (num_positive + 1e-5),
            num_negative_words * 1. / (num_negative + 1e-5)
        ))
        print('Sentiment Acc: {:.4f}'.format(acc))
        print("BLEU-1/2/3/4: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(b1, b2, b3, b4))
        print("Distinct-1/2: {:.4f}/{:.4f}".format(d1, d2))
        print("F1: {:.4f}".format(f1))
        print("**********************************")

        return {'f1': f1, 'bleu1': b1, 'bleu2': b2, 'bleu3': b3, 'bleu4': b4, 'distinct1': d1, 'distinct2': d2}

    best_f1 = 0.
    for i in range(args.num_steps):
        if i < args.adapter_pretrain_steps:
            train_step(i + 1, train_adapter=True)
        else:
            train_step(i + 1, train_adapter=False)

        if (i + 1) % args.valid_every == 0 and (i + 1) >= args.start_eval_steps:
            test_seen_results = dev_step("test_seen", i + 1)
            # test_unseen_results = dev_step("test_unseen", i + 1)
            # cmudog_results = dev_step("cmudog", i + 1)

            if test_seen_results["f1"] > best_f1:
                best_f1 = test_seen_results["f1"]

                save_path = "{}-best".format(checkpoint_prefix)
                os.makedirs(save_path, exist_ok=True)
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(save_path)

                print("Saved model checkpoint to {}\n".format(checkpoint_prefix))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Pre-training for Knowledge-Grounded Conversation'
    )

    parser.add_argument('--model_type', type=str, default='')

    # files
    parser.add_argument('--train_file', type=str, default='')
    parser.add_argument('--test_seen_file', type=str, default='')
    parser.add_argument('--test_unseen_file', type=str, default='')
    parser.add_argument('--cmudog_file', type=str, default='')

    # training scheme
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=2)
    parser.add_argument('--num_steps', type=int, default=1000000)
    parser.add_argument('--accum_steps', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--lr2', type=float, default=5e-5)
    parser.add_argument('--lr3', type=float, default=5e-5)
    parser.add_argument('--clip', type=float, default=2.0)
    parser.add_argument('--schedule', type=str, default='linear')
    parser.add_argument('--adapter_pretrain_steps', type=int, default=500)

    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--warmup_steps', type=int, default=5000)
    parser.add_argument('--num_epochs', type=int, default=3)

    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--valid_every', type=int, default=1)
    parser.add_argument('--start_eval_steps', type=int, default=500)

    # save
    parser.add_argument('--exp_name', type=str, default='0601_test')
    parser.add_argument('--log', type=str, default='wizard_of_wikipedia/log')
    parser.add_argument('--seed', type=int, default=42)

    # model
    parser.add_argument('--bart_config', type=str, default='/home2/xxx/Data/pretrain-models/facebook/bart-base')
    parser.add_argument('--pretrain_file', type=str, default='')
    parser.add_argument('--bert_config', type=str, default='/home2/xxx/Data/pretrain-models/roberta-base')
    parser.add_argument('--sst_pretrain_file', type=str, default='sst2/log/1231_sst_block32_1/checkpoints/model-best')

    parser.add_argument('--min_segment_len', type=int, default=6)
    parser.add_argument('--split_threshold', type=float, default=0.2)
    parser.add_argument('--know_threshold', type=float, default=0.4)
    parser.add_argument('--stop_words_path', type=str, default='debug_data/stop_words.txt')
    parser.add_argument('--stop_words_size', type=int, default=200)
    parser.add_argument('--percentage', type=float, default=0.125)
    parser.add_argument('--test_knowledge_truncate', type=int, default=64)
    parser.add_argument('--test_knowledge_num', type=int, default=64)
    parser.add_argument('--senti_threshold', type=float, default=0.3)
    parser.add_argument('--senti_percentage', type=float, default=0.5)

    parser.add_argument('--max_source_len', type=int, default=256)
    parser.add_argument('--max_target_len', type=int, default=256)
    parser.add_argument('--knowledge_trunc', type=int, default=64)
    parser.add_argument('--text_trunc', type=int, default=128)

    parser.add_argument('--max_length', type=int, default=40)
    parser.add_argument('--min_length', type=int, default=0)
    parser.add_argument('--num_beams', type=int, default=4)
    parser.add_argument('--repetition_penalty', type=float, default=1.2)
    parser.add_argument('--no_repeat_ngram_size', type=int, default=3)

    # gpu
    parser.add_argument('--gpu_list', type=str, default='0')
    parser.add_argument('--gpu_ratio', type=float, default=0.85)
    parser.add_argument('--n_device', type=int, default=7)
    parser.add_argument('--no_cuda', type=str2bool, default=False)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    main(args)
