#! /usr/bin/env python3
# coding=utf-8


import math
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import spacy
import kenlm
import sng_parser
from nltk.tokenize import word_tokenize
from datasets import load_metric
from nltk.util import ngrams
from pycocoevalcap.spice.spice import Spice
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from spacy.tokens import Token
from spacy.lang.en.stop_words import STOP_WORDS  # import stop words from language data
from transformers import GPT2LMHeadModel
from transformers import (
    GPT2Tokenizer, GPT2Config,
)

# tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
# config = GPT2Config.from_pretrained('gpt2-medium')
# model = GPT2LMHeadModel.from_pretrained('gpt2-medium', config=config)
# model.eval()


stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter)  # set attribute with getter


nlp = spacy.load("en_core_web_sm")
meteor_metric = load_metric('meteor')
rouge_metric = load_metric('rouge')
bleu_metric = load_metric('bleu')
sacrebleu_metric = load_metric('sacrebleu')
bert_metric = load_metric('bertscore')


# ken_lm = kenlm.Model('yelp.bin')
# ken_lm = kenlm.Model('allsides.bin')
ken_lm = kenlm.Model('formal.bin')

# Optional model configuration
model_args = ClassificationArgs()
model_args.labels_list = [0, 1]

scorer = Spice()


# Create a ClassificationModel
# class_model = ClassificationModel(
#     "roberta", "./yelp_classifier", cuda_device=1, args=model_args
# )

# class_model = ClassificationModel(
#     "roberta", "./allsides_classifier_new", cuda_device=1, args=model_args
# )

class_model = ClassificationModel(
    "roberta", "./formality_classifier", cuda_device=1, args=model_args
)


def compute_PINC(refs_toks, hypes_toks):
    """
    :param refs_toks:
    :param hypes_toks:
    :return:
    """
    score = 0
    for ref_tok, hypes_tok in zip(refs_toks, hypes_toks):
        bigrams_ctx = list(ngrams(ref_tok, 2))
        bigrams_hyp = list(ngrams(hypes_tok, 2))

        # print(ctx_tok)
        bi_sc = 1 - (len(set(bigrams_ctx) & set(bigrams_hyp)) / (len(set(bigrams_hyp)) + 1e-5))
        uni_sc = 1 - (len(set(ref_tok) & set(hypes_tok)) / (len(set(hypes_tok)) + 1e-5))
        score += (bi_sc + uni_sc) / 2
    return score / len(refs_toks)



def get_kg_overlap(sent1s, sent2s, sents):
    score = []
    for sent1, sent2, sent in zip(sent1s, sent2s, sents):
        sent1_ents = [token.lemma_.lower() for token in nlp(sent1) if token.pos_ in ['NOUN'] and not token.is_stop]
        sent2_ents = [token.lemma_.lower() for token in nlp(sent2) if token.pos_ in ['NOUN'] and not token.is_stop]
        sent_ents = [token.lemma_.lower() for token in nlp(sent) if token.pos_ in ['NOUN'] and not token.is_stop]

        common_ent_12 = list(set(sent1_ents + sent2_ents))
        hit_ent = list(set(sent_ents + sent2_ents))
        real_hit = [ent for ent in hit_ent if ent not in common_ent_12]
        score.append(len(real_hit) / (len(sent_ents) + 1e-5))
    return np.mean(score)


def get_kg_overlap_new(ctx_toks, ref_toks, hype_toks):
    score = []
    for sent1_tok, sent2_tok, cand_tok in zip(ctx_toks, ref_toks, hype_toks):
        common_ent_12 = list(set(sent1_tok) & set(sent2_tok))

        hit_ent = list(set(sent2_tok) & set(cand_tok))
        real_hit = [ent for ent in hit_ent if ent not in common_ent_12]
        # print(real_hit)
        # print(hit_ent)
        score.append(len(real_hit) / (len(hit_ent) + 1e-5))

    return np.mean(score)


