#!/usr/bin/env python3
"""
Schema Analyzer Tool
Extracts complete database schema with detailed column information.
Inspired by OpenSearch-SQL's preprocessing phase.
"""

import sqlite3
import json
import os

def analyze_schema(db_path="database.sqlite"):
    """Extract complete schema information from the database."""

    # Ensure output directory exists
    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    schema_info = {
        "tables": {},
        "summary": {
            "total_tables": 0,
            "total_columns": 0,
            "total_rows": 0
        }
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = cursor.fetchall()

        for table_name, in tables:
            # Skip SQLite internal tables
            if table_name.startswith("sqlite_"):
                continue

            table_info = {
                "columns": {},
                "row_count": 0,
                "primary_keys": [],
                "indexes": [],
                "sql_definition": ""
            }

            # Get table creation SQL
            cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
            create_sql = cursor.fetchone()
            if create_sql:
                table_info["sql_definition"] = create_sql[0]

            # Get column information
            cursor.execute(f"PRAGMA table_info(`{table_name}`)")
            columns = cursor.fetchall()

            for col in columns:
                col_id, col_name, col_type, not_null, default_value, is_pk = col
                table_info["columns"][col_name] = {
                    "type": col_type,
                    "nullable": not not_null,
                    "default": default_value,
                    "is_primary_key": bool(is_pk),
                    "position": col_id
                }

                if is_pk:
                    table_info["primary_keys"].append(col_name)

            # Get row count
            try:
                cursor.execute(f"SELECT COUNT(*) FROM `{table_name}`")
                table_info["row_count"] = cursor.fetchone()[0]
                schema_info["summary"]["total_rows"] += table_info["row_count"]
            except Exception as e:
                table_info["row_count"] = f"Error: {str(e)}"

            # Get indexes
            cursor.execute(f"PRAGMA index_list(`{table_name}`)")
            indexes = cursor.fetchall()
            for idx in indexes:
                idx_name = idx[1]
                cursor.execute(f"PRAGMA index_info(`{idx_name}`)")
                idx_cols = cursor.fetchall()
                table_info["indexes"].append({
                    "name": idx_name,
                    "unique": bool(idx[2]),
                    "columns": [col[2] for col in idx_cols]
                })

            schema_info["tables"][table_name] = table_info
            schema_info["summary"]["total_columns"] += len(table_info["columns"])

        schema_info["summary"]["total_tables"] = len(schema_info["tables"])

        # Identify columns that appear in multiple tables
        column_occurrences = {}
        for table_name, table_info in schema_info["tables"].items():
            for col_name in table_info["columns"]:
                if col_name not in column_occurrences:
                    column_occurrences[col_name] = []
                column_occurrences[col_name].append(table_name)

        schema_info["duplicate_columns"] = {
            col: tables for col, tables in column_occurrences.items()
            if len(tables) > 1
        }

    except Exception as e:
        schema_info["error"] = str(e)
    finally:
        conn.close()

    # Write to output file
    output_path = "tool_output/schema_analysis.json"
    with open(output_path, 'w') as f:
        json.dump(schema_info, f, indent=2)

    print(f"Schema analysis complete - results in {output_path}")
    print(f"Found {schema_info['summary']['total_tables']} tables with {schema_info['summary']['total_columns']} total columns")
    print(f"Total rows in database: {schema_info['summary']['total_rows']}")

    if schema_info.get("duplicate_columns"):
        print(f"\nWarning: {len(schema_info['duplicate_columns'])} columns appear in multiple tables")

if __name__ == "__main__":
    analyze_schema()