import re
import json
import pandas as pd
from typing import Dict, List, Tuple
import psycopg2
from tabulate import tabulate

# You can remove or adapt these if needed
# from tabulate import tabulate


def additional_filter(curr_table_name, filtered_example_rows):
    # also filter out the column with name shoton or goal if the table name is `Match`
    if curr_table_name == "match":
        filtered_example_rows = [
            {
                col: row[col]
                for col in row
                if col
                not in [
                    "shoton",
                    "goal",
                    "shotoff",
                    "foulcommit",
                    "card",
                    "cross",
                    "corner",
                    "possession",
                ]
            }
            for row in filtered_example_rows
        ]
    if curr_table_name == "posts":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["body"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "posthistory":
        filtered_example_rows = [
            {
                col: row[col]
                for col in row
                if col
                not in [
                    "flavortext",
                    "originaltext",
                    "printings",
                    "purchaseurls",
                    "text",
                ]
            }
            for row in filtered_example_rows
        ]
    if curr_table_name == "cards":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["text"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "sets":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["booster"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "rulings":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["text"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "comments":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["text"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "users":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["aboutme", "profileimageurl"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "event":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["notes"]}
            for row in filtered_example_rows
        ]
    if curr_table_name == "foreign_data":
        filtered_example_rows = [
            {col: row[col] for col in row if col not in ["flavortext", "text", "type"]}
            for row in filtered_example_rows
        ]
    return filtered_example_rows


def get_user_defined_tables(cursor):
    """
    Return a list of (schema_name, table_name) for all user-defined tables
    in the DB (i.e., excluding pg_catalog and information_schema).
    """
    cursor.execute(
        """
        SELECT schemaname, tablename
        FROM pg_catalog.pg_tables
        WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
        ORDER BY schemaname, tablename;
    """
    )

    rows = cursor.fetchall()
    return rows


def generate_schema_prompt_postgresql(
    db_name: str,
    conn,
    table_name: str = None,
    num_rows: int = 3,
):
    """
    Generate a textual schema prompt for *all* user-defined tables in the DB
    (i.e. not in pg_catalog or information_schema). This includes triggers,
    sample rows, etc. If `table_name` is specified, only that single table
    is shown (case-insensitive on the table name, but the schema must match).

    Args:
        db_name (str): ephemeral DB name.
        conn: an active psycopg2 (or similar) connection object to `db_name`.
        table_name (str): if provided, only show this (case-insensitive) table name.
        num_rows (int): how many sample rows to fetch from each table.

    Returns:
        str: A concatenation of "CREATE TABLE" statements for each table,
             plus triggers and sample data sections.
    """
    cursor = conn.cursor()
    all_schema_table_pairs = get_user_defined_tables(cursor)
    schemas_output = []

    for schema_name, tbl_name in all_schema_table_pairs:
        # If a specific table_name is requested, skip others
        # We only match on the "table" portion, ignoring the schema
        if table_name and tbl_name.lower() != table_name.lower():
            continue

        # 1) Gather column info
        cursor.execute(
            """
            SELECT column_name, data_type, is_nullable, column_default
            FROM information_schema.columns
            WHERE table_schema = %s
              AND table_name = %s
            ORDER BY ordinal_position;
            """,
            (schema_name, tbl_name),
        )
        columns_info = cursor.fetchall()

        # 2) Primary keys
        #    We can reference the schema+table by "schema_name.tbl_name"::regclass
        #    quoting if needed
        schema_dot_table = f'"{schema_name}"."{tbl_name}"'
        cursor.execute(
            f"""
            SELECT a.attname
            FROM pg_index i
            JOIN pg_attribute a
                ON a.attrelid = i.indrelid
            AND a.attnum = ANY(i.indkey)
            WHERE i.indrelid = '{schema_dot_table}'::regclass
            AND i.indisprimary;
        """
        )
        primary_keys = [row[0] for row in cursor.fetchall()]

        # 3) Foreign keys
        cursor.execute(
            """
            SELECT
                kcu.column_name,
                ccu.table_name AS foreign_table_name,
                ccu.column_name AS foreign_column_name
            FROM
                information_schema.table_constraints AS tc
                JOIN information_schema.key_column_usage AS kcu
                  ON tc.constraint_name = kcu.constraint_name
                  AND tc.table_schema = kcu.table_schema
                JOIN information_schema.constraint_column_usage AS ccu
                  ON ccu.constraint_name = tc.constraint_name
                  AND ccu.table_schema = tc.table_schema
            WHERE tc.constraint_type = 'FOREIGN KEY'
              AND tc.table_name = %s
              AND tc.table_schema = %s;
            """,
            (tbl_name, schema_name),
        )
        foreign_keys = cursor.fetchall()

        # 4) Format CREATE TABLE statement (all columns)
        create_table_sql = format_postgresql_create_table(
            table_name=f"{tbl_name}",  # or just tbl_name if you prefer
            columns_info=columns_info,
            primary_keys=primary_keys,
            foreign_keys=foreign_keys,
        )

        # 5) Get triggers
        triggers_list = get_triggers_for_table(cursor, schema_name, tbl_name)

        # 6) Get example data
        #    We quote the schema and table in a single string: "schema_name"."tbl_name"
        cursor.execute(f'SELECT * FROM "{schema_name}"."{tbl_name}" LIMIT {num_rows};')
        example_rows = cursor.fetchall()
        col_names = [desc[0] for desc in cursor.description]
        example_dict_rows = [dict(zip(col_names, row)) for row in example_rows]

        # 7) (Optional) Additional filter for large/noisy columns
        filtered_example_rows = additional_filter(tbl_name, example_dict_rows)

        # 8) Format example data
        if filtered_example_rows:
            example_df = pd.DataFrame(filtered_example_rows)
            example_data_str = (
                "First 3 rows:\n"
                + tabulate(
                    example_df, headers="keys", tablefmt="simple", showindex=False
                )
                + "\n...\n"
            )
        else:
            example_data_str = "First 3 rows:\nNo data available in this table.\n"

        # Combine everything into final text
        schema_str = create_table_sql

        if triggers_list:
            schema_str += "\n\n-- Triggers:"
            for trg_def in triggers_list:
                # each trigger_def is typically "CREATE TRIGGER trigger_name ..."
                schema_str += f"\n{trg_def};"

        schema_str += f"\n\n{example_data_str}"
        schemas_output.append(schema_str)

    cursor.close()
    return "\n\n".join(schemas_output)


def format_postgresql_create_table(
    table_name, columns_info, primary_keys, foreign_keys
):
    """
    Generate a CREATE TABLE statement including columns, primary keys, foreign keys.
    'table_name' can be "schema.table" if needed.
    """
    lines = [f'CREATE TABLE "{table_name}" (']

    for i, (col_name, data_type, is_nullable, column_default) in enumerate(
        columns_info
    ):
        null_status = "NULL" if is_nullable == "YES" else "NOT NULL"
        default_val = f"DEFAULT {column_default}" if column_default else ""
        col_line = f"    {col_name} {data_type} {null_status} {default_val}".strip()
        # put a comma if not last column or if we still have PK/FK
        if i < len(columns_info) - 1 or primary_keys or foreign_keys:
            col_line += ","
        lines.append(col_line)

    # add PK
    if primary_keys:
        pk_line = f"    PRIMARY KEY ({', '.join(primary_keys)})"
        if foreign_keys:
            pk_line += ","
        lines.append(pk_line)

    # add FKs
    for idx, fk in enumerate(foreign_keys):
        fk_col, ref_table, ref_col = fk
        # If the references are in the same schema, you might need "schema.ref_table".
        # For simplicity, we'll just do ref_table as is:
        fk_line = f"    FOREIGN KEY ({fk_col}) REFERENCES {ref_table}({ref_col})"
        if idx < len(foreign_keys) - 1:
            fk_line += ","
        lines.append(fk_line)

    lines.append(");")
    return "\n".join(lines)


def get_triggers_for_table(cursor, schema_name, table_name):
    """
    Retrieve the trigger definitions for the given table from pg_trigger (excluding internal triggers).
    We'll return a list of textual CREATE TRIGGER statements.
    """
    # notice we match "schema_name"."table_name" as the regclass
    cursor.execute(
        f"""
        SELECT tgname, pg_get_triggerdef(pg_trigger.oid)
        FROM pg_trigger
        WHERE NOT tgisinternal
          AND tgrelid = '"{schema_name}"."{table_name}"'::regclass;
    """
    )
    triggers_info = cursor.fetchall()
    return [td for (_, td) in triggers_info]