def compute_spice(hypes, refs):
    """
    :param hypes:
    :param refs:
    :return:
    """
    score, scores = scorer.compute_score(hypes, refs)
    return score, scores


def meteor_score(hyps, refs):
    """
    compute meteor score

    alpha: Parameter for controlling relative weights of precision and recall. default: 0.9
    beta: Parameter for controlling shape of penalty as a function of fragmentation. default: 3
    gamma: Relative weight assigned to fragmentation penalty. default: 0.5

    :param hyps: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    :param refs: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    :return:
    """

    # print(meteor_metric.inputs_description)
    assert len(hyps) == len(refs)
    meteor_metric.add_batch(predictions=hyps, references=refs)
    return meteor_metric.compute()


def rouge_score(hyps, refs):
    """
    compute rouge score

    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `" "`.
        See details in https://github.com/huggingface/datasets/issues/617

    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.

    :param hyps: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    :param refs: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    :return:
    """
    rouge_metric.add_batch(predictions=hyps, references=refs)
    return rouge_metric.compute(use_stemmer=True)


def bleu_score_original(hyps_token, refs_token, n):
    """
    compute original score. can do 1 v n. in tokenized form.

    :param n: highest order
    :param hyps_token: [hyps_1_tokens]
    :param refs_token: [[ref_1_tokens], [ref_2_tokens]]
    :return: bleu score
    """
    bleu_metric.add_batch(predictions=hyps_token, references=refs_token)
    return bleu_metric.compute(smooth=True, max_order=n)


def bleu_score_sacre(hyps, refs):
    """
    compute bleu score

    smooth: The smoothing method to use
    smooth_value: For 'floor' smoothing, the floor to use
    force: Ignore data that looks already tokenized
    lowercase: Lowercase the data
    tokenize: The tokenizer to use

    :param hyps: <str> to be evaluated
    :param refs: list of <str> as refs
    :return:
    """
    refs = [[ref] for ref in refs]
    # print(sacrebleu_metric.inputs_description)
    sacrebleu_metric.add_batch(predictions=hyps, references=refs)
    return sacrebleu_metric.compute(lowercase=True)


def iBLEU(old_bleu, hyps_token, ctx_toks, n, alpha):
    bleu_metric.add_batch(predictions=hyps_token, references=ctx_toks)
    extra = bleu_metric.compute(smooth=True, max_order=n)
    return alpha * old_bleu + (1 - alpha) * extra['bleu']


def bert_score(hyps, refs):
    """
    compute bert score

    `model_type` (str): bert specification, default using the suggested
    model for the target language; has to specify at least one of
    `model_type` or `lang`
    `num_layers` (int): the layer of representation to use.
    default using the number of layers tuned on WMT16 correlation data
    `verbose` (bool): turn on intermediate status update
    `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict
    `device` (str): on which the contextual embedding model will be allocated on.
    If this argument is None, the model lives on cuda:0 if cuda is available.
    `nthreads` (int): number of threads
    `batch_size` (int): bert score processing batch size
    at least one of `model_type` or `lang`. `lang` needs to be
    specified when `rescale_with_baseline` is True.
    `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
    :param hyps: (list of str): prediction/candidate sentences
    :param refs: (list of str or list of list of str): reference sentences
    :return:
    """
    bert_metric.add_batch(predictions=hyps, references=[[ref] for ref in refs])
    return bert_metric.compute(lang='en', device="cuda:1", rescale_with_baseline=True)


def DIST_score(hyps, n=1):
    """
    compute DIST score

    :param hyps: (list): hyps_1_tokens
    :param n:
    :return:
    """
    all_toks = [token.text for hyp in hyps for token in nlp(hyp)]
    DIST = ngrams(all_toks, n)
    all_grams = [' '.join(grams) for grams in DIST]
    DIST = len(list(set(all_grams))) / len(all_grams)
    return DIST


