import os
import json
import re
from datasets import load_dataset
from evaluation import Evaluator
from collections import defaultdict
import argparse


def extract_sql_queries(text):
    sql_queries = re.findall(r"```sql(.*?)```", text, re.DOTALL)
    return sql_queries


def extract_output_sql_queries(prediction):
    _logging = defaultdict(int)
    if prediction.count("CoT") > 1:
        _logging["duplicate_cot"] += 1

    if "===== LƯỢT CỦA BẠN" in prediction:
        _logging["data_redundancy"] += 1

    prediction = prediction.split("===== LƯỢT CỦA BẠN")[0]

    sqls = extract_sql_queries(prediction)
    if len(sqls) == 0:
        _logging["no_sql_in_response"] += 1
        return prediction, _logging

    if len(sqls) == 1:
        prediction = sqls[0]
    elif len(sqls) > 1:
        _logging["sub_sql"] += 1
        prediction = sorted(sqls, key=len)[0]

    prediction = prediction.replace("\n", " ").strip()
    return prediction, _logging


def extract_tokens(sql):
    removed_tokens = ["(", ")", ";", ","]
    for rm_token in removed_tokens:
        sql = sql.replace(rm_token, " ")
    tokens = sql.split(" ")
    tokens = [item for sublist in tokens for item in sublist.split(".")]
    tokens = [
        token
        for idx, token in enumerate(tokens)
        if not ('"' in token and idx != 0 and tokens[idx - 1] == "=")
    ]
    tokens = [x for x in tokens if x not in removed_tokens + [""]]
    tokens = [x for x in tokens]
    return tokens


def main():
    # Initialize argument parser
    parser = argparse.ArgumentParser(description="SQL Query Prediction Evaluation")
    parser.add_argument("--level", type=str, default="syllable", help="Level type (e.g., syllable)")
    parser.add_argument("--version", type=str, required=True, help="Version of the predictions")
    parser.add_argument("--total", type=int, default=1908, help="Total number of predictions")
    parser.add_argument("--dataset_url", type=str,
                        default="https://huggingface.co/datasets/TeeA/VinAIResearch-ViText2SQL/resolve/main/syllable-level/test-00000-of-00001.parquet",
                        help="Dataset URL")
    # parser.add_argument("--token", type=str, required=True, help="Hugging Face API token")

    # Parse arguments
    args = parser.parse_args()

    evaluator = Evaluator()
    dataset = load_dataset(
        "parquet",
        data_files={"test": args.dataset_url},
        # token=args.token
    )

    version = args.version
    path = f"benchmark/predictions/benchmark__{version}"
    direct_path = path
    print(direct_path)

    print(len(os.listdir(direct_path)))

    sql_keywords = [
        "add", "all", "alter", "and", "as", "asc", "avg", "between", "by", "char", "column",
        "count", "create", "delete", "desc", "distinct", "drop", "exists", "from", "group",
        "having", "in", "index", "inner", "insert", "into", "is", "join", "left", "like",
        "max", "min", "not", "null", "or", "order", "outer", "select", "set", "sum", "table",
        "union", "update", "values", "where"
    ]
    accept_redundancy_sql_keywords = ["order", "desc", "by", "limit", "asc"]

    global_logging = defaultdict(int)
    IS_print = False
    total = args.total

    predictions = []
    for index in range(total):
        ground_truth = dataset["test"][index]["query"]
        with open(f"{direct_path}/{index}.json", encoding="utf-8") as _f:
            data = json.load(_f)
            prediction, _logging = extract_output_sql_queries(data["predict"])

        hardness = evaluator.eval_hardness(eval(dataset["test"][index]["sql"]))
        question = dataset["test"][index]["question"]

        if IS_print:
            print("=" * 50)
            print((f"**[{hardness}] - {question}**"))
            print((f"**GroundTruth**: {ground_truth}"))
            print((f"**Prediction_**: {prediction}"))

        prediction = prediction.replace('"', "").replace("'", "").replace("count(*)", "count ( * )")
        ground_truth = ground_truth.replace('"', "").replace("'", "")

        if ground_truth in prediction:
            _logging["1.data_redundancy_after_remove_0"] += 1

            text_after = prediction[len(ground_truth):].strip()
            if text_after.startswith("order by"):
                _logging["1.data_redundancy_with_order_by"] += 1
            if any(
                    f" {keyword} " in " " + text_after
                    for keyword in set(sql_keywords) - set(accept_redundancy_sql_keywords)
            ):
                _logging["1.data_redundancy_denied"] += 1

        if prediction == ground_truth:
            _logging[f"1.exact_match__{hardness}"] += 1
            _logging[f"1.total_exact_match"] += 1
        else:
            if prediction.replace(" ", "") == ground_truth.replace(" ", ""):
                _logging[f"1.match_without_space"] += 1
            else:
                gt_tokens = dataset["test"][index]["query_toks"]
                if all(toks in prediction for toks in gt_tokens) and len(
                        prediction.replace(" ", "")
                ) == len(ground_truth.replace(" ", "")):
                    _logging[f"1.tokens_match__{hardness}"] += 1

        gt_tokens = extract_tokens(ground_truth)
        pr_tokens = extract_tokens(prediction)

        if sorted(gt_tokens) == sorted(pr_tokens):
            _logging["2.match_tokens_fake"] += 1

        for key, val in _logging.items():
            global_logging[key] += val

        predictions.append(prediction)

    print(global_logging)

    with open(f"benchmark/new_sql/pred_{version}.sql", "w", encoding='utf-8') as _f:
        _f.write("\n".join(predictions))


if __name__ == "__main__":
    main()
