
from utils.db_utils import connect_postgresql, reset_and_restore_database
from utils.util import load_jsonl, save_jsonl
import os
import json
import pandas as pd
from tabulate import tabulate
import tqdm

def generate_schema_info(db_name, preprocess_sql):
    schemas_list = []

    db = connect_postgresql(db_name)
    cursor = db.cursor()
    try:
        if preprocess_sql:
            for pre_sql in preprocess_sql:
                cursor.execute(pre_sql)
    except Exception as e:
        print(f"Error executing preprocess sql on {db_name}: {preprocess_sql}")
        print(f"Exception: {e}")
        db.rollback()
    
    query = """
        SELECT table_name
        FROM information_schema.tables
        WHERE table_catalog = %s
          AND table_schema NOT IN ('information_schema', 'pg_catalog')
          AND table_type = 'BASE TABLE'
        ORDER BY table_name;
    """

    cursor.execute(query, (db_name,))

    all_tables = [row[0] for row in cursor.fetchall() if row[0] != 'sqlite_sequence']
    
    db_schema_dic = {}
    # print(all_tables)
    for tbl in all_tables:
        schema_dic = {}
        cursor.execute(
            """
            SELECT column_name, data_type, is_nullable, column_default
            FROM information_schema.columns
            WHERE table_name = %s
            ORDER BY ordinal_position;
            """,
            (tbl,),
        )
        columns_info = cursor.fetchall()
        schema_dic['columns_info'] = columns_info
        
        cursor.execute(
            """
            SELECT a.attname
            FROM   pg_index i
            JOIN   pg_attribute a 
                ON a.attrelid = i.indrelid
                AND a.attnum = ANY(i.indkey)
            WHERE  i.indrelid = %s::regclass
                AND i.indisprimary;
            """,
            (f'"{tbl}"',),
        )
        
        primary_keys = [row[0] for row in cursor.fetchall()]
        schema_dic['primary_keys'] = primary_keys
        
        cursor.execute(
            """
            SELECT
                kcu.column_name,
                ccu.table_name AS foreign_table_name,
                ccu.column_name AS foreign_column_name
            FROM
                information_schema.table_constraints AS tc
                JOIN information_schema.key_column_usage AS kcu
                  ON tc.constraint_name = kcu.constraint_name
                  AND tc.table_schema = kcu.table_schema
                JOIN information_schema.constraint_column_usage AS ccu
                  ON ccu.constraint_name = tc.constraint_name
                  AND ccu.table_schema = tc.table_schema
            WHERE tc.constraint_type = 'FOREIGN KEY'
              AND tc.table_name = %s;
            """,
            (tbl,),
        )
        foreign_keys = cursor.fetchall()
        schema_dic['foreign_keys'] = foreign_keys
        
        selected_columns_str = ','.join([f'"{info[0]}"' for info in columns_info])
        cursor.execute(f'SELECT {selected_columns_str} FROM "{tbl}" LIMIT 3;')
        example_rows = cursor.fetchall()
        example_data = ['|'.join(str(cell) for cell in row) for row in example_rows]
        col_names = [desc[0] for desc in cursor.description]
        # if example_rows:
            # example_df = pd.DataFrame(example_rows, columns = col_names)
            # example_data_str = (
            #     "First 3 rows:\n"
            #     + tabulate(
            #         example_df, headers="keys", tablefmt="simple", showindex=False
            #     )
            #     + "\n...\n"
            # )
            
        example_data = example_data if example_rows else ''
        # print(example_data)
        schema_dic['example_data'] = example_data
        schema_dic['example_data_col_names' ] = col_names
        db_schema_dic[tbl] = schema_dic
    cursor.close()
    db.close()

    return db_schema_dic

def get_schema(data_jsonl):
    schema_list = []
    for idx, content in enumerate(data_jsonl):
        preprocess_sql = content.get('preprocess_sql')
        db_name = content['db_id']
        schema_dic = generate_schema_info(db_name, preprocess_sql)
        schema_list.append({'idx': idx, 'db_id': db_name, 'schema_info': schema_dic})
    return schema_list



