import json
import jieba
import regex
import string
import numpy as np
from rouge import Rouge
from collections import Counter
import sys; sys.path.append('./')
from tools.utils import *
from evaluation.bleu import *
from evaluation.utils import *

def normalize_answer(s):
    def remove_articles(text):
        return regex.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

def rougel_score(prediction, ground_truth, lang):
    if lang == "cn":
        prediction = jieba.lcut(prediction)
        ground_truth = jieba.lcut(ground_truth)
        
    rouge = Rouge()
    # no normalization
    try:
        scores = rouge.get_scores(prediction, ground_truth, avg=True)
    except ValueError:  # "Hypothesis is empty."
        return 0.0
    return scores["rouge-l"]["f"]

def distinct_n_sentence_level(sentence, n, lang="en"):
    """
    Compute distinct-N for a single sentence.
    :param sentence: a list of words.
    :param n: int, ngram.
    :return: float, the metric value.
    """
    if lang == "en":
        sentence = sentence.split()
    elif lang == "cn":
        # sentence = jieba.lcut(sentence)
        sentence = list(sentence)
    if len(sentence) == 0:
        return 0.0  # Prevent a zero division
    distinct_ngrams = set(ngrams(sentence, n))
    return len(distinct_ngrams) / len(sentence)

def rl(prediction, ground_truths, lang="en"):
    return rougel_score(prediction, ground_truths, lang)

def ems(prediction, ground_truths):
    if isinstance(ground_truths, str):
        return exact_match_score(prediction, ground_truths)
    return max([exact_match_score(prediction, gt) for gt in ground_truths])

def f1_score(prediction, ground_truth, lang="en"):
    if lang == "en":
        prediction_tokens = normalize_answer(prediction).split()
        ground_truth_tokens = normalize_answer(ground_truth).split()
    elif lang == "cn":
        prediction_tokens = jieba.lcut(normalize_answer(prediction))
        ground_truth_tokens = jieba.lcut(normalize_answer(ground_truth))
    
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def f1(prediction, ground_truths, lang="en"):
    if isinstance(ground_truths, str):
        return f1_score(prediction, ground_truths, lang)
    return max([f1_score(prediction, gt, lang) for gt in ground_truths])


def calculate_f1_score_en(predicted_answer, ground_truth_answer):
    # Convert the answers to sets of tokens
    predicted_tokens = set(predicted_answer.lower().split())
    ground_truth_tokens = set(ground_truth_answer.lower().split())

    # Calculate precision, recall, and F1 score
    precision = len(predicted_tokens.intersection(ground_truth_tokens)) / len(predicted_tokens)
    recall = len(predicted_tokens.intersection(ground_truth_tokens)) / len(ground_truth_tokens)

    if precision + recall == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)

    return f1_score

def avg_bleu(prediction, ground_truths, lang="en"):
    if lang == "en":
        bleu_1, bleu_2, bleu_3, bleu_4 = calc_bleu_score_en(prediction, ground_truths)
    elif lang == "cn":
        bleu_1, bleu_2, bleu_3, bleu_4 = calc_bleu_score_cn(prediction, ground_truths)
    return (bleu_1 + bleu_2 + bleu_3 + bleu_4) / 4

def eval_dialogue_system(infile):
    lines = open(infile, 'r').readlines()[1:]

    f1_scores = []
    rl_scores = []
    answer_lengths = []
    for line in lines:
        line = json.loads(line)
        answer = line['answer']
        output = line['output'][0]

        f1_scores.append(f1(output, answer))
        rl_scores.append(rl(output, answer))
        answer_lengths.append(len(output.split()))

    F1 = round(np.mean(f1_scores), 4)
    RL = round(np.mean(rl_scores), 4)
    lens = round(np.mean(answer_lengths), 4)

    return F1, RL, lens

if __name__ == "__main__":
    model_type = "chatgpt"

    input_path = "dataset_input/hotpotqa/" + model_type + "/get_answer_zero-shot.json"
    output_path = "dataset_output/hotpotqa/" + model_type + "/answer_zero-shot.json"

    results = mapping_input_output(input_path, output_path, model_type="chatgpt")

    em, f1_scores, ori_types, cor_types = 0, [], [], []
    for sample in results:
        answer = sample["answer"]
        predict = sample["prediction"]

        ori_types.append(sample["type"])
        f1_scores.append(calculate_f1_score(predict, answer))
        if answer in predict:
            em += 1
            cor_types.append(sample["type"])
        
    # f1 = f1_score([sample["answer"] for sample in results], [sample["prediction"] for sample in results], average="micro")
    
    # analysis
    ori_b_type_count = sum([t == "bridge" for t in ori_types])
    cor_b_type_count = sum([t == "bridge" for t in cor_types])
    cor_c_type_count = len(cor_types) - cor_b_type_count
    
    print("Original Bridge Type Number: {}, Original Comparision Type Number: {}".format(ori_b_type_count, len(results) - ori_b_type_count))
    print("Correct Bridge Type Number: {}, Correct Comparision Type Number: {}".format(cor_b_type_count, cor_c_type_count))
    
    print("EM: {}, F1 Score: {}".format(em/len(results), sum(f1_scores)/len(f1_scores)))
