#!/usr/bin/env python3
"""
Comprehensive Schema Analyzer - Extracts complete database context
Combines best practices from iter2_distillery and iter8_validation approaches
"""

import sqlite3
import json
import os
from typing import Dict, List, Any, Optional
from collections import defaultdict, Counter

def analyze_database(db_path: str = "database.sqlite") -> Dict[str, Any]:
    """Orchestrate comprehensive database analysis."""
    
    conn = sqlite3.connect(db_path)
    conn.row_factory = sqlite3.Row
    cursor = conn.cursor()
    
    try:
        # Core analyses
        schema = extract_complete_schema(cursor)
        relationships = analyze_relationships(cursor, schema)
        data_patterns = analyze_data_patterns(cursor, schema)
        critical_issues = detect_critical_issues(schema, data_patterns)
        
        # Generate comprehensive report
        report = {
            "database_overview": generate_overview(schema, critical_issues),
            "schema_structure": schema,
            "relationships": relationships,
            "data_patterns": data_patterns,
            "critical_warnings": critical_issues,
            "sql_guidance": generate_sql_guidance(schema, relationships, data_patterns)
        }
        
        return report
        
    finally:
        conn.close()

def extract_complete_schema(cursor) -> Dict[str, Any]:
    """Extract complete schema with all metadata."""
    
    schema = {"tables": {}}
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
    tables = [row[0] for row in cursor.fetchall()]
    
    for table in tables:
        # Get column information
        cursor.execute(f"PRAGMA table_info({table})")
        columns = []
        primary_keys = []
        for col in cursor.fetchall():
            col_info = {
                "name": col[1],
                "type": col[2],
                "nullable": col[3] == 0,
                "default": col[4],
                "primary_key": col[5] == 1
            }
            columns.append(col_info)
            if col[5] == 1:
                primary_keys.append(col[1])
        
        # Get row count
        cursor.execute(f"SELECT COUNT(*) FROM {table}")
        row_count = cursor.fetchone()[0]
        
        # Get foreign keys
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        foreign_keys = []
        for fk in cursor.fetchall():
            foreign_keys.append({
                "column": fk[3],
                "references_table": fk[2],
                "references_column": fk[4]
            })
        
        # Get sample data (first 5 rows)
        samples = []
        if row_count > 0:
            col_names = [c["name"] for c in columns]
            cursor.execute(f"SELECT * FROM {table} LIMIT 5")
            for row in cursor.fetchall():
                samples.append(dict(zip(col_names, row)))
        
        schema["tables"][table] = {
            "columns": columns,
            "primary_keys": primary_keys,
            "row_count": row_count,
            "foreign_keys": foreign_keys,
            "samples": samples,
            "is_empty": row_count == 0,
            "is_junction": len(foreign_keys) >= 2 and len(columns) <= len(foreign_keys) + 2
        }
    
    return schema

def analyze_relationships(cursor, schema: Dict[str, Any]) -> Dict[str, Any]:
    """Analyze table relationships and join paths."""
    
    relationships = {
        "foreign_key_map": {},
        "junction_tables": [],
        "join_paths": {},
        "relationship_graph": defaultdict(list)
    }
    
    # Map all foreign key relationships
    for table, info in schema["tables"].items():
        if info["is_junction"]:
            relationships["junction_tables"].append({
                "table": table,
                "connects": [fk["references_table"] for fk in info["foreign_keys"]],
                "warning": f"Junction table - may need DISTINCT when joining through {table}"
            })
        
        for fk in info["foreign_keys"]:
            key = f"{table}.{fk['column']}"
            target = f"{fk['references_table']}.{fk['references_column']}"
            relationships["foreign_key_map"][key] = target
            relationships["relationship_graph"][table].append(fk["references_table"])
    
    # Find common join paths
    relationships["join_paths"] = find_join_paths(relationships["relationship_graph"])
    
    return relationships

