#! /usr/bin/env python3
# coding=utf-8


import torch
import numpy as np
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
from transformers.utils import check_min_version
from rl_lib.utils import *
import spacy
from spacy.lang.en.stop_words import STOP_WORDS  # import stop words from language data
from datasets import load_metric
from spacy.tokens import Token

logger = logging.getLogger(__name__)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

BATCH_SIZE = 16
LEARNING_RATE = 5e-4
MAX_EPOCHES = 10000

check_min_version("4.5.0.dev0")

logger = logging.getLogger(__name__)

stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter)  # set attribute with getter


nlp = spacy.load("en_core_web_sm")
meteor_metric = load_metric('meteor')
rouge_metric = load_metric('rouge')
bleu_metric = load_metric('bleu')
sacrebleu_metric = load_metric('sacrebleu')
bert_metric = load_metric('bertscore')


def decode_chain_argmax(model, tokenizer, encoded_sent1, config):
    """
    return decoding results with argmax
    greedily pick each step token
    no teacher forcing but use last step argmax output as current step input
    :param model: the LM
    :param tokenizer: the tokenizer passed into
    :param encoded_sent1: token ids (encoded by tokenizer)
    :param config: config class
    :return: each step logits, each step token id
    """
    outputs = model.generate(encoded_sent1,
                             max_length=config.max_target_length,
                             remove_invalid_values=True,
                             output_scores=True,
                             output_hidden_states=True,
                             return_dict_in_generate=True,
                             repetition_penalty=1.1)

    output_tokens = tokenizer.batch_decode(outputs['sequences'],
                                           skip_special_tokens=True,
                                           clean_up_tokenization_spaces=True)[0]

    actions = outputs['sequences'][..., 2:-1]  # remove </s><s> at the beginning and </s> at the end
    # logits = torch.cat([score[0].unsqueeze(0) for score in outputs['decoder_hidden_states']])[1:-2, 0, ...].squeeze()
    scores = torch.stack([score[0] for score in outputs['scores'][1:-2]])
    # logits should be [seq_len, 768]
    # actions should be [1, seq_len]
    # scores should be [seq_len, 50264]

    if scores.size(0) > actions.size(1):
        scores = scores[range(actions.size(1)), ...]

    if scores.size(0) < scores.size(1):
        scores = scores[:-2, ...]
        actions = actions[..., :-3]

    token_probs = F.softmax(scores, dim=1)
    # print("output_tokens (argmax)", output_tokens)
    return token_probs, actions[0], output_tokens


def decode_chain_sampling(model, tokenizer, encoded_sent1, config):
    """
    use sampling on probabilities to get each token
    :param model: the LM
    :param tokenizer:
    :param encoded_sent1: token ids (encoded by tokenizer)
    :param config: config class
    :return: each step logits, each step token id
    """
    try:
        outputs = model.generate(encoded_sent1,
                                 max_length=config.max_target_length,
                                 do_sample=True,  # do sample
                                 temperature=1.2,  # but not too random
                                 remove_invalid_values=True,
                                 early_stopping=True,
                                 output_scores=True,
                                 output_hidden_states=True,
                                 return_dict_in_generate=True)
    except:
        print("Exception!", encoded_sent1)
        outputs = model.generate(encoded_sent1,
                                 max_length=config.max_target_length,
                                 remove_invalid_values=True,
                                 output_scores=True,
                                 output_hidden_states=True,
                                 return_dict_in_generate=True,
                                 repetition_penalty=1.1)

    output_tokens = tokenizer.batch_decode(outputs['sequences'],
                                           skip_special_tokens=True,
                                           clean_up_tokenization_spaces=True)[0]

    actions = outputs['sequences'][..., 2:-1]  # remove </s><s> at the beginning and </s> at the end
    # logits = torch.cat([score[0].unsqueeze(0) for score in outputs['decoder_hidden_states']])[1:-2, 0, ...].squeeze()
    scores = torch.stack([score[0] for score in outputs['scores'][1:-2]])
    # logits should be [seq_len, 768]
    # actions should be [1, seq_len]
    # scores should be [seq_len, 50264]

    if scores.size(0) > actions.size(1):
        scores = scores[range(actions.size(1)), ...]

    if scores.size(0) < scores.size(1):
        scores = scores[:-2, ...]
        actions = actions[..., :-3]

    token_probs = F.softmax(scores, dim=1)
    # print("output_tokens (sampling)", output_tokens)
    return token_probs, actions[0], output_tokens


