import os
import argparse
from pathlib import Path
from ..core.utils import load_json, load_jsonl
from ..core.sql_execute import *
from ..core.config import Config


def find_latest_folder(directory):
    folders = [f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))]
    
    if not folders:
        return None  

    latest_folder = sorted(folders)[-1]
    return latest_folder


def compare_sql_results(db_path, sql1, sql2, timeout=10):
    result1 = execute_sql_with_timeout(db_path, sql1, timeout)
    if result1 is None:
        print("The execution of the gold SQL failed!!!! Please check the error message.")
        return 0

    result2 = execute_sql_with_timeout(db_path, sql2, timeout)
    if result2 is None:
        print("The generated SQL did not produce a result object.")
        return 0

    if result2.result_type == SQLExecutionResultType.TIMEOUT:
        print(f"The generated SQL timed out during execution. The error message is as follows:{result2.error_message}")
        return 0
    if result2.result_type == SQLExecutionResultType.ERROR:
        print(f"The generated SQL code failed to run and an error occurred. The error message is as follows:{result2.error_message}")
        return 0
    
    if result1.result is None:
        print("The execution of the Gold SQL was successful but the result was empty.")
        if result2.result is None:
            return 1
        return 0

    if result2.result is None:
        print("The generated SQL was executed successfully but the result was empty.")
        return 0

    if set(result1.result) == set(result2.result):
        return 1
    else:
        return 0



def compute_EX_source_difficulty_based(gold_sql_file: str, result_path: str = None):
    test_data = load_json(gold_sql_file)

    if result_path is None:
        intermediate_results_dir = Path("./results/intermediate_results")
        latest_folder = find_latest_folder(intermediate_results_dir)
        if latest_folder is None:
            raise ValueError("Cannot find the result folder.")
        result_path = intermediate_results_dir / latest_folder
    else:
        result_path = Path(result_path)

    generated_sql_file = result_path / "generated_sql_results.jsonl"
    generated_sql_data = load_jsonl(str(generated_sql_file))

    source_stats = {}

    global_total_cnt = 0
    global_correct_cnt = 0

    for item in test_data:
        question_id = item.get("question_id")
        source = item.get("source", "")
        difficulty = item.get("difficulty")
        db_id = item.get("db_id", "")
        db_folder = f"{source}_{db_id}"
        db_file = f"{db_id}.sqlite"
        db_path = str(Config().database_dir / db_folder / db_file)

        gold_sql = item.get("gold_SQL")
        generated_sql = ""

        for i in generated_sql_data:
            if i.get("question_id") == question_id:
                generated_sql = i.get("generated_sql")
                break

        curr_res = compare_sql_results(db_path, gold_sql, generated_sql)
        if source not in source_stats:
            source_stats[source] = {"total_cnt": 0, "correct_cnt": 0}

        if source == "bird_dev":
            if "difficulty_stats" not in source_stats[source]:
                source_stats[source]["difficulty_stats"] = {}
            if difficulty not in source_stats[source]["difficulty_stats"]:
                source_stats[source]["difficulty_stats"][difficulty] = {"total_cnt": 0, "correct_cnt": 0}

            source_stats[source]["difficulty_stats"][difficulty]["total_cnt"] += 1
            if curr_res:
                source_stats[source]["difficulty_stats"][difficulty]["correct_cnt"] += 1

        source_stats[source]["total_cnt"] += 1
        if curr_res:
            source_stats[source]["correct_cnt"] += 1

        global_total_cnt += 1
        if curr_res:
            global_correct_cnt += 1

    ex_results = {}
    for source, stats in source_stats.items():
        total_cnt = stats["total_cnt"]
        correct_cnt = stats["correct_cnt"]
        EX = round(correct_cnt * 100 / total_cnt, 2)
        ex_results[source] = {"total_EX": EX}

        if source == "bird_dev" and "difficulty_stats" in stats:
            ex_results[source]["difficulty_EX"] = {}
            for difficulty, diff_stats in stats["difficulty_stats"].items():
                diff_total_cnt = diff_stats["total_cnt"]
                diff_correct_cnt = diff_stats["correct_cnt"]
                diff_EX = round(diff_correct_cnt * 100 / diff_total_cnt, 2)
                ex_results[source]["difficulty_EX"][difficulty] = diff_EX
                print(f"Source: {source}, Difficulty: {difficulty}, total question num: {diff_total_cnt}, generated sql correct ex num{diff_correct_cnt}, EX value{diff_EX}")

        print(f"Source: {source}, total question num: {total_cnt}, generated sql correct ex num{correct_cnt}, EX value{EX}")

    overall_EX = round(global_correct_cnt * 100 / global_total_cnt, 2)
    ex_results["overall_EX"] = overall_EX
    print(f"Overall EX: total question num: {global_total_cnt}, generated sql correct ex num{global_correct_cnt}, EX value{overall_EX}")

    return ex_results

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gold_sql_file', type=str, required=True,
                       help='Path to file containing gold SQL queries')
    parser.add_argument('--result_path', type=str, 
                       default=None,
                       help='Path to result folder (default: latest result)')
    args = parser.parse_args()
    
    compute_EX_source_difficulty_based(args.gold_sql_file, args.result_path)

if __name__ == "__main__":
    main()