#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse
import multiprocessing as mp
import json
import pdb
from func_timeout import func_timeout, FunctionTimedOut
from eval_sql_utils import (
    load_jsonl,
    execute_sql,
    sort_results,
    print_data,
    update_require_opt,
    save_jsonl
)
from tqdm import tqdm

def result_callback(result):
    exec_result.append(result)

def calculate_ex(predicted_res, ground_truth_res):
    res = 0
    if set(predicted_res) == set(ground_truth_res):
        res = 1
    return res

def execute_model(
    predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect
):
    try:
        res = func_timeout(
            meta_time_out,
            execute_sql,
            args=(predicted_sql, ground_truth, db_place, sql_dialect, calculate_ex),
        )
    except KeyboardInterrupt:
        sys.exit(0)
    except FunctionTimedOut as time_e:
        res = 0
        result = {"sql_idx": idx, "res": res, "error": str(time_e)}
        return result
    except Exception as e:
        # result = [(f"error",)]  # possibly len(query) > 512 or not executable
        res = 0
        result = {"sql_idx": idx, "res": res, "error": str(e)}
        return result
    result = {"sql_idx": idx, "res": res}
    return result

def run_sqls_parallel(
    sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite"
):
    pool = mp.Pool(processes=num_cpus)
    for i, sql_pair in enumerate(tqdm(sqls, desc="Processing SQL queries")):

        predicted_sql, ground_truth = sql_pair
        pool.apply_async(
            execute_model,
            args=(
                predicted_sql,
                ground_truth,
                db_places[i],
                i,
                meta_time_out,
                sql_dialect,
            ),
            callback=result_callback,
        )
    pool.close()
    pool.join()

def compute_ex(exec_results):
    num_queries = len(exec_results)
    results = [res["res"] for res in exec_results]
    acc = sum(results) / num_queries if num_queries > 0 else 0
    return acc * 100

def compute_acc_by_diff(exec_results, diff_json_path):
    num_queries = len(exec_results)
    results = [res["res"] for res in exec_results]
    contents = load_jsonl(diff_json_path)
    simple_results, moderate_results, challenging_results = [], [], []

    for i, content in enumerate(contents):
        if content["difficulty"] == "simple":
            simple_results.append(exec_results[i])

        if content["difficulty"] == "moderate":
            moderate_results.append(exec_results[i])

        if content["difficulty"] == "challenging":
            try:
                challenging_results.append(exec_results[i])
            except:
                print(i)

    simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) if len(simple_results) > 0 else 0
    moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) if len(moderate_results) > 0 else 0
    challenging_acc = sum([res["res"] for res in challenging_results]) / len(challenging_results) if len(challenging_results) > 0 else 0
    all_acc = sum(results) / num_queries if num_queries > 0 else 0
    count_lists = [
        len(simple_results),
        len(moderate_results),
        len(challenging_results),
        num_queries,
    ]
    return (
        simple_acc * 100,
        moderate_acc * 100,
        challenging_acc * 100,
        all_acc * 100,
        count_lists,
    )

def package_sqls_from_jsonl(jsonl_path, mode='baseline'):
    data = load_jsonl(jsonl_path)
    if mode == "baseline":
        predicted_sqls = [entry["first_code"] for entry in data]
    elif mode == "training":
        predicted_sqls = [entry["final_code"] for entry in data]
    elif mode == "inference":
        predicted_sqls = [entry["rag_final_code"] for entry in data]
        
    ground_truth_sqls = [entry["sql"] for entry in data]
    db_paths = [entry["db_path"] for entry in data]  # Assuming each entry contains the db path
    return data, predicted_sqls, ground_truth_sqls, db_paths

if __name__ == "__main__":
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument(
        "--predicted_sql_path", type=str, required=True, default=""
    )
    args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
    args_parser.add_argument("--db_root_path", type=str, required=True, default="")
    args_parser.add_argument("--num_cpus", type=int, default=1)
    args_parser.add_argument("--meta_time_out", type=float, default=30.0)
    args_parser.add_argument("--mode_gt", type=str, default="gt")
    args_parser.add_argument("--mode_predict", type=str, default="gpt")
    args_parser.add_argument("--difficulty", type=str, default="simple")
    args_parser.add_argument("--engine", type=str, default="")
    args_parser.add_argument("--sql_dialect", type=str, default="SQLite")
    args_parser.add_argument("--output_path", type=str, default="SQLite")
    args_parser.add_argument("--eval_mode", type=str, default="baseline")
    args = args_parser.parse_args()
    exec_result = []

    data, pred_queries, gt_queries, db_paths = package_sqls_from_jsonl(
        args.predicted_sql_path, args.eval_mode
    )
    query_pairs = list(zip(pred_queries, gt_queries))

    run_sqls_parallel(
        query_pairs,
        db_places=db_paths,
        num_cpus=args.num_cpus,
        meta_time_out=args.meta_time_out,
        sql_dialect=args.sql_dialect,
    )
    exec_result = sort_results(exec_result)
    print("start calculate")
    # pdb.set_trace()
    simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
        exec_result, args.predicted_sql_path
    )
    ex_acc = compute_ex(exec_result)
    score_lists = [simple_acc, moderate_acc, challenging_acc, acc, ex_acc]
    print(f"EX for {args.engine} on {args.sql_dialect} set")
    print("start calculate")
    print_data(score_lists, count_lists, metric="EX")
    print(
        "==========================================================================================="
    )
    print(f"Finished EX evaluation for {args.engine} on {args.sql_dialect} set")
    print("\n\n")

    # Update 'require_opt' field based on execution results
    data = update_require_opt(data, exec_result)
    
    # Save the updated data to a new jsonl file
    save_jsonl(data, args.output_path)