def analyze_data_patterns(cursor, schema: Dict[str, Any]) -> Dict[str, Any]:
    """Analyze data patterns for each table and column."""
    
    patterns = {"tables": {}}
    
    for table, info in schema["tables"].items():
        if info["is_empty"]:
            patterns["tables"][table] = {"empty": True}
            continue
        
        table_patterns = {"columns": {}}
        
        for col in info["columns"]:
            col_name = col["name"]
            col_type = col["type"].upper()
            
            # Get distinct count and samples
            cursor.execute(f"SELECT COUNT(DISTINCT {col_name}) as distinct_count FROM {table}")
            distinct_count = cursor.fetchone()[0]
            
            # Get sample values with exact case
            cursor.execute(f"SELECT DISTINCT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 10")
            samples = [row[0] for row in cursor.fetchall()]
            
            # Analyze patterns
            col_pattern = {
                "type": col_type,
                "distinct_count": distinct_count,
                "samples": samples,
                "is_low_cardinality": distinct_count < 100,
                "is_likely_id": col_name.lower().endswith(('_id', 'id', '_code', 'code')),
                "is_likely_name": col_name.lower().endswith(('_name', 'name')),
                "requires_exact_match": distinct_count < 100 or col_name.lower().endswith(('_id', 'id', '_code', 'code'))
            }
            
            # Check for case sensitivity
            if col_type in ['TEXT', 'VARCHAR'] and samples:
                lowercase_samples = [s.lower() for s in samples if isinstance(s, str)]
                col_pattern["is_case_sensitive"] = len(samples) != len(set(lowercase_samples))
            
            # Check for NULL patterns
            cursor.execute(f"SELECT COUNT(*) FROM {table} WHERE {col_name} IS NULL")
            null_count = cursor.fetchone()[0]
            col_pattern["null_ratio"] = null_count / info["row_count"] if info["row_count"] > 0 else 0
            
            table_patterns["columns"][col_name] = col_pattern
        
        patterns["tables"][table] = table_patterns
    
    return patterns

