# db_utils.py
import os
import subprocess
import psycopg2
from psycopg2 import OperationalError
import psycopg2.extras
from psycopg2.pool import SimpleConnectionPool
import json
import pandas as pd
from tabulate import tabulate

USER_CONFIG = {
    
}
_postgresql_pools = {}

DEFAULT_DB_CONFIG = {
    "minconn": 1,
    "maxconn": 5,
    "user": "root",
    "password": "123123",
    "host": "localhost",
    "port": 5432,
}


def _get_or_init_pool(db_name):
    """
    Returns a connection pool for the given database name, creating one if it does not exist.
    """
    if db_name not in _postgresql_pools:
        config = DEFAULT_DB_CONFIG.copy()
        config.update({"dbname": db_name})
        _postgresql_pools[db_name] = SimpleConnectionPool(
            config["minconn"],
            config["maxconn"],
            dbname=config["dbname"],
            user=config["user"],
            password=config["password"],
            host=config["host"],
            port=config["port"],
        )
    return _postgresql_pools[db_name]


def perform_query_on_postgresql_databases(query, db_name, conn=None):
    """
    Executes the given query on the specified database, returns (result, conn).
    Automatically commits if the query is recognized as a write operation.
    """
    MAX_ROWS = 10000
    pool = _get_or_init_pool(db_name)
    need_to_put_back = False

    if conn is None:
        conn = pool.getconn()
        need_to_put_back = True

    cursor = conn.cursor()
    cursor.execute("SET statement_timeout = '60s';")  # 60s query timeout

    try:
        cursor.execute(query)
        lower_q = query.strip().lower()
        conn.commit()

        if lower_q.startswith("select") or lower_q.startswith("with"):
            # Fetch up to MAX_ROWS + 1 to see if there's an overflow
            rows = cursor.fetchmany(MAX_ROWS + 1)
            if len(rows) > MAX_ROWS:
                rows = rows[:MAX_ROWS]
            result = rows
        else:
            try:
                result = cursor.fetchall()
            except psycopg2.ProgrammingError:
                result = None

        return (result, conn)

    except Exception as e:
        conn.rollback()
        raise e
    finally:
        cursor.close()
        if need_to_put_back:
            # If you only need a single query, you could return it right away:
            # But usually, we keep the same conn for subsequent queries, so do nothing.
            # If you truly do not want to reuse it, uncomment below:
            # pool.putconn(conn)
            pass


def close_postgresql_connection(db_name, conn):
    """
    Release a connection back to the pool when you are done with it.
    """
    if db_name in _postgresql_pools:
        pool = _postgresql_pools[db_name]
        pool.putconn(conn)


def close_all_postgresql_pools():
    """
    Closes all connections in all pools (e.g., at application shutdown).
    """
    for pool in _postgresql_pools.values():
        pool.closeall()
    _postgresql_pools.clear()


def close_postgresql_pool(db_name):
    """
    Close the pool for a specific db_name and remove its reference.
    """
    if db_name in _postgresql_pools:
        pool = _postgresql_pools.pop(db_name)
        pool.closeall()


def get_connection_for_phase(db_name, logger):
    """
    Acquire a new connection (borrowed from the connection pool) for a specific phase.
    """
    logger.info(f"Acquiring dedicated connection for phase on db: {db_name}")
    result, conn = perform_query_on_postgresql_databases("SELECT 1", db_name, conn=None)
    return conn