def compute_ppl_kenlm(texts):
    """
    compute ppl of the list of text
    :param texts:
    :return:
    """
    texts_transfered = [' '.join(word_tokenize(itm.lower().strip())) for itm in texts]
    sum = 0
    words = []
    length = 0
    for i, line in enumerate(texts_transfered):
        words += [word for word in line.split()]
        length += len(line.split())
        score = ken_lm.score(line)
        sum += score
    return math.pow(10, -sum / length)


def compute_acc(hypes, target_label):
    labels, _ = class_model.predict([str(hyp) for hyp in hypes])
    if target_label == 1:
        # print(np.sum(labels))
        # print(len(hypes))
        return np.sum(labels) / len(hypes)
    else:
        return 1 - (np.sum(labels) / len(hypes))


def compute_acc_single(hypes):
    labels, prob = class_model.predict(hypes)
    return F.softmax(torch.tensor(prob))


def run_eval(df_name):
    # pass
    res_path = df_name
    test_df = pd.read_csv(res_path, header=0)

    hypes = test_df['predictions'].tolist()
    ctxs = test_df['sent1'].tolist()  # for input copy
    refs = test_df['sent2'].tolist()

    hypes_dict = {str(i): [hypes[i]] for i in range(len(hypes))}
    refs_dict = {str(i): [refs[i]] for i in range(len(refs))}

    hypes_toks = [[token.lemma_.lower() for token in nlp(str(hype))] for hype in hypes]
    ctxs_toks = [[[token.lemma_.lower() for token in nlp(str(ctx))]] for ctx in ctxs]
    refs_toks = [[[token.lemma_.lower() for token in nlp(str(ref))]] for ref in refs]
    refs_tokks = [ref[0] for ref in refs_toks]

    acc = compute_acc(hypes, 1)
    print("ACC:", acc)

    bleu_1 = bleu_score_original(hypes_toks, refs_toks, 1)
    print("BLEU-1:", bleu_1['bleu'])

    i_bleu = iBLEU(bleu_1['bleu'], hypes_toks, ctxs_toks, 1, 0.5)
    print("i BLEU:", i_bleu)

    # bleu_2 = bleu_score_original(hypes_toks, refs_toks, 2)
    # print("BLEU-2:", bleu_2['bleu'])

    # rouge_sc = rouge_score(hypes, refs)
    # print("ROUGE-L:", rouge_sc['rougeL'])

    # bleu_sac = bleu_score_sacre(hypes, refs)
    # print("BLEU-Sacre:", bleu_sac['score'])

    # bert_result = bert_score(hypes, refs)
    # print("BERT Score P:", np.mean(bert_result['precision']))
    # print("BERT Score R:", np.mean(bert_result['recall']))
    # print("BERT Score F1:", np.mean(bert_result['f1']))
    # print("BERT Score F1:", bert_result)

    # DIST_1 = DIST_score(hypes, 1)
    # print("hyps Dist-1:", DIST_1)
    #
    # DIST_2 = DIST_score(hypes, 2)
    # print("hyps Dist-2:", DIST_2)
    #
    # DIST_1 = DIST_score(refs, 1)
    # print("refs Dist-1:", DIST_1)
    #
    # DIST_2 = DIST_score(refs, 2)
    # print("refs Dist-2:", DIST_2)

    PPL = compute_ppl_kenlm([str(hyp) for hyp in hypes])
    print("hyps PPL:", PPL)

    PPL = compute_ppl_kenlm([str(ref) for ref in refs])
    print("refs PPL:", PPL)

    # kg_faith = get_kg_overlap(ctxs, refs, hypes)
    # print("KG Faith:", kg_faith)

    hypes_toks = [[token.lemma_.lower() for token in nlp(str(hype)) if not token.is_stop and not token.is_punct] for hype in
                  hypes]
    ctxs_toks = [[token.lemma_.lower() for token in nlp(str(ctx)) if not token.is_stop and not token.is_punct] for ctx in
                 ctxs]
    refs_toks = [[[token.lemma_.lower() for token in nlp(str(ref)) if not token.is_stop and not token.is_punct]] for ref in
                 refs]
    refs_tokks = [ref[0] for ref in refs_toks]

    kg_faith = get_kg_overlap_new(ctxs_toks, refs_tokks, hypes_toks)
    print("KG Faith:", kg_faith)

    # pinc_sc = compute_PINC(refs_tokks, hypes_toks)
    # print("PINC Score:", pinc_sc)


