#!/usr/bin/env python3
"""
Focused Context Analyzer - Streamlined database analysis for SQL generation
"""

import sqlite3
import os
import json
import re
from collections import defaultdict, Counter

def analyze_database(db_path="database.sqlite"):
    """Focused database analysis - essential info only"""

    os.makedirs('tool_output', exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    output = []

    # Get all tables
    cursor.execute("""
        SELECT name, sql FROM sqlite_master
        WHERE type='table' AND name NOT LIKE 'sqlite_%'
        ORDER BY name
    """)
    tables = cursor.fetchall()
    table_names = [t[0] for t in tables]

    # Section 1: Raw Schema (like naive, but with inline warnings)
    output.append("=== DATABASE SCHEMA ===\n")

    table_columns = {}
    special_columns = {}  # Track columns that need warnings

    for table_name, create_sql in tables:
        if create_sql:
            # Analyze for special column patterns
            if ' ' in create_sql or '"' in create_sql:
                special_columns[table_name] = "Has columns with spaces - use quotes"

            output.append(create_sql)

    # Section 2: Critical Information Only
    output.append("\n\n=== CRITICAL INFORMATION ===\n")

    # Get foreign keys
    fk_count = 0
    output.append("Foreign Keys:")
    for table in table_names:
        cursor.execute(f"PRAGMA foreign_key_list('{table}')")
        fks = cursor.fetchall()
        for fk in fks:
            output.append(f"  {table}.{fk[3]} → {fk[2]}.{fk[4]}")
            fk_count += 1

    if fk_count == 0:
        output.append("  None defined - infer joins from column names")

    # Collect column info for analysis
    for table in table_names:
        cursor.execute(f"PRAGMA table_info('{table}')")
        columns = cursor.fetchall()
        table_columns[table] = [col[1] for col in columns]

    # Section 3: Key Data Samples (only for problematic patterns)
    output.append("\n\n=== KEY DATA SAMPLES ===\n")
    output.append("(Showing only columns that commonly cause errors)\n")

    problem_patterns = ['id', 'ID', 'type', 'TYPE', 'status', 'STATUS',
                       'category', 'name', 'Name', 'date', 'Date']

    for table in table_names:
        samples_shown = False

        for col_name in table_columns[table]:
            # Check if this column matches problem patterns
            show_sample = False

            # Check for ID confusion
            if 'id' in col_name.lower() and col_name != 'id':
                show_sample = True

            # Check for case-sensitive columns
            if any(p in col_name for p in problem_patterns):
                show_sample = True

            # Check for columns with spaces
            if ' ' in col_name:
                show_sample = True

            if show_sample:
                try:
                    # Get distinct values
                    cursor.execute(f"""
                        SELECT DISTINCT "{col_name}", COUNT(*) as cnt
                        FROM {table}
                        WHERE "{col_name}" IS NOT NULL
                        GROUP BY "{col_name}"
                        ORDER BY cnt DESC
                        LIMIT 5
                    """)
                    values = cursor.fetchall()

                    if values:
                        if not samples_shown:
                            output.append(f"\n{table}:")
                            samples_shown = True

                        output.append(f"  {col_name}:")

                        # Check for case sensitivity issues
                        has_case_variation = False
                        if all(isinstance(v[0], str) for v in values):
                            lower_values = set(v[0].lower() for v in values if v[0])
                            if len(lower_values) < len(values):
                                has_case_variation = True

                        for val, cnt in values[:3]:
                            if isinstance(val, str) and len(val) > 50:
                                val = val[:50] + "..."
                            output.append(f"    {repr(val)} ({cnt} rows)")

                        if has_case_variation:
                            output.append(f"    WARNING: Case-sensitive values!")

                except Exception:
                    pass

    # Section 4: Evidence Traps (database-specific)
    output.append("\n\n=== COMMON EVIDENCE TRAPS ===\n")

    # Check for specific problematic patterns
    traps_found = []

    # Check for boolean-like columns that might be counterintuitive
    for table in table_names:
        for col in table_columns[table]:
            if any(word in col.lower() for word in ['compiled', 'enabled', 'active', 'valid']):
                try:
                    cursor.execute(f"""
                        SELECT DISTINCT "{col}" FROM {table}
                        WHERE "{col}" IS NOT NULL LIMIT 5
                    """)
                    values = [row[0] for row in cursor.fetchall()]
                    if values:
                        traps_found.append(f'"{col}" in {table}: values are {values[:3]} - check evidence carefully')
                except:
                    pass

    # Check for date columns with specific formats
    for table in table_names:
        for col in table_columns[table]:
            if 'date' in col.lower() or 'time' in col.lower():
                try:
                    cursor.execute(f"""
                        SELECT "{col}" FROM {table}
                        WHERE "{col}" IS NOT NULL LIMIT 1
                    """)
                    sample = cursor.fetchone()
                    if sample and sample[0]:
                        traps_found.append(f'{table}.{col} format: {repr(sample[0])}')
                except:
                    pass

    if traps_found:
        for trap in traps_found[:10]:  # Limit to 10 most important
            output.append(f"- {trap}")
    else:
        output.append("- No specific traps detected")

    # Section 5: Quick Reference
    output.append("\n\n=== QUICK REFERENCE ===\n")

    # Show table sizes
    output.append("Table Sizes:")
    for table in table_names:
        try:
            cursor.execute(f"SELECT COUNT(*) FROM {table}")
            count = cursor.fetchone()[0]
            output.append(f"  {table}: {count} rows")
        except:
            pass

    # Check for tables with ID columns that aren't named 'id'
    id_warnings = []
    for table in table_names:
        for col in table_columns[table]:
            if col.lower().endswith('_id') or col.lower().endswith('id'):
                if col.lower() != 'id' and 'foreign' not in col.lower():
                    id_warnings.append(f"{table}.{col}")

    if id_warnings:
        output.append("\nNon-standard ID columns:")
        for warning in id_warnings[:5]:
            output.append(f"  {warning}")

    # Final reminders
    output.append("\n\n=== REMEMBER ===")
    output.append("- Return ONLY columns explicitly requested")
    output.append("- Use exact matches (=) not LIKE unless needed")
    output.append("- Check samples for case sensitivity")
    output.append("- Evidence can be misleading - verify against actual data")

    # Write to file
    output_text = '\n'.join(output)

    with open('tool_output/focused_analysis.txt', 'w') as f:
        f.write(output_text)

    print(output_text)

    conn.close()

if __name__ == "__main__":
    analyze_database()