import os
import re
import string
from nltk.metrics import f_measure
from rouge_score import rouge_scorer
import jsonlines
import numpy as np


def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def exact_match(generation_batches, references):
    ems = []
    for generation_batch, reference in zip(generation_batches, references):
        ems_batch = []
        for generation in generation_batch:
            normalized_generation = normalize_text(generation)
            normalized_reference = normalize_text(reference)
            if normalized_generation == normalized_reference:
                ems_batch.append(1)
            else:
                ems_batch.append(0)
        ems.append(np.average(ems_batch))
    return {'em': ems}


def f1_score(generation_batches, references):
    f1_scores = []
    for generation_batch, reference in zip(generation_batches, references):
        f1_scores_batch = []
        for generation in generation_batch:
            score = f_measure(
                set(normalize_text(reference).split()), set(normalize_text(generation).split())
                )
            if score is None:  # answer is the empty string after normalizing
                score = 0.0
            f1_scores_batch.append(score)
        f1_scores.append(np.average(f1_scores_batch))
    return {"f1": f1_scores}


def rouge_l(generation_batches, references):
    scores = []
    for generation_batch, reference in zip(generation_batches, references):
        scores_batch = []
        for generation in generation_batch:
            scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
            eval = scorer.score(reference, generation)
            score = eval["rougeL"].fmeasure
            scores_batch.append(score)
        scores.append(np.average(scores_batch))

    return {"rouge-l": scores}


def evaluating(generation_batches, references):
    # preprocessing
    generation_batches_processed = []
    for generation_batch in generation_batches:
        generation_batch_processed = []
        for generation in generation_batch:
            # the generated answer may be a list
            generation = str(generation)

            generation = generation.strip()
            generation = " ".join(generation.split())
            generation = ", ".join([tmp.strip() for tmp in generation.split(",")])
            generation_batch_processed.append(generation)
        generation_batches_processed.append(generation_batch_processed)

    scores = {}
    scores.update(exact_match(generation_batches_processed, references))
    scores.update(f1_score(generation_batches_processed, references))
    scores.update(rouge_l(generation_batches_processed, references))

    return scores


if __name__ == "__main__":
    # evaluating generalizability in input_length: sft, base, and sota
    # sft
    result_folder = "../train_sft/input_length/generations"
    output_files_short = []
    output_files_long = []
    for generation_result in os.listdir(result_folder):
        if generation_result.endswith(".jsonl"):
            if generation_result.startswith("64k") or generation_result.startswith("128k"):
                output_files_short.append(generation_result)
            elif generation_result.startswith("512k") or generation_result.startswith("1024k"):
                output_files_long.append(generation_result)
    output_files_short.sort()
    output_files_long.sort()

    # 64k, 128k
    answers = []
    generation_batches = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # 512k, 1024k
    answers = []
    generation_batches = []
    for generation_file in output_files_long:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))


    # base
    result_folder = "../results_test/"
    output_files_short = []
    output_files_long = []
    for generation_result in os.listdir(result_folder):
        if generation_result.endswith(".jsonl") and "test_full" in generation_result and "qwen25_7b_instruct_1m" in generation_result:
            if generation_result.startswith("64k") or generation_result.startswith("128k"):
                output_files_short.append(generation_result)
            elif generation_result.startswith("512k") or generation_result.startswith("1024k"):
                output_files_long.append(generation_result)
    output_files_short.sort()
    output_files_long.sort()

    # 64k, 128k
    answers = []
    generation_batches = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # 512k, 1024k
    answers = []
    generation_batches = []
    for generation_file in output_files_long:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # gpt-4.1
    result_folder = "../results_test/"
    output_files_short = []
    output_files_long = []
    for generation_result in os.listdir(result_folder):
        if generation_result.endswith(
                ".jsonl") and "test_full" in generation_result and "gpt41_20250414" in generation_result:
            if generation_result.startswith("64k") or generation_result.startswith("128k"):
                output_files_short.append(generation_result)
            elif generation_result.startswith("512k") or generation_result.startswith("1024k"):
                output_files_long.append(generation_result)
    output_files_short.sort()
    output_files_long.sort()

    # 64k, 128k
    answers = []
    generation_batches = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # 512k, 1024k
    answers = []
    generation_batches = []
    for generation_file in output_files_long:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # gemini 2.5 pro
    result_folder = "../results_test/"
    output_files_short = []
    output_files_long = []
    for generation_result in os.listdir(result_folder):
        if generation_result.endswith(
                ".jsonl") and "test_full" in generation_result and "gemini25_pro" in generation_result:
            if generation_result.startswith("64k") or generation_result.startswith("128k"):
                output_files_short.append(generation_result)
            elif generation_result.startswith("512k") or generation_result.startswith("1024k"):
                output_files_long.append(generation_result)
    output_files_short.sort()
    output_files_long.sort()

    # 64k, 128k
    answers = []
    generation_batches = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # 512k, 1024k
    answers = []
    generation_batches = []
    for generation_file in output_files_long:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # o4-mini
    result_folder = "../results_test/"
    output_files_short = []
    output_files_long = []
    for generation_result in os.listdir(result_folder):
        if generation_result.endswith(
                ".jsonl") and "test_full" in generation_result and "o4mini_20250416" in generation_result:
            if generation_result.startswith("64k") or generation_result.startswith("128k"):
                output_files_short.append(generation_result)
            elif generation_result.startswith("512k") or generation_result.startswith("1024k"):
                output_files_long.append(generation_result)
    output_files_short.sort()
    output_files_long.sort()

    # 64k, 128k
    answers = []
    generation_batches = []
    for generation_file in output_files_short:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))

    # 512k, 1024k
    answers = []
    generation_batches = []
    for generation_file in output_files_long:
        print(generation_file)
        with jsonlines.open(os.path.join(result_folder, f"{generation_file}")) as reader:
            for line in reader:
                if "generations" in line.keys() and line["answer"] != "NULL" and line["sql_type"] != "multi_simple":
                    answers.append(line["answer"])
                    generation_batches.append(line["generations"])
                    # print(line["answer"], line["generations"])
    scores = evaluating(generation_batches, answers)
    metrics = []
    values_all = []
    for metric, values_tmp in scores.items():
        metrics.append(metric)
        values_all.append(str(round(np.mean(values_tmp), 3)))
    print(metrics, "/".join(values_all))