#!/usr/bin/env python3
"""
Resilient Schema Analyzer Tool
Extracts database schema with multiple fallback strategies and comprehensive error handling.
"""

import sqlite3
import json
import os
import sys
from typing import Dict, List, Any, Optional

def safe_execute(cursor, query: str, params=None) -> Optional[List]:
    """Execute a query with error handling."""
    try:
        if params:
            cursor.execute(query, params)
        else:
            cursor.execute(query)
        return cursor.fetchall()
    except Exception as e:
        print(f"⚠️  Query failed: {query[:50]}... Error: {str(e)}")
        return None

def get_table_names_basic(cursor) -> List[str]:
    """Get table names with fallback strategies."""
    tables = []

    # Primary method: sqlite_master
    result = safe_execute(cursor,
        "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")

    if result:
        tables = [name for name, in result if not name.startswith('sqlite_')]
        print(f"✓ Found {len(tables)} tables via sqlite_master")
    else:
        # Fallback: Try common table name patterns
        print("⚠️  sqlite_master query failed, trying fallback methods")

        # Try PRAGMA method
        for potential_table in ['users', 'products', 'orders', 'customers']:
            test = safe_execute(cursor, f"SELECT 1 FROM {potential_table} LIMIT 1")
            if test is not None:
                tables.append(potential_table)

    return tables

def analyze_table_robust(cursor, table_name: str) -> Dict[str, Any]:
    """Analyze a single table with multiple fallback strategies."""
    table_info = {
        "columns": {},
        "row_count": "unknown",
        "primary_keys": [],
        "foreign_keys": [],
        "analysis_status": "complete",
        "warnings": []
    }

    # Method 1: PRAGMA table_info (most reliable)
    pragma_result = safe_execute(cursor, f"PRAGMA table_info(`{table_name}`)")

    if pragma_result:
        for col in pragma_result:
            try:
                col_id, col_name, col_type, not_null, default_value, is_pk = col
                table_info["columns"][col_name] = {
                    "type": col_type or "UNKNOWN",
                    "nullable": not not_null,
                    "default": default_value,
                    "is_primary_key": bool(is_pk),
                    "position": col_id,
                    "semantic_type": infer_semantic_type(col_name, col_type)
                }
                if is_pk:
                    table_info["primary_keys"].append(col_name)
            except Exception as e:
                table_info["warnings"].append(f"Column parse error: {str(e)}")
    else:
        # Fallback: Try to get any info from SELECT *
        table_info["analysis_status"] = "partial"
        sample = safe_execute(cursor, f"SELECT * FROM `{table_name}` LIMIT 1")
        if sample and cursor.description:
            for i, desc in enumerate(cursor.description):
                col_name = desc[0]
                table_info["columns"][col_name] = {
                    "type": "UNKNOWN",
                    "nullable": True,
                    "default": None,
                    "is_primary_key": False,
                    "position": i,
                    "semantic_type": infer_semantic_type(col_name, None)
                }
            table_info["warnings"].append("Schema from SELECT * - types unknown")

    # Get row count with timeout protection
    try:
        # First try exact count with limit
        result = safe_execute(cursor,
            f"SELECT COUNT(*) FROM (SELECT 1 FROM `{table_name}` LIMIT 100000)")
        if result:
            count = result[0][0]
            if count == 100000:
                table_info["row_count"] = "100000+"
                table_info["warnings"].append("Large table - count limited")
            else:
                table_info["row_count"] = count
        else:
            # Try sampling approach
            sample = safe_execute(cursor, f"SELECT 1 FROM `{table_name}` LIMIT 10")
            if sample:
                table_info["row_count"] = "10+ (exact count failed)"
            else:
                table_info["row_count"] = "unknown"
    except Exception as e:
        table_info["row_count"] = "error"
        table_info["warnings"].append(f"Row count failed: {str(e)}")

    # Try to detect foreign keys from SQL definition
    create_sql = safe_execute(cursor,
        "SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
        (table_name,))

    if create_sql and create_sql[0][0]:
        sql_text = create_sql[0][0].lower()
        if 'foreign key' in sql_text:
            import re
            fk_pattern = r'foreign\s+key\s*\(([^)]+)\)\s+references\s+(\w+)\s*\(([^)]+)\)'
            for match in re.finditer(fk_pattern, sql_text):
                table_info["foreign_keys"].append({
                    "column": match.group(1).strip(),
                    "references_table": match.group(2),
                    "references_column": match.group(3).strip()
                })

    # Try to detect relationships from column names
    for col_name in table_info["columns"]:
        if col_name.endswith('_id') and col_name != 'id':
            potential_table = col_name[:-3]
            table_info["foreign_keys"].append({
                "column": col_name,
                "references_table": f"{potential_table} (inferred)",
                "references_column": "id (assumed)",
                "confidence": "low"
            })

    return table_info

