#!/usr/bin/env python3
"""
Error Prevention Validator - Detects and prevents common SQL generation errors
Combines validation from iter7 with comprehensive patterns from iter2
"""

import sqlite3
import json
import os
import re
from typing import Dict, List, Any, Optional
from collections import defaultdict, Counter

class ErrorPreventionValidator:
    def __init__(self, db_path: str):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.schema = self.load_schema()
        self.common_errors = []
        self.prevention_rules = []
        
    def load_schema(self) -> Dict[str, Any]:
        """Load database schema for validation."""
        
        schema = {"tables": {}}
        
        # Get all tables
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
        tables = [row[0] for row in self.cursor.fetchall()]
        
        for table in tables:
            # Get columns
            self.cursor.execute(f"PRAGMA table_info({table})")
            columns = {}
            for row in self.cursor.fetchall():
                columns[row[1]] = {
                    "type": row[2],
                    "nullable": row[3] == 0,
                    "primary_key": row[5] == 1
                }
            
            # Get row count
            self.cursor.execute(f"SELECT COUNT(*) FROM {table}")
            row_count = self.cursor.fetchone()[0]
            
            schema["tables"][table] = {
                "columns": columns,
                "row_count": row_count
            }
        
        return schema
    
    def detect_column_confusion_patterns(self) -> List[Dict[str, Any]]:
        """Detect columns that are commonly confused."""
        
        confusion_patterns = []
        
        # Known confusion pairs
        known_confusions = [
            {
                "pair": ["grad_100", "grad_cohort"],
                "explanation": "grad_100 is graduation RATE (percentage), grad_cohort is graduation COUNT (number)",
                "severity": "CRITICAL"
            },
            {
                "pair": ["state", "state_abbr", "state_code"],
                "explanation": "state is full name, state_abbr/state_code is 2-letter abbreviation",
                "severity": "HIGH"
            },
            {
                "pair": ["Value", "Values"],
                "explanation": "Different columns - exact case and spelling matter",
                "severity": "HIGH"
            },
            {
                "pair": ["ShortName", "LongName", "Name"],
                "explanation": "ShortName is brief, LongName is official full name, Name may be either",
                "severity": "MEDIUM"
            },
            {
                "pair": ["id", "ID", "Id"],
                "explanation": "Case matters - check exact column name",
                "severity": "HIGH"
            }
        ]
        
        # Check which confusion patterns exist in this database
        all_columns = set()
        for table, info in self.schema["tables"].items():
            for col in info["columns"]:
                all_columns.add(col.lower())
        
        for confusion in known_confusions:
            found_columns = []
            for col in confusion["pair"]:
                if col.lower() in all_columns:
                    found_columns.append(col)
            
            if len(found_columns) >= 2:
                confusion_patterns.append({
                    "columns": found_columns,
                    "warning": confusion["explanation"],
                    "severity": confusion["severity"],
                    "prevention": f"Be explicit about which column: {confusion['explanation']}"
                })
        
        return confusion_patterns
    
    def detect_aggregation_patterns(self) -> List[Dict[str, Any]]:
        """Detect common aggregation errors."""
        
        patterns = []
        
        # Check for numeric columns that might be aggregated
        for table, info in self.schema["tables"].items():
            numeric_columns = []
            non_numeric_columns = []
            
            for col, col_info in info["columns"].items():
                if col_info["type"] in ["INTEGER", "REAL", "NUMERIC"]:
                    numeric_columns.append(col)
                else:
                    non_numeric_columns.append(col)
            
            if numeric_columns and non_numeric_columns:
                patterns.append({
                    "table": table,
                    "aggregation_columns": numeric_columns,
                    "grouping_columns": non_numeric_columns,
                    "warning": f"Table {table} has numeric columns that may need aggregation",
                    "prevention": "When using SUM/AVG/COUNT on numeric columns, GROUP BY non-aggregated columns"
                })
        
        return patterns
    
    def detect_join_errors(self) -> List[Dict[str, Any]]:
        """Detect potential join-related errors."""
        
        errors = []
        
        # Find potential join columns (columns that appear in multiple tables)
        column_tables = defaultdict(list)
        for table, info in self.schema["tables"].items():
            for col in info["columns"]:
                column_tables[col].append(table)
        
        # Detect ambiguous columns
        for col, tables in column_tables.items():
            if len(tables) > 1:
                errors.append({
                    "type": "ambiguous_column",
                    "column": col,
                    "tables": tables,
                    "warning": f"Column '{col}' exists in {len(tables)} tables",
                    "prevention": f"Always qualify as table.{col} in JOINs",
                    "severity": "HIGH"
                })
        
        # Detect junction tables
        for table, info in self.schema["tables"].items():
            fk_like_columns = [col for col in info["columns"] 
                              if col.lower().endswith(('_id', 'id')) and not col.lower() == 'id']
            if len(fk_like_columns) >= 2:
                errors.append({
                    "type": "junction_table",
                    "table": table,
                    "warning": f"Table '{table}' appears to be a junction table",
                    "prevention": "May need DISTINCT when joining through this table",
                    "severity": "MEDIUM"
                })
        
        return errors
    
    def detect_empty_table_errors(self) -> List[Dict[str, Any]]:
        """Detect empty tables that will cause query failures."""
        
        errors = []
        
        for table, info in self.schema["tables"].items():
            if info["row_count"] == 0:
                errors.append({
                    "type": "empty_table",
                    "table": table,
                    "warning": f"Table '{table}' is EMPTY",
                    "prevention": f"Avoid querying {table} - it will return no results",
                    "severity": "CRITICAL"
                })
        
        return errors
    
    def detect_case_sensitivity_issues(self) -> List[Dict[str, Any]]:
        """Detect columns with case-sensitive values."""
        
        issues = []
        
        for table, info in self.schema["tables"].items():
            if info["row_count"] == 0:
                continue
                
            for col, col_info in info["columns"].items():
                if col_info["type"] in ["TEXT", "VARCHAR"]:
                    # Sample values to check case sensitivity
                    try:
                        query = f"SELECT DISTINCT {col} FROM {table} WHERE {col} IS NOT NULL LIMIT 20"
                        self.cursor.execute(query)
                        values = [row[0] for row in self.cursor.fetchall()]
                        
                        if values:
                            # Check if lowercasing changes uniqueness
                            unique_original = len(set(values))
                            unique_lower = len(set(v.lower() for v in values if isinstance(v, str)))
                            
                            if unique_original != unique_lower:
                                issues.append({
                                    "type": "case_sensitive",
                                    "table": table,
                                    "column": col,
                                    "warning": f"Column {table}.{col} has case-sensitive values",
                                    "prevention": "Use exact case from database values",
                                    "examples": values[:3],
                                    "severity": "HIGH"
                                })
                    except:
                        pass  # Skip problematic columns
        
        return issues
    
    def generate_prevention_checklist(self) -> List[Dict[str, str]]:
        """Generate a checklist to prevent common errors."""
        
        checklist = [
            {
                "check": "Return EXACTLY the requested columns",
                "why": "Most common error - adding unrequested columns",
                "priority": "CRITICAL"
            },
            {
                "check": "Use = for exact matching by default",
                "why": "LIKE causes missed matches when exact value exists",
                "priority": "CRITICAL"
            },
            {
                "check": "Preserve exact case for text values",
                "why": "Database is case-sensitive",
                "priority": "HIGH"
            },
            {
                "check": "Include GROUP BY for mixed aggregate/non-aggregate",
                "why": "SQL syntax error without it",
                "priority": "CRITICAL"
            },
            {
                "check": "Qualify ambiguous columns with table names",
                "why": "Ambiguity causes SQL errors",
                "priority": "HIGH"
            },
            {
                "check": "Check for empty tables before querying",
                "why": "Empty tables return no results",
                "priority": "MEDIUM"
            },
            {
                "check": "Use IS NULL not = NULL",
                "why": "SQL NULL comparison rules",
                "priority": "HIGH"
            },
            {
                "check": "Consider DISTINCT for junction table joins",
                "why": "Junction tables can create duplicates",
                "priority": "MEDIUM"
            }
        ]
        
        return checklist
    
    def generate_sql_antipatterns(self) -> List[Dict[str, Any]]:
        """Generate common SQL anti-patterns to avoid."""
        
        antipatterns = [
            {
                "pattern": "SELECT name, age FROM ... ORDER BY age LIMIT 1",
                "issue": "Returns extra column (age) when only name requested",
                "correct": "SELECT name FROM ... ORDER BY age LIMIT 1",
                "frequency": "VERY HIGH"
            },
            {
                "pattern": "WHERE column LIKE '%value%'",
                "issue": "Unnecessary LIKE for exact matches",
                "correct": "WHERE column = 'value'",
                "frequency": "HIGH"
            },
            {
                "pattern": "SELECT dept, COUNT(*) FROM ...",
                "issue": "Missing GROUP BY",
                "correct": "SELECT dept, COUNT(*) FROM ... GROUP BY dept",
                "frequency": "HIGH"
            },
            {
                "pattern": "WHERE column = NULL",
                "issue": "Wrong NULL comparison",
                "correct": "WHERE column IS NULL",
                "frequency": "MEDIUM"
            },
            {
                "pattern": "COUNT(*) / 6",
                "issue": "Hard-coded divisor assumes all periods have data",
                "correct": "COUNT(*) / COUNT(DISTINCT period)",
                "frequency": "MEDIUM"
            }
        ]
        
        return antipatterns
    
    def validate_database(self) -> Dict[str, Any]:
        """Run all validations and generate comprehensive report."""
        
        report = {
            "column_confusions": self.detect_column_confusion_patterns(),
            "aggregation_patterns": self.detect_aggregation_patterns(),
            "join_errors": self.detect_join_errors(),
            "empty_tables": self.detect_empty_table_errors(),
            "case_sensitivity": self.detect_case_sensitivity_issues(),
            "prevention_checklist": self.generate_prevention_checklist(),
            "sql_antipatterns": self.generate_sql_antipatterns(),
            "summary": {}
        }
        
        # Generate summary
        critical_count = sum(1 for items in [
            report["column_confusions"],
            report["empty_tables"]
        ] for item in items if item.get("severity") == "CRITICAL")
        
        high_count = sum(1 for items in [
            report["column_confusions"],
            report["join_errors"],
            report["case_sensitivity"]
        ] for item in items if item.get("severity") == "HIGH")
        
        report["summary"] = {
            "critical_issues": critical_count,
            "high_priority_issues": high_count,
            "total_warnings": sum(len(v) for k, v in report.items() if isinstance(v, list)),
            "tables_analyzed": len(self.schema["tables"]),
            "empty_tables": len(report["empty_tables"])
        }
        
        return report
    
    def close(self):
        """Close database connection."""
        self.conn.close()

