import os
import re
import string
from nltk.metrics import f_measure
from rouge_score import rouge_scorer
import jsonlines
import numpy as np


def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def exact_match(generation_batches, references):
    ems = []
    for generation_batch, reference in zip(generation_batches, references):
        ems_batch = []
        for generation in generation_batch:
            normalized_generation = normalize_text(generation)
            normalized_reference = normalize_text(reference)
            if normalized_generation == normalized_reference:
                ems_batch.append(1)
            else:
                ems_batch.append(0)
        ems.append(np.average(ems_batch))
    return {'em': ems}


def f1_score(generation_batches, references):
    f1_scores = []
    for generation_batch, reference in zip(generation_batches, references):
        f1_scores_batch = []
        for generation in generation_batch:
            score = f_measure(
                set(normalize_text(reference).split()), set(normalize_text(generation).split())
                )
            if score is None:  # answer is the empty string after normalizing
                score = 0.0
            f1_scores_batch.append(score)
        f1_scores.append(np.average(f1_scores_batch))
    return {"f1": f1_scores}


def rouge_l(generation_batches, references):
    scores = []
    for generation_batch, reference in zip(generation_batches, references):
        scores_batch = []
        for generation in generation_batch:
            scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
            eval = scorer.score(reference, generation)
            score = eval["rougeL"].fmeasure
            scores_batch.append(score)
        scores.append(np.average(scores_batch))

    return {"rouge-l": scores}


def evaluating(generation_batches, references):
    # preprocessing
    generation_batches_processed = []
    for generation_batch in generation_batches:
        generation_batch_processed = []
        for generation in generation_batch:
            # the generated answer may be a list
            generation = str(generation)

            generation = generation.strip()
            generation = " ".join(generation.split())
            generation = ", ".join([tmp.strip() for tmp in generation.split(",")])
            generation_batch_processed.append(generation)
        generation_batches_processed.append(generation_batch_processed)

    scores = {}
    scores.update(exact_match(generation_batches_processed, references))
    scores.update(f1_score(generation_batches_processed, references))
    scores.update(rouge_l(generation_batches_processed, references))

    return scores


if __name__ == "__main__":
    in_sql_types = ['multi_ran_filtering_foa', 'multi_ran_organizing', 'multi_simple', 'multi_ran_filtering_ofo',
                       'multi_ran_aggregating', 'multi_ran_filtering_foo']
    in_sql_complexity_threshold = 15
    in_question_focuses = ['author_list', 'title_word_count', 'title_entire', 'author_count', 'author_relationship']


    result_folder = "../train_grpo/grpo/generations_sampled_192"
    output_files_short = []
    output_files_long = []
    for generation_result in os.listdir(result_folder):
        if generation_result.endswith(".jsonl"):
            if generation_result.startswith("64k") or generation_result.startswith("128k"):
                output_files_short.append(generation_result)
            elif generation_result.startswith("512k") or generation_result.startswith("1024k"):
                output_files_long.append(generation_result)
    output_files_short.sort()
    output_files_long.sort()

    print("question focus")
    # question focus for all short samples
    answers_low = []
    generation_batches_low = []
    answers_high = []
    generation_batches_high = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    question_focus = line["focus"]
                    if question_focus in in_question_focuses:
                        answers_low.append(line["answer"])
                        generation_batches_low.append(line["generations"])
                    else:
                        answers_high.append(line["answer"])
                        generation_batches_high.append(line["generations"])
    print("in", len(answers_low), len(generation_batches_low))
    scores = evaluating(generation_batches_low, answers_low)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    print("out", len(answers_high), len(generation_batches_high))
    scores = evaluating(generation_batches_high, answers_high)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))


    print("sql type")
    # sql type for all short samples
    answers_low = []
    generation_batches_low = []
    answers_high = []
    generation_batches_high = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    sql_type = line["sql_type"]
                    if sql_type in in_sql_types:
                        answers_low.append(line["answer"])
                        generation_batches_low.append(line["generations"])
                    else:
                        answers_high.append(line["answer"])
                        generation_batches_high.append(line["generations"])
    print("in", len(answers_low), len(generation_batches_low))
    scores = evaluating(generation_batches_low, answers_low)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    print("out", len(answers_high), len(generation_batches_high))
    scores = evaluating(generation_batches_high, answers_high)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))


    # sql length for all short samples
    print("sql length")
    answers_low = []
    generation_batches_low = []
    answers_high = []
    generation_batches_high = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    sql_complexity = len(line["sql"].split())
                    if sql_complexity <= in_sql_complexity_threshold:
                        answers_low.append(line["answer"])
                        generation_batches_low.append(line["generations"])
                    else:
                        answers_high.append(line["answer"])
                        generation_batches_high.append(line["generations"])
    print("low", len(answers_low), len(generation_batches_low))
    scores = evaluating(generation_batches_low, answers_low)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    print("high", len(answers_high), len(generation_batches_high))
    scores = evaluating(generation_batches_high, answers_high)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))


    # input length for all samples
    print("input length")
    answers_low_short = []
    generation_batches_low_short = []
    answers_low_long = []
    generation_batches_low_long = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers_low_short.append(line["answer"])
                    generation_batches_low_short.append(line["generations"])
    for generation_file in output_files_long:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers_low_long.append(line["answer"])
                    generation_batches_low_long.append(line["generations"])
    print("short", len(answers_low_short), len(generation_batches_low_short))
    scores = evaluating(generation_batches_low_short, answers_low_short)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    print("long", len(answers_low_long), len(generation_batches_low_long))
    scores = evaluating(generation_batches_low_long, answers_low_long)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))