#!/usr/bin/env python3
"""
Schema Mapper - Deep schema analysis with exact column locations.
Eliminates "no such column" errors by providing complete table.column mappings.
"""

import sqlite3
import json
import os

def analyze_schema(db_path="database.sqlite"):
    """Perform comprehensive schema analysis."""

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    schema_data = {
        "tables": {},
        "relationships": [],
        "column_index": {}  # Quick lookup: column_name -> [table1, table2, ...]
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [t[0] for t in cursor.fetchall()]

        for table in tables:
            # Get detailed column information
            cursor.execute(f"PRAGMA table_info({table})")
            columns = cursor.fetchall()

            # Get row count
            cursor.execute(f"SELECT COUNT(*) FROM [{table}]")
            row_count = cursor.fetchone()[0]

            # Get foreign keys
            cursor.execute(f"PRAGMA foreign_key_list({table})")
            foreign_keys = cursor.fetchall()

            # Store table info
            schema_data["tables"][table] = {
                "columns": [
                    {
                        "name": col[1],
                        "type": col[2],
                        "nullable": col[3] == 0,
                        "primary_key": col[5] > 0,
                        "position": col[0]
                    }
                    for col in columns
                ],
                "row_count": row_count,
                "foreign_keys": [
                    {
                        "column": fk[3],
                        "references_table": fk[2],
                        "references_column": fk[4]
                    }
                    for fk in foreign_keys
                ]
            }

            # Build column index for quick lookup
            for col in columns:
                col_name = col[1]
                if col_name not in schema_data["column_index"]:
                    schema_data["column_index"][col_name] = []
                schema_data["column_index"][col_name].append(table)

            # Track relationships
            for fk in foreign_keys:
                schema_data["relationships"].append({
                    "from_table": table,
                    "from_column": fk[3],
                    "to_table": fk[2],
                    "to_column": fk[4]
                })

        # Find implicit relationships (common column names)
        for col_name, tables_with_col in schema_data["column_index"].items():
            if len(tables_with_col) > 1 and (col_name.endswith('id') or col_name.endswith('Id') or col_name.endswith('ID')):
                for i, t1 in enumerate(tables_with_col):
                    for t2 in tables_with_col[i+1:]:
                        # Check if not already tracked as FK
                        existing = any(
                            r for r in schema_data["relationships"]
                            if (r["from_table"] == t1 and r["to_table"] == t2 and r["from_column"] == col_name) or
                               (r["from_table"] == t2 and r["to_table"] == t1 and r["from_column"] == col_name)
                        )
                        if not existing:
                            schema_data["relationships"].append({
                                "from_table": t1,
                                "from_column": col_name,
                                "to_table": t2,
                                "to_column": col_name,
                                "type": "implicit"
                            })

    finally:
        conn.close()

    # Save schema mapping
    os.makedirs('tool_output', exist_ok=True)
    with open('tool_output/schema_mapping.json', 'w') as f:
        json.dump(schema_data, f, indent=2)

    # Also create a quick reference text file
    with open('tool_output/column_locations.txt', 'w') as f:
        f.write("COLUMN LOCATION QUICK REFERENCE\n")
        f.write("=" * 60 + "\n\n")

        # List all columns with their table locations
        all_columns = {}
        for table, info in schema_data["tables"].items():
            for col in info["columns"]:
                col_name = col["name"]
                if col_name not in all_columns:
                    all_columns[col_name] = []
                all_columns[col_name].append(f"{table}.{col_name}")

        # Sort and write
        for col_name in sorted(all_columns.keys()):
            locations = all_columns[col_name]
            if len(locations) == 1:
                f.write(f"{col_name}: {locations[0]}\n")
            else:
                f.write(f"{col_name}: {', '.join(locations)} (in multiple tables)\n")

        f.write("\n" + "=" * 60 + "\n")
        f.write("RELATIONSHIPS (for JOINs):\n")
        f.write("-" * 40 + "\n")

        for rel in schema_data["relationships"]:
            f.write(f"{rel['from_table']}.{rel['from_column']} -> {rel['to_table']}.{rel['to_column']}\n")

    print(f"Schema mapping complete: {len(schema_data['tables'])} tables analyzed")
    return schema_data

if __name__ == "__main__":
    analyze_schema()