def infer_semantic_type(col_name: str, col_type: Optional[str]) -> str:
    """Infer semantic type from column name and type."""
    col_lower = col_name.lower()

    # Check name patterns
    if any(x in col_lower for x in ['id', '_id', 'identifier', 'code']):
        return "identifier"
    elif any(x in col_lower for x in ['count', 'total', 'sum', 'amount', 'quantity', 'sales']):
        return "measure"
    elif any(x in col_lower for x in ['date', 'time', 'year', 'month', 'day', 'created', 'updated']):
        return "temporal"
    elif any(x in col_lower for x in ['name', 'title', 'description', 'category', 'type', 'status']):
        return "category"
    elif any(x in col_lower for x in ['is_', 'has_', 'enabled', 'active', 'flag', 'boolean']):
        return "boolean"
    elif any(x in col_lower for x in ['percent', 'ratio', 'rate']):
        return "percentage"
    elif col_type:
        # Check type patterns
        if 'INT' in col_type.upper():
            return "numeric"
        elif 'REAL' in col_type.upper() or 'FLOAT' in col_type.upper():
            return "decimal"
        elif 'TEXT' in col_type.upper() or 'CHAR' in col_type.upper():
            return "text"

    return "general"

def analyze_schema(db_path="database.sqlite"):
    """Main schema analysis with comprehensive error handling."""

    os.makedirs("tool_output", exist_ok=True)

    schema_info = {
        "tables": {},
        "summary": {
            "total_tables": 0,
            "total_columns": 0,
            "total_rows": 0,
            "analysis_quality": "complete"
        },
        "semantic_classification": {
            "identifier_columns": [],
            "measure_columns": [],
            "category_columns": [],
            "temporal_columns": [],
            "boolean_columns": [],
            "percentage_columns": []
        },
        "ambiguous_columns": {},
        "relationships": {
            "foreign_keys": [],
            "inferred_joins": []
        },
        "analysis_errors": []
    }

    conn = None
    try:
        # Connect with timeout to prevent hanging
        conn = sqlite3.connect(db_path, timeout=30.0)
        cursor = conn.cursor()
        print(f"✓ Connected to database: {db_path}")

        # Get table names with fallback
        tables = get_table_names_basic(cursor)

        if not tables:
            schema_info["analysis_errors"].append("No tables found or accessible")
            schema_info["summary"]["analysis_quality"] = "failed"
            print("❌ No tables found in database")
        else:
            column_occurrences = {}

            for table_name in tables:
                print(f"  Analyzing table: {table_name}")

                try:
                    # Analyze table with timeout protection
                    conn.execute("PRAGMA query_timeout = 5000")  # 5 second timeout
                    table_info = analyze_table_robust(cursor, table_name)

                    # Track column occurrences
                    for col_name in table_info["columns"]:
                        if col_name not in column_occurrences:
                            column_occurrences[col_name] = []
                        column_occurrences[col_name].append(table_name)

                        # Classify semantically
                        semantic_type = table_info["columns"][col_name].get("semantic_type", "general")
                        full_name = f"{table_name}.{col_name}"

                        if semantic_type == "identifier":
                            schema_info["semantic_classification"]["identifier_columns"].append(full_name)
                        elif semantic_type == "measure":
                            schema_info["semantic_classification"]["measure_columns"].append(full_name)
                        elif semantic_type == "category":
                            schema_info["semantic_classification"]["category_columns"].append(full_name)
                        elif semantic_type == "temporal":
                            schema_info["semantic_classification"]["temporal_columns"].append(full_name)
                        elif semantic_type == "boolean":
                            schema_info["semantic_classification"]["boolean_columns"].append(full_name)
                        elif semantic_type == "percentage":
                            schema_info["semantic_classification"]["percentage_columns"].append(full_name)

                    # Track foreign keys
                    for fk in table_info.get("foreign_keys", []):
                        schema_info["relationships"]["foreign_keys"].append({
                            "from": f"{table_name}.{fk['column']}",
                            "to": f"{fk['references_table']}.{fk['references_column']}",
                            "confidence": fk.get("confidence", "high")
                        })

                    schema_info["tables"][table_name] = table_info

                    # Update summary
                    schema_info["summary"]["total_columns"] += len(table_info["columns"])
                    if isinstance(table_info["row_count"], int):
                        schema_info["summary"]["total_rows"] += table_info["row_count"]

                    # Track analysis quality
                    if table_info["analysis_status"] == "partial":
                        if schema_info["summary"]["analysis_quality"] == "complete":
                            schema_info["summary"]["analysis_quality"] = "partial"

                    if table_info.get("warnings"):
                        schema_info["analysis_errors"].extend(
                            [f"{table_name}: {w}" for w in table_info["warnings"]]
                        )

                except Exception as e:
                    print(f"  ⚠️  Error analyzing table {table_name}: {str(e)}")
                    schema_info["analysis_errors"].append(f"Table {table_name}: {str(e)}")
                    schema_info["tables"][table_name] = {
                        "columns": {},
                        "analysis_status": "failed",
                        "error": str(e)
                    }
                    if schema_info["summary"]["analysis_quality"] == "complete":
                        schema_info["summary"]["analysis_quality"] = "partial"

            schema_info["summary"]["total_tables"] = len(schema_info["tables"])

            # Identify ambiguous columns
            for col_name, tables in column_occurrences.items():
                if len(tables) > 1:
                    schema_info["ambiguous_columns"][col_name] = {
                        "appears_in": tables,
                        "count": len(tables),
                        "disambiguation_required": True
                    }

            # Infer additional relationships from naming patterns
            for table1 in tables:
                for table2 in tables:
                    if table1 != table2:
                        # Check for junction table pattern
                        if '_' in table1:
                            parts = table1.split('_')
                            if len(parts) == 2 and parts[0] in tables and parts[1] in tables:
                                schema_info["relationships"]["inferred_joins"].append({
                                    "type": "junction",
                                    "table": table1,
                                    "connects": [parts[0], parts[1]]
                                })

    except sqlite3.Error as e:
        schema_info["analysis_errors"].append(f"Database error: {str(e)}")
        schema_info["summary"]["analysis_quality"] = "failed"
        print(f"❌ Database error: {str(e)}")

    except Exception as e:
        schema_info["analysis_errors"].append(f"Unexpected error: {str(e)}")
        schema_info["summary"]["analysis_quality"] = "failed"
        print(f"❌ Unexpected error: {str(e)}")

    finally:
        if conn:
            conn.close()

    # Save results
    output_path = "tool_output/schema_analysis.json"
    try:
        with open(output_path, 'w') as f:
            json.dump(schema_info, f, indent=2, default=str)
        print(f"\n✓ Schema analysis saved to {output_path}")
        print(f"  Quality: {schema_info['summary']['analysis_quality']}")
        print(f"  Tables: {schema_info['summary']['total_tables']}")
        print(f"  Columns: {schema_info['summary']['total_columns']}")

        if schema_info["analysis_errors"]:
            print(f"  ⚠️  {len(schema_info['analysis_errors'])} warnings/errors logged")

    except Exception as e:
        print(f"❌ Failed to save results: {str(e)}")
        # Try to save a minimal error file
        try:
            with open("tool_output/schema_analysis.error", 'w') as f:
                f.write(f"Analysis failed: {str(e)}\n")
                f.write(f"Partial data: {len(schema_info.get('tables', {}))} tables found\n")
        except:
            pass

    return schema_info

if __name__ == "__main__":
    result = analyze_schema()

    # Exit with appropriate code
    if result["summary"]["analysis_quality"] == "failed":
        sys.exit(1)
    elif result["summary"]["analysis_quality"] == "partial":
        sys.exit(0)  # Partial success
    else:
        sys.exit(0)  # Complete success