#! /usr/bin/env python3
# coding=utf-8



import math
import torch
import numpy as np
import pandas as pd
import spacy
import kenlm
import sng_parser
from nltk.tokenize import word_tokenize
from datasets import load_metric
from nltk.util import ngrams
from pycocoevalcap.spice.spice import Spice
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from spacy.tokens import Token
from spacy.lang.en.stop_words import STOP_WORDS  # import stop words from language data
from transformers import GPT2LMHeadModel
from transformers import (
    GPT2Tokenizer, GPT2Config,
)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
config = GPT2Config.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium', config=config)
model.eval()


stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter)  # set attribute with getter


nlp = spacy.load("en_core_web_sm")
meteor_metric = load_metric('meteor')
rouge_metric = load_metric('rouge')
bleu_metric = load_metric('bleu')
sacrebleu_metric = load_metric('sacrebleu')
bert_metric = load_metric('bertscore')


# ken_lm = kenlm.Model('yelp.bin')
ken_lm = kenlm.Model('allsides.bin')

# Optional model configuration
model_args = ClassificationArgs()
model_args.labels_list = [0, 1]

scorer = Spice()


# Create a ClassificationModel
# class_model = ClassificationModel(
#     "roberta", "./yelp_classifier", cuda_device=1, args=model_args
# )

class_model = ClassificationModel(
    "roberta", "./allsides_classifier", cuda_device=1, args=model_args
)


def compute_ppl(texts):
    """
    return Perplexity of given text. Lower is better.
    :param text: list of string
    :return:
    """
    ppls = []
    for text in texts:
        text = str(text).strip()
        if len(text) == 0:
            return 0
        input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
        input_ids = input_ids.to('cpu')
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]
        ppls.append(math.exp(loss))
    return np.mean(ppls)


def get_kg_overlap(sent1s, sent2s, sents):
    score = []
    for sent1, sent2, sent in zip(sent1s, sent2s, sents):
        sent1_ents = [token.lemma_.lower() for token in nlp(sent1) if token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop]
        sent2_ents = [token.lemma_.lower() for token in nlp(sent2) if token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop]
        sent_ents = [token.lemma_.lower() for token in nlp(sent) if token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop]

        common_ent_12 = list(set(sent1_ents + sent2_ents))
        hit_ent = list(set(sent_ents + sent2_ents))
        real_hit = [ent for ent in hit_ent if ent not in common_ent_12]
        score.append(len(real_hit) / (len(sent_ents) + 1e-5))
    return np.mean(score)


def compute_spice(hypes, refs):
    """
    :param hypes:
    :param refs:
    :return:
    """
    score, scores = scorer.compute_score(hypes, refs)
    return score, scores


def meteor_score(hyps, refs):
    """
    compute meteor score

    alpha: Parameter for controlling relative weights of precision and recall. default: 0.9
    beta: Parameter for controlling shape of penalty as a function of fragmentation. default: 3
    gamma: Relative weight assigned to fragmentation penalty. default: 0.5

    :param hyps: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    :param refs: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    :return:
    """

    # print(meteor_metric.inputs_description)
    assert len(hyps) == len(refs)
    meteor_metric.add_batch(predictions=hyps, references=refs)
    return meteor_metric.compute()


def rouge_score(hyps, refs):
    """
    compute rouge score

    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `" "`.
        See details in https://github.com/huggingface/datasets/issues/617

    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.

    :param hyps: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    :param refs: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    :return:
    """
    rouge_metric.add_batch(predictions=hyps, references=refs)
    return rouge_metric.compute(use_stemmer=True)


def bleu_score_original(hyps_token, refs_token, n):
    """
    compute original score. can do 1 v n. in tokenized form.

    :param n: highest order
    :param hyps_token: [hyps_1_tokens]
    :param refs_token: [[ref_1_tokens], [ref_2_tokens]]
    :return: bleu score
    """
    # TODO
    bleu_metric.add_batch(predictions=hyps_token, references=refs_token)
    return bleu_metric.compute(smooth=True, max_order=n)


def bleu_score_sacre(hyps, refs):
    """
    compute bleu score

    smooth: The smoothing method to use
    smooth_value: For 'floor' smoothing, the floor to use
    force: Ignore data that looks already tokenized
    lowercase: Lowercase the data
    tokenize: The tokenizer to use

    :param hyps: <str> to be evaluated
    :param refs: list of <str> as refs
    :return:
    """
    refs = [[ref] for ref in refs]
    # print(sacrebleu_metric.inputs_description)
    sacrebleu_metric.add_batch(predictions=hyps, references=refs)
    return sacrebleu_metric.compute(lowercase=True)