def run_eval_single(prediction, sent1, sent2):
    # pass

    hypes = [prediction]
    ctxs = [sent1]  # for input copy
    refs = [sent2]

    hypes_dict = {str(i): [hypes[i]] for i in range(len(hypes))}
    refs_dict = {str(i): [refs[i]] for i in range(len(refs))}

    hypes_toks = [[token.lemma_.lower() for token in nlp(str(hype))] for hype in hypes]
    ctxs_toks = [[[token.lemma_.lower() for token in nlp(str(ctx))]] for ctx in ctxs]
    refs_toks = [[[token.lemma_.lower() for token in nlp(str(ref))]] for ref in refs]
    refs_tokks = [ref[0] for ref in refs_toks]

    acc = compute_acc_single(hypes)
    print("ACC:", acc)

    bleu_1 = bleu_score_original(hypes_toks, refs_toks, 1)
    print("BLEU-1:", bleu_1['bleu'])

    bert_result = bert_score(hypes, refs)
    # print("BERT Score P:", np.mean(bert_result['precision']))
    # print("BERT Score R:", np.mean(bert_result['recall']))
    print("BERT Score F1:", np.mean(bert_result['f1']))
    # print("BERT Score F1:", bert_result)

    PPL = compute_ppl_kenlm([str(hyp) for hyp in hypes])
    print("hyps PPL:", PPL)

    PPL = compute_ppl_kenlm([str(ref) for ref in refs])
    print("refs PPL:", PPL)


    hypes_toks = [[token.lemma_.lower() for token in nlp(str(hype)) if not token.is_stop and not token.is_punct] for hype in
                  hypes]
    ctxs_toks = [[token.lemma_.lower() for token in nlp(str(ctx)) if not token.is_stop and not token.is_punct] for ctx in
                 ctxs]
    refs_toks = [[[token.lemma_.lower() for token in nlp(str(ref)) if not token.is_stop and not token.is_punct]] for ref in
                 refs]
    refs_tokks = [ref[0] for ref in refs_toks]

    kg_faith = get_kg_overlap_new(ctxs_toks, refs_tokks, hypes_toks)
    print("KG Faith:", kg_faith)