def reset_and_restore_database(db_name):
    """
    Resets the database by dropping it and re-creating it from its template.
    1) close pool
    2) terminate connections
    3) dropdb
    4) createdb --template ...
    """
    pg_host = "localhost"
    pg_port = 5432
    pg_user = "root"

    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = '123123'
    base_db_name = db_name.split("_process_")[0]
    template_db_name = f"{base_db_name}_template"



    close_postgresql_pool(db_name)

    # 2) Terminate existing connections
    terminate_command = [
        "psql",
        "-h",
        pg_host,
        "-p",
        str(pg_port),
        "-U",
        pg_user,
        "-d",
        "postgres",
        "-c",
        f"""
        SELECT pg_terminate_backend(pid)
        FROM pg_stat_activity
        WHERE datname = '{db_name}' AND pid <> pg_backend_pid();
        """,
    ]
    subprocess.run(
        terminate_command,
        check=True,
        env=env_vars,
        timeout=60,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    # 3) dropdb
    drop_command = [
        "dropdb",
        "--if-exists",
        "-h",
        pg_host,
        "-p",
        str(pg_port),
        "-U",
        pg_user,
        db_name,
    ]
    subprocess.run(
        drop_command,
        check=True,
        env=env_vars,
        timeout=60,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    

    # 4) createdb --template=xxx_template
    create_command = [
        "createdb",
        "-h",
        pg_host,
        "-p",
        str(pg_port),
        "-U",
        pg_user,
        db_name,
        "--template",
        template_db_name,
    ]
    subprocess.run(
        create_command,
        check=True,
        env=env_vars,
        timeout=60,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    
    print(f'Reset {db_name} sucessfully')

def get_conn(db_name):
    """
    Returns a psycopg2 connection for the given db_name from the connection pool.
    The caller is responsible for releasing it by calling close_postgresql_connection(db_name, conn).
    """
    pool = _get_or_init_pool(db_name)
    conn = pool.getconn()
    return conn

def create_db_copies(db_name):
    pg_host = "localhost"
    pg_port = 5432
    pg_user = "root"
    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = '123123'
    
    base_template = f"{db_name}_template"
    ephemeral_name = f"{db_name}_copy"
    
    terminate_command = [
        "psql",
        "-h",
        pg_host,
        "-p",
        str(pg_port),
        "-U",
        pg_user,
        "-d",
        "postgres",
        "-c",
        f"""
        SELECT pg_terminate_backend(pid)
        FROM pg_stat_activity
        WHERE datname = '{db_name}' AND pid <> pg_backend_pid();
        """,
    ]
    subprocess.run(
        terminate_command,
        check=True,
        env=env_vars,
        timeout=60,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    
    drop_cmd = [
                "dropdb",
                "--if-exists",
                "-h",
                pg_host,
                "-p",
                str(pg_port),
                "-U",
                pg_user,
                ephemeral_name,
            ]
    subprocess.run(
        drop_cmd,
        check=False,
        env=env_vars,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    print(f"Drop table {ephemeral_name} sucessfully")

    # createdb
    # create_cmd = [
    #     "createdb",
    #     "-h",
    #     pg_host,
    #     "-p",
    #     str(pg_port),
    #     "-U",
    #     pg_user,
    #     ephemeral_name,
    #     "--template",
    #     base_template,
    # ]
    # try:
    #     subprocess.run(
    #         create_cmd,
    #         check=True,
    #         env=env_vars,
    #         stdout=subprocess.DEVNULL,
    #         stderr=subprocess.DEVNULL,
    #     )
    #     return ephemeral_name, base_template
    # except subprocess.CalledProcessError as e:
    #     print(f"Error executing command: {' '.join(create_cmd)}")
    #     print(f"Error Info: {e}")
        


def drop_db_copies(ephemeral_name, base_template):
    pg_host = "localhost"
    pg_port = 5432
    pg_user = "root"
    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = '123123'
    
    drop_cmd = [
                "dropdb",
                "--if-exists",
                "-h",
                pg_host,
                "-p",
                str(pg_port),
                "-U",
                pg_user,
                ephemeral_name,
            ]
    try:
        subprocess.run(
            drop_cmd,
            check=True,
            env=env_vars,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
    except subprocess.CalledProcessError as e:
        print(f"Failed to drop ephemeral db {ephemeral_name}: {e}")


def create_ephemeral_db_copies(base_db_names, num_copies):
    """
    For each base database in base_db_names, create `num_copies` ephemeral DB copies
    from base_db_template. Return a dict: {base_db: [ephemeral1, ephemeral2, ...], ...}
    """
    pg_host = "localhost"
    pg_port = 5432
    pg_user = "root"
    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = '123123'

    ephemeral_db_pool = {}

    for base_db in base_db_names:
        base_template = f"{base_db}_template"
        ephemeral_db_pool[base_db] = []

        for i in range(1, num_copies + 1):
            ephemeral_name = f"{base_db}_process_{i}"
            # If it already exists, drop it first
            drop_cmd = [
                "dropdb",
                "--if-exists",
                "-h",
                pg_host,
                "-p",
                str(pg_port),
                "-U",
                pg_user,
                ephemeral_name,
            ]
            subprocess.run(
                drop_cmd,
                check=False,
                env=env_vars,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
            )

            # createdb
            create_cmd = [
                "createdb",
                "-h",
                pg_host,
                "-p",
                str(pg_port),
                "-U",
                pg_user,
                ephemeral_name,
                "--template",
                base_template,
            ]
            
            subprocess.run(
                create_cmd,
                check=True,
                env=env_vars,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
            )

            ephemeral_db_pool[base_db].append(ephemeral_name)
    

    return ephemeral_db_pool


def drop_ephemeral_dbs(ephemeral_db_pool_dict, pg_password, logger):
    """
    Delete all ephemeral databases created during the script execution.
    """
    pg_host = "localhost"
    pg_port = 5432
    pg_user = "root"
    env_vars = os.environ.copy()
    env_vars["PGPASSWORD"] = pg_password

    logger.info("=== Cleaning up ephemeral databases ===")
    for base_db, ephemeral_list in ephemeral_db_pool_dict.items():
        for ephemeral_db in ephemeral_list:
            logger.info(f"Dropping ephemeral db: {ephemeral_db}")
            drop_cmd = [
                "dropdb",
                "--if-exists",
                "-h",
                pg_host,
                "-p",
                str(pg_port),
                "-U",
                pg_user,
                ephemeral_db,
            ]
            try:
                subprocess.run(
                    drop_cmd,
                    check=True,
                    env=env_vars,
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.DEVNULL,
                )
            except subprocess.CalledProcessError as e:
                logger.error(f"Failed to drop ephemeral db {ephemeral_db}: {e}")


def execute_queries(queries, db_name, conn, logger=None, section_title=""):
    """
    Execute a list of queries using the SAME connection (conn).
    Returns (query_result, execution_error_flag, timeout_flag).
    Once the first error occurs, we break out and return immediately.
    """
    query_result = None
    execution_error = False
    timeout_error = False
    error_message = ""
    for i, query in enumerate(queries):
        try:
            logger.info(f"Executing query {i+1}/{len(queries)}: {query}")
            query_result, conn = perform_query_on_postgresql_databases(
                query, db_name, conn=conn
            )
            # logger.info(f"Query result: {query_result[:3]}")

        except psycopg2.errors.QueryCanceled as e:
            # Timeout error
            logger.error(f"Timeout error executing query {i+1}: {e}")
            timeout_error = True
            error_message = e
            break

        except OperationalError as e:
            # Operational errors (e.g., server not available, etc.)
            logger.error(f"OperationalError executing query {i+1}: {e}")
            execution_error = True
            error_message = e
            break

        except psycopg2.Error as e:
            # Other psycopg2 errors (e.g., syntax errors, constraint violations)
            logger.error(f"psycopg2 Error executing query {i+1}: {e}")
            execution_error = True
            error_message = e
            break

        except Exception as e:
            # Any other generic error
            logger.error(f"Generic error executing query {i+1}: {e}")
            execution_error = True
            error_message = e
            break

        finally:
            logger.info(f"[{section_title}] DB: {db_name}, conn info: {conn}")

        # If an error is flagged, don't continue subsequent queries
        if execution_error or timeout_error:
            break

    return query_result, execution_error, timeout_error,error_message




def connect_postgresql(db_name):
    try:
        conn = psycopg2.connect(
            dbname=db_name,
            user="root",       
            password="123123", 
            host="localhost", 
            port="5432",      
        )
        # print(f"Connected to the database: {db_name}")
        return conn
    except psycopg2.Error as e:
        print(f"Error connecting to database {db_name}: {e}")
        return None

def run_queries(queries, db_name):
    conn = connect_postgresql(db_name)
    result, _, _, _ = execute_queries(queries, db_name, conn)
    reset_and_restore_database(db_name)
    if conn: conn.close()
    return result

def get_qep(queries, db_name):
    conn = connect_postgresql(db_name)
    cur = conn.cursor()
    plans = []
    for sql in queries:
        # explain_sql = f"EXPLAIN (COSTS OFF) {sql}"
        explain_sql = f"EXPLAIN (FORMAT JSON) {sql}"
        try:
            # cur.execute(explain_sql)
            # plan_text = cur.fetchall() 
            # execution_plan = "\n".join([line[0] for line in plan_text])
            # plans.append(execution_plan)
            cur.execute(explain_sql)
            plan_text = cur.fetchone()[0]
            formatted_json = json.dumps(plan_text, indent=2)
            plans.append(formatted_json)
            
        except psycopg2.Error as e:
            plans.append(sql)
            
    if cur: cur.close()
    if conn: conn.close()
    return '\n\n'.join(plans)

def get_qep_tmp(queries, db_name):
    conn = connect_postgresql(db_name)
    cur = conn.cursor()
    plans = []
    for sql in queries:
        explain_sql = f"EXPLAIN (FORMAT JSON) {sql}"
        cur.execute(explain_sql)
        plan_result = cur.fetchone()
        print(plan_result)
        
def get_column_table(question_info, generated_sql, db_schema_dic):
    db_id = question_info['db_id']
    table_name_list = list(db_schema_dic[db_id].keys())
    lower_table_names_list = [t.lower() for t in table_name_list]
    sql_text = '\n'.join(generated_sql).lower()
    filtered_table = []
    final_results = []
    for table in lower_table_names_list:
        if table in sql_text:
            filtered_table.append(table)
    filtered_tables = list(set(filtered_table))
    
    for table in filtered_tables:
        columns = db_schema_dic[db_id][table]
        lower_column_name_list = [col.lower() for col in columns]
        for column in lower_column_name_list:
            if column in sql_text:
                if [table, column] not in final_results:
                    final_results.append([table, column])
    return final_results

def get_column_table_from_detail_schema(generated_sql, detail_schema):
    db_id = detail_schema['db_id']
    tab_col_dic = detail_schema['detail_info']['table_columns']
    table_name_list = list(tab_col_dic.keys())
    # print(table_name_list)
    lower_table_names_list = [t.lower() for t in table_name_list]
    if type(generated_sql) == list:
        sql_text = '\n'.join(generated_sql).lower()
    else: sql_text = generated_sql
    filtered_table = []
    final_results = []
    for idx, table in enumerate(lower_table_names_list):
        if table in sql_text:
            filtered_table.append(table_name_list[idx])
            
    filtered_tables = list(set(filtered_table))
    for table in filtered_tables:
        columns = tab_col_dic[table]
        lower_column_name_list = [col.lower() for col in columns]
        for idx, column in enumerate(lower_column_name_list):
            if column in sql_text:
                if [table, columns[idx]] not in final_results:
                    final_results.append([table, columns[idx]])
    return final_results


def get_column_table_from_full_schema(idx, generated_sql, full_schema):
    info = full_schema[idx]['schema_info']
    table_name_list = list(info.keys())
    # print(table_name_list)
    lower_table_names_list = [t.lower() for t in table_name_list]
    if type(generated_sql) == list:
        sql_text = '\n'.join(generated_sql).lower()
    else: sql_text = generated_sql.lower()
    filtered_table = []
    final_results = []
    for idx, table in enumerate(lower_table_names_list):
        if table in sql_text:
            filtered_table.append(table_name_list[idx])
    
    filtered_tables = list(set(filtered_table))
    for table in filtered_tables:
        column_info = info[table]['columns_info']
        column_name_list = [col[0] for col in column_info]
        lower_column_name_list = [col.lower() for col in column_name_list]
        for idx, column in enumerate(lower_column_name_list):
            if column in sql_text:
                if [table, column_name_list[idx]] not in final_results:
                    final_results.append([table, column_name_list[idx]])
    return final_results


def format_postgresql_create_table(
    table_name,
    columns_info,
    primary_keys,
    foreign_keys
):
    """
    Format PostgreSQL CREATE TABLE statement, optionally filtering out columns not in documented_columns.

    Args:
        table_name (str): Name of the table.
        columns_info (list): List of tuples (column_name, data_type, is_nullable, column_default).
        primary_keys (list): List of primary key columns.
        foreign_keys (list): List of (fk_column, ref_table, ref_column).
        documented_columns (list, optional): Columns that should be retained in the schema.

    Returns:
        str: Formatted CREATE TABLE statement.
    """
    

    lines = [f'CREATE TABLE "{table_name}" (']
    filtered_columns_info = []

    for i, col in enumerate(columns_info):
        column_name, data_type, is_nullable, column_default = col
        null_status = "NULL" if is_nullable == "YES" else "NOT NULL"
        default = f"DEFAULT {column_default}" if column_default else ""
        column_line = f"    {column_name} {data_type} {null_status} {default}".strip()
        
        if i < len(filtered_columns_info) - 1 or primary_keys or foreign_keys:
            column_line += ","
        lines.append(column_line)

    
    if primary_keys:
        pk_line = f"    PRIMARY KEY ({', '.join(primary_keys)})"
        if foreign_keys:
            pk_line += ","
        lines.append(pk_line)

    
    for fk in foreign_keys:
        fk_column, ref_table, ref_column = fk
        fk_line = f"    FOREIGN KEY ({fk_column}) REFERENCES {ref_table}({ref_column})"
        if fk != foreign_keys[-1]:
            fk_line += ","
        lines.append(fk_line)

    lines.append(");")
    return "\n".join(lines)


def get_pretty_schema(idx,  full_schema, column_meaning_dict, selected_schema = None, show_data = True, show_meaning = True):
    if show_meaning:
        column_meaning_dict = dict(zip([k.lower() for k in column_meaning_dict.keys()], list(column_meaning_dict.values())))
    db_name = full_schema[idx]['db_id']
    info = full_schema[idx]['schema_info']
    if selected_schema:
        # selected_tables = list(set([t[0] for t in selected_schema]))
        selected_tables = list(dict.fromkeys(t[0] for t in selected_schema))
        selected_columns = list(set([t[1] for t in selected_schema]))
    else:
        selected_tables = list(info.keys())
        selected_columns = []
        for tbl in selected_tables:
            cols = [c[0] for c in info[tbl]['columns_info']]
            selected_columns += cols
        selected_columns = list(set(selected_columns))
        
        
    schemas = []
    for tbl in selected_tables:
        columns_info = info[tbl]['columns_info']
        selected_column_info = [i for i in columns_info if i[0] in selected_columns]
        primary_keys = info[tbl]['primary_keys']
        selected_primary_keys = [col for col in primary_keys if col in selected_columns]
        foreign_keys = info[tbl]['foreign_keys']
        selected_foreign_keys = [tup for tup in foreign_keys if (tup[0] in selected_columns and tup[1] in selected_tables and tup[2] in selected_columns)]
        pretty_schema = format_postgresql_create_table(
            tbl,
            selected_column_info,
            selected_primary_keys,
            selected_foreign_keys,
        )
        schema_prompt = pretty_schema
        
        if show_meaning:
            column_meaning_list = []
            selected_columns_for_table = [info[0] for info in selected_column_info]
            for col in selected_columns_for_table:
                meaning_key = f"{db_name}|{tbl}|{col}".lower()
                meaning = column_meaning_dict.get(meaning_key)
                if meaning: 
                    meaning = meaning.replace('#','')
                    column_meaning_list.append(f"{tbl}.{col}: {meaning}")
            if len(column_meaning_list) > 0:
                column_meanings_str = "Column Meanings:\n" + "\n".join(column_meaning_list)
            else: 
                column_meanings_str = ''
            schema_prompt += f'\n\n{column_meanings_str}'
            
        if show_data:
            example_data = info[tbl]['example_data']
            example_col_names = info[tbl]['example_data_col_names']
            processed_example_row = []
            if len(example_data) != 0:
                for data in example_data:
                    processed_data = data.split('|')
                    processed_data_zip = list(zip(example_col_names, processed_data))
                    selected_data = [d for d in processed_data_zip if d[0] in selected_columns]
                    example_rows = [d[1] for d in selected_data]
                    example_cols = [d[0] for d in selected_data]
                    processed_example_row.append(example_rows)
                example_df = pd.DataFrame(processed_example_row, columns = example_cols)
                example_data_str = (
                    "First 3 rows:\n"
                    + tabulate(
                        example_df, headers="keys", tablefmt="simple", showindex=False
                    )
                    + "\n...\n"
                )
                # print(example_data_str)
            else:
                example_data_str = ''
            schema_prompt += f'\n\n{example_data_str}'
        schema_prompt += '\n'
        schemas.append(schema_prompt)
    final_schema_dll = "\n\n".join(schemas)
    return final_schema_dll

def clean_execution_plan(plan):
    """
    Recursively removes metadata keys using pop() and keeps only SQL-relevant details.
    """
    if isinstance(plan, list):
        return [clean_execution_plan(subplan) for subplan in plan]

    if isinstance(plan, dict):
        # List of keys to remove (metadata)
        metadata_keys = [
            "Parallel Aware", "Async Capable", "Startup Cost", "Total Cost",
            "Plan Rows", "Plan Width", "Planned Partitions", "Parent Relationship",
            "Strategy", "Inner Unique", "Scan Direction", "Alias"
        ]

        # Pop out unwanted metadata keys
        for key in metadata_keys:
            plan.pop(key, None)

        # Recursively clean nested plans
        if "Plans" in plan:
            plan["Plans"] = clean_execution_plan(plan["Plans"])

        return plan

    return plan  # Return unchanged if not a dict or list

def get_clean_qep(sql_query, db_name):
    """
    Runs EXPLAIN (FORMAT JSON) and returns a cleaned-up execution plan without metadata.
    """
    try:
        conn_params = {
            "dbname": f"{db_name}",
            "user":"root",       
            "password":"123123", 
            "host":"localhost", 
            "port":"5432", 
        }
        conn = psycopg2.connect(**conn_params)
        cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)

        try:
            # Run EXPLAIN to get the raw execution plan
            cur.execute(f"EXPLAIN (FORMAT JSON) {sql_query}")
            plan_result = cur.fetchone()

            if plan_result and 'QUERY PLAN' in plan_result:
                raw_plan = plan_result['QUERY PLAN'][0].get('Plan', {})

                # Clean the execution plan by removing metadata
                cleaned_plan = clean_execution_plan(raw_plan)
                response = {"Original SQL": sql_query, "Execution Plan": cleaned_plan}
            else:
                response = {"Original SQL": sql_query, "Execution Plan": None}

        except psycopg2.Error:
            response = {"Original SQL": sql_query, "Execution Plan": None}

        finally:
            cur.close()
            conn.close()

        return json.dumps(response, indent=4)

    except psycopg2.Error as e:
        return json.dumps({"error": str(e)}, indent=4)
    
def generate_execution_plan_text(sql_queries, db_name):
    """
    Generates execution plans for each SQL query in the list and returns them as formatted text.
    """
    results = []
    for i, sql in enumerate(sql_queries, start=1):
        plan_result = get_clean_qep(sql, db_name)
        results.append(f"# Execution Plan for SQL {i}\n{plan_result}")

    return "\n\n".join(results)


def get_target_schema(idx, target_sql, gold_sql, full_schema):
    gold_schema = get_column_table_from_full_schema(idx, gold_sql, full_schema)
    gold_schema = [f"{t}.{c}" for t,c in gold_schema]
    target_schema = get_column_table_from_full_schema(idx, target_sql, full_schema)
    target_schema = [f"{t}.{c}" for t,c in target_schema]
    return gold_schema, target_schema

def compute_ex(ground_truth, sl_res):
    total_question_num = len(ground_truth)
    correct_num = 0
    for idx, table_column in enumerate(ground_truth):
        correct_flag = True
        for otn_ocn in table_column:
            if otn_ocn not in sl_res[idx]:
                correct_flag = False
        if correct_flag:
            correct_num += 1
    return correct_num / total_question_num

def compute_recall(ground_truth, sl_res):
    recall_list = []
    for idx, table_column in enumerate(ground_truth):
        correct = 0
        for ct_tuple in table_column:
            if ct_tuple in sl_res[idx]:
                correct+= 1
        recall = correct / len(table_column) if len(table_column) != 0 else 0
        recall_list.append(recall)
    return sum(recall_list) / len(recall_list)

def compute_precision(ground_truth, sl_res):
    precision_list = []
    for idx, sl_value in enumerate(sl_res):
        count = 0
        for ct_tuple in sl_value:
            if ct_tuple in ground_truth[idx]:
                count += 1
        precision = count / len(sl_value) if len(sl_value) != 0 else 1
        precision_list.append(precision)
    return sum(precision_list)/len(precision_list)

def compute_metrics(ground_truth, sl_res):
    ex = compute_ex(ground_truth, sl_res)
    recall = compute_recall(ground_truth, sl_res)
    precision = compute_precision(ground_truth, sl_res)
    f1 = 2*precision*recall/(precision+recall)
    # f1 = 0
    print(f"ex: {ex}, recall: {recall}, precision: {precision}, f1: {f1}")
    
def compute_schema_metrics(data_jsonl, target_sql_list, full_schema_jsonl):
    gold_schema_list, target_schema_list = [], []
    for idx, content in enumerate(data_jsonl):
        gold_sql = content['sol_sql']
        target_sql = target_sql_list[idx]
        gold_schema, target_schema = get_target_schema(idx, target_sql, gold_sql, full_schema_jsonl)
        gold_schema_list.append(gold_schema)
        target_schema_list.append(target_schema)
    compute_metrics(gold_schema_list, target_schema_list)

if __name__ == '__main__':
    pass