def detect_critical_issues(schema: Dict[str, Any], patterns: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Detect critical issues that could cause query failures."""
    
    issues = []
    
    # Empty tables
    for table, info in schema["tables"].items():
        if info["is_empty"]:
            issues.append({
                "severity": "CRITICAL",
                "type": "empty_table",
                "table": table,
                "message": f"Table '{table}' is EMPTY - queries will return no results"
            })
    
    # Ambiguous columns
    column_tables = defaultdict(list)
    for table, info in schema["tables"].items():
        for col in info["columns"]:
            column_tables[col["name"]].append(table)
    
    for col_name, tables in column_tables.items():
        if len(tables) > 1:
            issues.append({
                "severity": "HIGH",
                "type": "ambiguous_column",
                "column": col_name,
                "tables": tables,
                "message": f"Column '{col_name}' exists in multiple tables: {', '.join(tables)} - must qualify with table name"
            })
    
    # Common confusion patterns
    confusion_pairs = [
        ("grad_100", "grad_cohort", "grad_100 is RATE (%), grad_cohort is COUNT (#)"),
        ("state", "state_abbr", "state is full name, state_abbr is 2-letter code"),
        ("Value", "Values", "Different columns - check exact spelling"),
        ("ShortName", "LongName", "ShortName is brief, LongName is official full name")
    ]
    
    all_columns = set()
    for table, info in schema["tables"].items():
        for col in info["columns"]:
            all_columns.add(col["name"].lower())
    
    for col1, col2, warning in confusion_pairs:
        if col1.lower() in all_columns and col2.lower() in all_columns:
            issues.append({
                "severity": "HIGH",
                "type": "confusing_columns",
                "columns": [col1, col2],
                "message": warning
            })
    
    return issues

def generate_overview(schema: Dict[str, Any], issues: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Generate database overview statistics."""
    
    total_tables = len(schema["tables"])
    total_rows = sum(info["row_count"] for info in schema["tables"].values())
    empty_tables = [t for t, info in schema["tables"].items() if info["is_empty"]]
    junction_tables = [t for t, info in schema["tables"].items() if info["is_junction"]]
    
    return {
        "table_count": total_tables,
        "total_rows": total_rows,
        "empty_tables": empty_tables,
        "junction_tables": junction_tables,
        "critical_warnings": [i for i in issues if i["severity"] == "CRITICAL"],
        "high_warnings": [i for i in issues if i["severity"] == "HIGH"]
    }

def generate_sql_guidance(schema: Dict[str, Any], relationships: Dict[str, Any], patterns: Dict[str, Any]) -> Dict[str, Any]:
    """Generate SQL-specific guidance based on analysis."""
    
    guidance = {
        "exact_match_columns": [],
        "like_match_columns": [],
        "case_sensitive_columns": [],
        "aggregation_targets": [],
        "common_join_paths": [],
        "distinct_requirements": []
    }
    
    for table, table_patterns in patterns["tables"].items():
        if "columns" not in table_patterns:
            continue
            
        for col_name, col_pattern in table_patterns["columns"].items():
            full_name = f"{table}.{col_name}"
            
            if col_pattern.get("requires_exact_match"):
                guidance["exact_match_columns"].append({
                    "column": full_name,
                    "reason": "Low cardinality or ID/code column",
                    "samples": col_pattern.get("samples", [])[:3]
                })
            elif col_pattern.get("type") in ["TEXT", "VARCHAR"] and not col_pattern.get("is_likely_id"):
                guidance["like_match_columns"].append({
                    "column": full_name,
                    "reason": "High cardinality text column"
                })
            
            if col_pattern.get("is_case_sensitive"):
                guidance["case_sensitive_columns"].append({
                    "column": full_name,
                    "samples": col_pattern.get("samples", [])[:3]
                })
            
            if col_pattern.get("type") in ["INTEGER", "REAL", "NUMERIC"]:
                guidance["aggregation_targets"].append({
                    "column": full_name,
                    "suggested_functions": ["SUM", "AVG", "MIN", "MAX"]
                })
    
    # Add junction table warnings
    for junction in relationships["junction_tables"]:
        guidance["distinct_requirements"].append({
            "scenario": f"Joining through {junction['table']}",
            "reason": "Junction table may create duplicates",
            "recommendation": "Use DISTINCT in SELECT"
        })
    
    # Add common join paths
    guidance["common_join_paths"] = relationships.get("join_paths", {})
    
    return guidance

def find_join_paths(graph: Dict[str, List[str]]) -> Dict[str, Any]:
    """Find common join paths between tables."""
    
    paths = {}
    
    # Simple direct relationships
    for table, connections in graph.items():
        for target in connections:
            key = f"{table}_to_{target}"
            paths[key] = {
                "from": table,
                "to": target,
                "type": "direct",
                "path": [table, target]
            }
    
    return paths

def main():
    """Main execution function."""
    
    db_path = "database.sqlite"
    
    # Ensure output directory exists
    os.makedirs("tool_output", exist_ok=True)
    
    try:
        # Run analysis
        report = analyze_database(db_path)
        
        # Save to JSON
        output_path = "tool_output/comprehensive_schema.json"
        with open(output_path, 'w') as f:
            json.dump(report, f, indent=2, default=str)
        
        print(f"✓ Comprehensive schema analysis complete")
        print(f"  - Analyzed {report['database_overview']['table_count']} tables")
        print(f"  - Found {len(report['critical_warnings'])} critical issues")
        print(f"  - Results saved to {output_path}")
        
        # Print critical warnings
        if report['critical_warnings']:
            print("\n⚠️  CRITICAL WARNINGS:")
            for warning in report['critical_warnings']:
                print(f"  - {warning['message']}")
        
    except Exception as e:
        print(f"✗ Analysis failed: {str(e)}")
        # Create partial report
        error_report = {
            "error": str(e),
            "partial_results": {}
        }
        with open("tool_output/comprehensive_schema.json", 'w') as f:
            json.dump(error_report, f, indent=2)

if __name__ == "__main__":
    main()