def bleu_score_original(hyps_token, refs_token, n):
    """
    compute original score. can do 1 v n. in tokenized form.

    :param n: highest order
    :param hyps_token: [hyps_1_tokens]
    :param refs_token: [[ref_1_tokens], [ref_2_tokens]]
    :return: bleu score
    """
    # TODO
    bleu_metric.add_batch(predictions=hyps_token, references=refs_token)
    return bleu_metric.compute(smooth=True, max_order=n)


def get_kg_overlap_new(ctx_toks, ref_toks, hype_toks):
    score = []
    for sent1_tok, sent2_tok, cand_tok in zip(ctx_toks, ref_toks, hype_toks):
        common_ent_12 = list(set(sent1_tok) & set(sent2_tok))

        hit_ent = list(set(sent2_tok) & set(cand_tok))
        real_hit = [ent for ent in hit_ent if ent not in common_ent_12]
        # print(real_hit)
        # print(hit_ent)
        score.append(len(real_hit) / (len(hit_ent) + 1e-5))

    return np.mean(score)


def compute_metric(test_df, batch):
    # pass
    hypes = test_df['predictions'].tolist()
    ctxs = test_df['sent1'].tolist()  # for input copy
    refs = test_df['sent2'].tolist()

    print("Batch %d Evaluation Results...." % batch)

    hypes_toks = [[token.lemma_.lower() for token in nlp(hype)] for hype in hypes]
    refs_toks = [[[token.lemma_.lower() for token in nlp(ref)]] for ref in refs]

    bleu_1 = bleu_score_original(hypes_toks, refs_toks, 1)
    print("BLEU-1:", bleu_1['bleu'])

    hypes_toks = [[token.lemma_.lower() for token in nlp(hype) if not token.is_stop and not token.is_punct] for hype in
                  hypes]
    ctxs_toks = [[token.lemma_.lower() for token in nlp(ctx) if not token.is_stop and not token.is_punct] for ctx in
                 ctxs]
    refs_toks = [[[token.lemma_.lower() for token in nlp(ref) if not token.is_stop and not token.is_punct]] for ref in
                 refs]
    refs_tokks = [ref[0] for ref in refs_toks]

    kg_faith = get_kg_overlap_new(ctxs_toks, refs_tokks, hypes_toks)
    print("KG Faith:", kg_faith)


def run_eval(test_dataloader, model, tokenizer, batch, device, config):
    """
    run evaluation given current model on test set
    :param test_dataloader:
    :param model:
    :return:
    """
    test_df = pd.DataFrame(columns=['sent1', 'sent2', 'predictions'])

    for test_batch in tqdm(test_dataloader, desc="Test Iteration", position=0, leave=True):
        # batch is a dict:
        #
        # input_ids: (batch, token ids)
        # input_tokens: (batch, string)
        # labels_ids: (batch, list of token ids)
        # labels_tokens: (batch, list of strings)
        # spice_scores: (batch, list of float scores)

        outputs = model.generate(torch.tensor(test_batch['input_ids'][0], device=device),
                                 max_length=config.max_target_length,
                                 return_dict_in_generate=True,
                                 top_p=0.92,
                                 repetition_penalty=1.1)

        # print(outputs['sequences'])

        output_tokens = tokenizer.batch_decode(outputs['sequences'],
                                               skip_special_tokens=True,
                                               clean_up_tokenization_spaces=True)

        temp_dict = {'sent1': test_batch['input_tokens'],
                     'sent2': test_batch['labels_tokens'],
                     'predictions': output_tokens}

        test_df = test_df.append(pd.DataFrame.from_dict(temp_dict), ignore_index=True)

    # print(test_df)

    test_df.to_csv(config.output_dir + '/rl_generations_' + str(batch) + '.csv', index=False)
    compute_metric(test_df, batch)
    #
    # print('the mean of sacrebleu score is: ', output_scores['sacrebleu'])
    # print('the mean of original bleu score is: ', output_scores['original bleu'])
    # print('the mean of rouge score is: ', output_scores['rouge'])
    # print('the mean of meteor score is: ', output_scores['meteor'])
    # print('the mean of BERT score is: ', output_scores['BERT'])
    # print('the mean of NIST score is: ', output_scores['NIST'])
    # print('the mean of Dist-1 score is: ', output_scores['Dist-1'])
    # print('the mean of Dist-2 score is: ', output_scores['Dist-2'])
    # print('Average length score is: ', output_scores['Len'])

    # return output_scores['BERT']