def bert_score(hyps, refs):
    """
    compute bert score

    `model_type` (str): bert specification, default using the suggested
    model for the target language; has to specify at least one of
    `model_type` or `lang`
    `num_layers` (int): the layer of representation to use.
    default using the number of layers tuned on WMT16 correlation data
    `verbose` (bool): turn on intermediate status update
    `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict
    `device` (str): on which the contextual embedding model will be allocated on.
    If this argument is None, the model lives on cuda:0 if cuda is available.
    `nthreads` (int): number of threads
    `batch_size` (int): bert score processing batch size
    at least one of `model_type` or `lang`. `lang` needs to be
    specified when `rescale_with_baseline` is True.
    `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
    :param hyps: (list of str): prediction/candidate sentences
    :param refs: (list of str or list of list of str): reference sentences
    :return:
    """
    bert_metric.add_batch(predictions=hyps, references=[[ref] for ref in refs])
    return bert_metric.compute(lang='en', device="cuda:1", rescale_with_baseline=False)


def DIST_score(hyps, n=1):
    """
    compute DIST score

    :param hyps: (list): hyps_1_tokens
    :param n:
    :return:
    """
    all_toks = [token.text for hyp in hyps for token in nlp(hyp)]
    DIST = ngrams(all_toks, n)
    all_grams = [' '.join(grams) for grams in DIST]
    DIST = len(list(set(all_grams))) / len(all_grams)
    return DIST


def compute_ppl_kenlm(texts):
    """
    compute ppl of the list of text
    :param texts:
    :return:
    """
    texts_transfered = [' '.join(word_tokenize(itm.lower().strip())) for itm in texts]
    sum = 0
    words = []
    length = 0
    for i, line in enumerate(texts_transfered):
        words += [word for word in line.split()]
        length += len(line.split())
        score = ken_lm.score(line)
        sum += score
    return math.pow(10, -sum / length)


def compute_acc(hypes, target_label):
    labels, _ = class_model.predict(hypes)
    if target_label == 1:
        # print(np.sum(labels))
        # print(len(hypes))
        return np.sum(labels) / len(hypes)
    else:
        return 1 - (np.sum(labels) / len(hypes))


if __name__ == '__main__':
    # pass
    res_path = './test_data/test_generations_allsides_l2r_rd_ok.csv'
    test_df = pd.read_csv(res_path, header=0)

    hypes = test_df['predictions'].tolist()
    ctxs = test_df['sent1'].tolist()  # for input copy
    refs = test_df['sent2'].tolist()

    hypes_dict = {str(i): [hypes[i]] for i in range(len(hypes))}
    refs_dict = {str(i): [refs[i]] for i in range(len(refs))}

    hypes_toks = [[token.lemma_.lower() for token in nlp(hype)] for hype in hypes]
    refs_toks = [[[token.lemma_.lower() for token in nlp(ref)]] for ref in refs]

    acc = compute_acc(hypes, 1)
    print("ACC:", acc)

    bleu_1 = bleu_score_original(hypes_toks, refs_toks, 1)
    print("BLEU-1:", bleu_1['bleu'])

    bleu_2 = bleu_score_original(hypes_toks, refs_toks, 2)
    print("BLEU-2:", bleu_2['bleu'])

    # rouge_sc = rouge_score(hypes, refs)
    # print("ROUGE-L:", rouge_sc['rougeL'])

    # bleu_sac = bleu_score_sacre(hypes, refs)
    # print("BLEU-Sacre:", bleu_sac['score'])

    bert_result = bert_score(hypes, refs)
    # print("BERT Score P:", np.mean(bert_result['precision']))
    # print("BERT Score R:", np.mean(bert_result['recall']))
    print("BERT Score F1:", np.mean(bert_result['f1']))
    # print("BERT Score F1:", bert_result)

    # DIST_1 = DIST_score(hypes, 1)
    # print("hyps Dist-1:", DIST_1)
    #
    # DIST_2 = DIST_score(hypes, 2)
    # print("hyps Dist-2:", DIST_2)
    #
    # DIST_1 = DIST_score(refs, 1)
    # print("refs Dist-1:", DIST_1)
    #
    # DIST_2 = DIST_score(refs, 2)
    # print("refs Dist-2:", DIST_2)

    PPL = compute_ppl_kenlm(hypes)
    print("hyps PPL:", PPL)

    PPL = compute_ppl_kenlm(refs)
    print("refs PPL:", PPL)

    kg_faith = get_kg_overlap(ctxs, refs, hypes)
    print("KG Faith:", kg_faith)
