#!/usr/bin/env python3
"""
Enhanced Schema Analyzer Tool
Extracts complete database schema with semantic classification.
"""

import sqlite3
import json
import os

def analyze_schema(db_path="database.sqlite"):
    """Extract complete schema with semantic analysis."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    schema_info = {
        "tables": {},
        "summary": {
            "total_tables": 0,
            "total_columns": 0,
            "total_rows": 0
        },
        "semantic_classification": {
            "identifier_columns": [],
            "measure_columns": [],
            "category_columns": [],
            "temporal_columns": [],
            "boolean_columns": []
        },
        "ambiguous_columns": {}
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = cursor.fetchall()

        column_occurrences = {}

        for table_name, in tables:
            if table_name.startswith("sqlite_"):
                continue

            table_info = {
                "columns": {},
                "row_count": 0,
                "primary_keys": [],
                "foreign_keys": [],
                "indexes": [],
                "sql_definition": ""
            }

            # Get table creation SQL
            cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
            create_sql = cursor.fetchone()
            if create_sql:
                table_info["sql_definition"] = create_sql[0]
                # Extract foreign keys from SQL
                sql_lower = create_sql[0].lower()
                if 'foreign key' in sql_lower:
                    # Parse foreign keys
                    import re
                    fk_pattern = r'foreign\s+key\s*\(([^)]+)\)\s+references\s+(\w+)\s*\(([^)]+)\)'
                    for match in re.finditer(fk_pattern, sql_lower):
                        table_info["foreign_keys"].append({
                            "column": match.group(1).strip(),
                            "references_table": match.group(2),
                            "references_column": match.group(3).strip()
                        })

            # Get column information
            cursor.execute(f"PRAGMA table_info(`{table_name}`)")
            columns = cursor.fetchall()

            for col in columns:
                col_id, col_name, col_type, not_null, default_value, is_pk = col

                # Track column occurrences across tables
                if col_name not in column_occurrences:
                    column_occurrences[col_name] = []
                column_occurrences[col_name].append(table_name)

                # Semantic classification
                col_type_upper = col_type.upper()
                col_name_lower = col_name.lower()

                semantic_type = "general"
                if is_pk or 'id' in col_name_lower:
                    semantic_type = "identifier"
                    schema_info["semantic_classification"]["identifier_columns"].append(f"{table_name}.{col_name}")
                elif 'count' in col_name_lower or 'total' in col_name_lower or 'sum' in col_name_lower or 'amount' in col_name_lower:
                    semantic_type = "measure"
                    schema_info["semantic_classification"]["measure_columns"].append(f"{table_name}.{col_name}")
                elif 'date' in col_name_lower or 'time' in col_name_lower or 'year' in col_name_lower or 'month' in col_name_lower:
                    semantic_type = "temporal"
                    schema_info["semantic_classification"]["temporal_columns"].append(f"{table_name}.{col_name}")
                elif col_type_upper == 'BOOLEAN' or (col_name_lower in ['active', 'enabled', 'deleted', 'flag', 'is_active']):
                    semantic_type = "boolean"
                    schema_info["semantic_classification"]["boolean_columns"].append(f"{table_name}.{col_name}")
                elif col_type_upper in ['TEXT', 'VARCHAR'] and not is_pk:
                    semantic_type = "category"
                    schema_info["semantic_classification"]["category_columns"].append(f"{table_name}.{col_name}")

                table_info["columns"][col_name] = {
                    "type": col_type,
                    "nullable": not not_null,
                    "default": default_value,
                    "is_primary_key": bool(is_pk),
                    "position": col_id,
                    "semantic_type": semantic_type
                }

                if is_pk:
                    table_info["primary_keys"].append(col_name)

            # Get row count
            try:
                cursor.execute(f"SELECT COUNT(*) FROM `{table_name}`")
                table_info["row_count"] = cursor.fetchone()[0]
                schema_info["summary"]["total_rows"] += table_info["row_count"]
            except Exception as e:
                table_info["row_count"] = f"Error: {str(e)}"

            # Get indexes
            cursor.execute(f"PRAGMA index_list(`{table_name}`)")
            indexes = cursor.fetchall()
            for idx in indexes:
                idx_name = idx[1]
                cursor.execute(f"PRAGMA index_info(`{idx_name}`)")
                idx_cols = cursor.fetchall()
                table_info["indexes"].append({
                    "name": idx_name,
                    "unique": bool(idx[2]),
                    "columns": [col[2] for col in idx_cols]
                })

            schema_info["tables"][table_name] = table_info
            schema_info["summary"]["total_columns"] += len(table_info["columns"])

        schema_info["summary"]["total_tables"] = len(schema_info["tables"])

        # Identify ambiguous columns (appear in multiple tables)
        for col_name, tables in column_occurrences.items():
            if len(tables) > 1:
                schema_info["ambiguous_columns"][col_name] = {
                    "appears_in": tables,
                    "count": len(tables),
                    "disambiguation_required": True
                }

        # Save to file
        with open("tool_output/schema_analysis.json", "w") as f:
            json.dump(schema_info, f, indent=2)

        print(f"✓ Schema analysis complete: {schema_info['summary']['total_tables']} tables, {schema_info['summary']['total_columns']} columns")
        print(f"✓ Found {len(schema_info['ambiguous_columns'])} ambiguous columns requiring disambiguation")

    except Exception as e:
        print(f"✗ Schema analysis failed: {str(e)}")
        raise
    finally:
        conn.close()

if __name__ == "__main__":
    analyze_schema()