if __name__ == '__main__':
    # run_eval('./test_data/yelp/test_gen_yelp_bart.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_500.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_1000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_2000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_3000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_4000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_5000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_8000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_10000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_20000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg.csv')

    # run_eval('./test_data/allsides/rd/test_gen_allsides_rd_500.csv')
    # run_eval('./test_data/allsides/rd/test_gen_allsides_rd_1000.csv')
    # run_eval('./test_data/allsides/rd/test_gen_allsides_rd_2000.csv')
    # run_eval('./test_data/allsides/rd/test_gen_allsides_rd_3000.csv')
    # run_eval('./test_data/allsides/rd/test_gen_allsides_rd_4000.csv')
    # run_eval('./test_data/allsides/rd/test_gen_allsides_rd_5000.csv')

    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_bart.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_500.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_1000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_2000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_3000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_4000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_5000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_8000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_10000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_20000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg.csv')

    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_bart.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_500.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_1000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_2000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_3000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_4000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_5000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_8000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_10000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm_20000.csv')
    # run_eval('./test_data/yelp/lm/test_gen_yelp_p2n_lm.csv')

    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_bart.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_500.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_1000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_2000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_3000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_4000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_5000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_8000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_10000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg_20000.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_p2n_lm_kg.csv')

    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_500.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_1000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_2000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_3000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_4000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_5000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_8000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_10000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd_20000.csv')
    # run_eval('./test_data/yelp/rd/test_gen_yelp_p2n_rd.csv')

    # run_eval('./test_data/allsides/lm/test_gen_allsides_bart.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_500.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_1000.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_1500.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_2000.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_2500.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_3000.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm_3500.csv')
    # run_eval('./test_data/allsides/lm/test_gen_allsides_lm.csv')

    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_bart.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_500.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_1000.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_1500.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_2000.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_2500.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_3000.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_3500.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg.csv')

    # run_eval('./test_data/formal/test_gen_formal_rd.csv')
    # run_eval('./test_data/formal/test_gen_formal_lm.csv')
    # run_eval('./test_data/formal/test_gen_formal_lm_kg_200_04.csv')

    # run_eval('./test_data/formal/CAE/cae_formality_rd.csv')
    # run_eval('./test_data/formal/CAE/cae_formality_rd.csv')
    # run_eval('./test_data/formal/StyleTrans/style_transfer_formality_rd_10000.csv')
    #
    # run_eval('./test_data/formal/lm_kg/test_gen_formal_lm_kg_50_04.csv')
    # run_eval('./test_data/formal/lm_kg/test_gen_formal_lm_kg_200_03.csv')
    # run_eval('./test_data/formal/lm_kg/test_gen_formal_lm_kg_200_04.csv')
    # run_eval('./test_data/formal/lm_kg/test_gen_formal_lm_kg_200_05.csv')
    # run_eval('./test_data/formal/lm_kg/test_gen_formal_lm_kg_500_04.csv')

    # run_eval('./test_data/yelp/ablation/test_gen_yelp_lm_kg_10_06.csv')
    # run_eval('./test_data/yelp/ablation/test_gen_yelp_lm_kg_100_06.csv')
    # run_eval('./test_data/yelp/ablation/cae_yelp_lm_kg_pred.csv')

    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_100_05.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_100_06.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_10_06.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_200_06.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_200_07.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_200_08.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_280_06.csv')

    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_100_05.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_100_06.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_10_06.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_200_06.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_200_07.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_200_08.csv')
    # run_eval('./test_data/yelp/lm_kg/test_gen_yelp_lm_kg_50_06.csv')

    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_50_04.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_500_03.csv')
    # run_eval('./test_data/allsides/lm_kg/test_gen_allsides_lm_kg_500_04.csv')
    run_eval('./test_data/yelp/style_transformer/style_transformer_yelp_lm_kg_1000.csv')
    run_eval('./test_data/yelp/style_transformer/style_transformer_yelp_lm_kg_5000.csv')
    run_eval('./test_data/yelp/style_transformer/style_transformer_yelp_lm_kg_10000.csv')
    # run_eval('./test_data/yelp/CAE/cae_yelp_lm_pred.csv')
    # run_eval('./test_data/yelp/CAE/cae_yelp_lm_kg_pred.csv')

    # prediction = "otherwise we will a terrible experience and we will go again ."
    # sent1 = "otherwise a great experience and we will go again ."
    # sent2 = "otherwise a terrible experience and we will not go again ."
    #
    # run_eval_single(sent2, sent1, sent2)
    #
    # run_eval_single(prediction, sent1, sent2)
    #
    # prediction = "otherwise i 'll have to have him put to sleep"
    # run_eval_single(prediction, sent1, sent2)
    #
    # prediction = "otherwise a terrible experience and we will go again ."
    # run_eval_single(prediction, sent1, sent2)
    #
    # prediction = "overall, not a good experience and we will not go again."
    # run_eval_single(prediction, sent1, sent2)