def main():
    """Main execution function."""
    
    db_path = "database.sqlite"
    
    # Ensure output directory exists
    os.makedirs("tool_output", exist_ok=True)
    
    try:
        # Initialize validator
        validator = ErrorPreventionValidator(db_path)
        
        # Run validation
        report = validator.validate_database()
        
        # Save to JSON
        output_path = "tool_output/error_prevention.json"
        with open(output_path, 'w') as f:
            json.dump(report, f, indent=2, default=str)
        
        print(f"✓ Error prevention validation complete")
        print(f"  - Found {report['summary']['critical_issues']} CRITICAL issues")
        print(f"  - Found {report['summary']['high_priority_issues']} HIGH priority issues")
        print(f"  - Generated {len(report['prevention_checklist'])} prevention rules")
        print(f"  - Results saved to {output_path}")
        
        # Print critical issues
        if report["empty_tables"]:
            print("\n⚠️  CRITICAL: Empty tables detected:")
            for issue in report["empty_tables"]:
                print(f"    - {issue['table']}")
        
        if report["column_confusions"]:
            print("\n⚠️  Column confusion risks:")
            for confusion in report["column_confusions"][:3]:  # Show top 3
                print(f"    - {confusion['warning']}")
        
        validator.close()
        
    except Exception as e:
        print(f"✗ Validation failed: {str(e)}")
        # Create error report
        error_report = {
            "error": str(e),
            "partial_results": {}
        }
        with open("tool_output/error_prevention.json", 'w') as f:
            json.dump(error_report, f, indent=2)

if __name__ == "__main__":
    main()