#!/usr/bin/env python3
"""
Defensive Pattern Detector Tool
Detects SQL patterns with precomputed defaults and fallback strategies.
Focuses on the most critical patterns for SQL generation accuracy.
"""

import sqlite3
import json
import os
import sys
from typing import Dict, List, Any, Optional
import re

# Precomputed default patterns for common scenarios
DEFAULT_PATTERNS = {
    "aggregation_guidance": {
        "count_columns": ["id", "user_id", "customer_id", "order_id", "product_id", "student_id"],
        "sum_columns": ["total", "amount", "sales", "revenue", "cost", "price", "quantity"],
        "avg_columns": ["rating", "score", "grade", "percentage", "rate"],
        "rules": [
            "COUNT for entity identifiers (IDs)",
            "SUM for numeric measures (totals, amounts)",
            "AVG for rates and scores",
            "COUNT(*) for row counts",
            "COUNT(DISTINCT x) for unique values"
        ]
    },
    "join_patterns": {
        "common": [
            {"pattern": "table1.id = table2.table1_id", "description": "Standard foreign key"},
            {"pattern": "table1.id = junction.table1_id AND junction.table2_id = table2.id", "description": "Many-to-many"},
            {"pattern": "table.parent_id = table.id", "description": "Self-referencing"}
        ]
    },
    "case_sensitivity": {
        "default": "Case-insensitive comparison recommended",
        "use_lower": "Use LOWER() for text comparisons when uncertain"
    },
    "percentage_patterns": {
        "detection": "Columns containing 'percent', 'rate', or values 0-1 or 0-100",
        "calculation": "Multiply by 100.0 for percentage display"
    }
}

def safe_execute(conn, query: str, params=None) -> Optional[List]:
    """Execute query safely with timeout."""
    try:
        cursor = conn.cursor()
        conn.execute("PRAGMA query_timeout = 2000")  # 2 second timeout
        if params:
            cursor.execute(query, params)
        else:
            cursor.execute(query)
        return cursor.fetchall()
    except Exception as e:
        print(f"⚠️  Pattern query failed: {str(e)[:100]}")
        return None

def detect_aggregation_patterns(conn, schema_info: Dict, value_info: Dict) -> Dict[str, Any]:
    """Detect COUNT vs SUM patterns with fallbacks."""
    patterns = {
        "count_recommended": [],
        "sum_recommended": [],
        "avg_recommended": [],
        "ambiguous": [],
        "rules": DEFAULT_PATTERNS["aggregation_guidance"]["rules"],
        "confidence": "high"
    }

    try:
        # Analyze each table and column
        for table_name, table_data in schema_info.get("tables", {}).items():
            for col_name, col_info in table_data.get("columns", {}).items():
                full_name = f"{table_name}.{col_name}"
                col_lower = col_name.lower()
                semantic_type = col_info.get("semantic_type", "")

                # Rule-based classification
                if semantic_type == "identifier" or col_info.get("is_primary_key"):
                    patterns["count_recommended"].append(full_name)
                elif semantic_type == "measure" or any(x in col_lower for x in ["total", "amount", "sum", "sales", "revenue", "quantity"]):
                    patterns["sum_recommended"].append(full_name)
                elif semantic_type == "percentage" or any(x in col_lower for x in ["rate", "percentage", "avg", "average", "mean"]):
                    patterns["avg_recommended"].append(full_name)
                elif col_lower.endswith("_id"):
                    patterns["count_recommended"].append(full_name)
                elif col_lower.startswith("num_") or col_lower.startswith("number_"):
                    # Could be either - need to check data
                    if value_info and table_name in value_info.get("tables", {}):
                        col_values = value_info["tables"][table_name].get("columns", {}).get(col_name, {})
                        if col_values.get("looks_like_count"):
                            patterns["count_recommended"].append(full_name)
                        else:
                            patterns["sum_recommended"].append(full_name)
                    else:
                        patterns["ambiguous"].append(full_name)
                else:
                    # Check numeric columns that don't match patterns
                    col_type = col_info.get("type", "").upper()
                    if any(x in col_type for x in ["INT", "REAL", "FLOAT", "NUMERIC", "DECIMAL"]):
                        # Check if it's likely an ID or measure based on values
                        if value_info and table_name in value_info.get("tables", {}):
                            col_values = value_info["tables"][table_name].get("columns", {}).get(col_name, {})
                            distinct_count = col_values.get("distinct_count", 0)
                            if distinct_count > 100:  # High cardinality suggests ID
                                patterns["count_recommended"].append(full_name)
                            elif col_values.get("min", 0) >= 0 and col_values.get("has_negatives") == False:
                                patterns["sum_recommended"].append(full_name)
                            else:
                                patterns["ambiguous"].append(full_name)

    except Exception as e:
        print(f"⚠️  Aggregation pattern detection partial: {str(e)}")
        patterns["confidence"] = "low"
        # Use defaults
        patterns["count_recommended"] = DEFAULT_PATTERNS["aggregation_guidance"]["count_columns"]
        patterns["sum_recommended"] = DEFAULT_PATTERNS["aggregation_guidance"]["sum_columns"]
        patterns["avg_recommended"] = DEFAULT_PATTERNS["aggregation_guidance"]["avg_columns"]

    return patterns

def detect_join_patterns(conn, schema_info: Dict) -> Dict[str, Any]:
    """Detect join patterns with relationship inference."""
    patterns = {
        "explicit_foreign_keys": [],
        "inferred_relationships": [],
        "junction_tables": [],
        "self_references": [],
        "common_join_columns": [],
        "confidence": "high"
    }

    try:
        # Extract from schema foreign keys
        for fk in schema_info.get("relationships", {}).get("foreign_keys", []):
            patterns["explicit_foreign_keys"].append({
                "from": fk["from"],
                "to": fk["to"],
                "confidence": fk.get("confidence", "high")
            })

        # Find junction tables
        for join_info in schema_info.get("relationships", {}).get("inferred_joins", []):
            if join_info.get("type") == "junction":
                patterns["junction_tables"].append(join_info)

        # Detect common join patterns
        for table_name, table_info in schema_info.get("tables", {}).items():
            for col_name in table_info.get("columns", {}):
                col_lower = col_name.lower()

                # Self-reference detection
                if col_lower in ["parent_id", "manager_id", "supervisor_id", "parent", "parent_node"]:
                    patterns["self_references"].append({
                        "table": table_name,
                        "column": col_name,
                        "pattern": f"{table_name}.{col_name} = {table_name}.id"
                    })

                # Common FK patterns
                if col_lower.endswith("_id") and col_lower != "id":
                    potential_ref = col_lower[:-3]  # Remove _id
                    patterns["inferred_relationships"].append({
                        "from": f"{table_name}.{col_name}",
                        "to_table_hint": potential_ref,
                        "confidence": "medium"
                    })

                # Track columns that appear in multiple tables (for joins)
                if col_name in schema_info.get("ambiguous_columns", {}):
                    if col_name not in patterns["common_join_columns"]:
                        patterns["common_join_columns"].append({
                            "column": col_name,
                            "tables": schema_info["ambiguous_columns"][col_name]["appears_in"]
                        })

    except Exception as e:
        print(f"⚠️  Join pattern detection error: {str(e)}")
        patterns["confidence"] = "low"
        patterns["explicit_foreign_keys"] = DEFAULT_PATTERNS["join_patterns"]["common"]

    return patterns

def detect_query_patterns(conn, value_info: Dict) -> Dict[str, Any]:
    """Detect common query patterns and pitfalls."""
    patterns = {
        "case_sensitivity": {},
        "percentage_columns": [],
        "boolean_patterns": {},
        "date_formats": {},
        "null_heavy_columns": [],
        "warnings": [],
        "recommendations": []
    }

    try:
        # Analyze case patterns from value extraction
        if value_info:
            case_patterns = value_info.get("global_patterns", {}).get("case_patterns", {})
            if case_patterns:
                # Determine if case matters
                has_uppercase = "UPPERCASE" in case_patterns
                has_lowercase = "lowercase" in case_patterns
                has_mixed = len(case_patterns) > 1

                if has_mixed:
                    patterns["case_sensitivity"] = {
                        "matters": True,
                        "recommendation": "Use exact case from sample values",
                        "patterns_found": list(case_patterns.keys())
                    }
                else:
                    patterns["case_sensitivity"] = {
                        "matters": False,
                        "recommendation": "Case-insensitive comparisons safe",
                        "dominant_pattern": list(case_patterns.keys())[0] if case_patterns else "unknown"
                    }
            else:
                patterns["case_sensitivity"] = DEFAULT_PATTERNS["case_sensitivity"]

            # Find percentage columns
            for table_name, table_data in value_info.get("tables", {}).items():
                for col_name, col_data in table_data.get("columns", {}).items():
                    if col_data.get("looks_like_percentage"):
                        patterns["percentage_columns"].append({
                            "column": f"{table_name}.{col_name}",
                            "format": col_data.get("percentage_format", "unknown"),
                            "range": f"{col_data.get('min', 'unknown')}-{col_data.get('max', 'unknown')}"
                        })

                    # Boolean patterns
                    if col_data.get("data_type_inferred") == "boolean":
                        sample_vals = [s["value"] if isinstance(s, dict) else s
                                      for s in col_data.get("sample_values", [])]
                        patterns["boolean_patterns"][f"{table_name}.{col_name}"] = sample_vals[:3]

                    # Date formats
                    if col_data.get("data_type_inferred") == "date":
                        sample_vals = [s["value"] if isinstance(s, dict) else s
                                      for s in col_data.get("sample_values", [])]
                        if sample_vals:
                            # Detect format from first sample
                            first_date = str(sample_vals[0])
                            if re.match(r'\d{4}-\d{2}-\d{2}', first_date):
                                format_detected = "YYYY-MM-DD"
                            elif re.match(r'\d{2}/\d{2}/\d{4}', first_date):
                                format_detected = "MM/DD/YYYY"
                            else:
                                format_detected = "unknown"
                            patterns["date_formats"][f"{table_name}.{col_name}"] = format_detected

                    # NULL-heavy columns
                    null_count = col_data.get("null_count", 0)
                    if null_count > 0:
                        if table_data.get("row_count"):
                            row_count = table_data["row_count"]
                            if isinstance(row_count, int) and row_count > 0:
                                null_percentage = (null_count / row_count) * 100
                                if null_percentage > 50:
                                    patterns["null_heavy_columns"].append({
                                        "column": f"{table_name}.{col_name}",
                                        "null_percentage": round(null_percentage, 1)
                                    })

    except Exception as e:
        print(f"⚠️  Query pattern detection error: {str(e)}")
        patterns["warnings"].append(f"Pattern detection incomplete: {str(e)}")

    # Add general recommendations
    patterns["recommendations"] = [
        "Use exact column names from schema",
        "Qualify ambiguous columns with table names",
        "Check sample values for exact case matching",
        "Use COUNT(*) for row counts, not COUNT(column)",
        "Multiply by 100.0 for percentage calculations",
        "Use IS NULL / IS NOT NULL for NULL checks"
    ]

    return patterns

def detect_patterns(db_path="database.sqlite"):
    """Main pattern detection with defensive strategies."""

    os.makedirs("tool_output", exist_ok=True)

    pattern_result = {
        "aggregation_patterns": {},
        "join_patterns": {},
        "query_patterns": {},
        "critical_rules": [],
        "summary": {
            "detection_quality": "complete",
            "patterns_found": 0,
            "warnings": []
        }
    }

    conn = None
    try:
        # Load previous analysis results
        schema_info = {}
        value_info = {}

        if os.path.exists("tool_output/schema_analysis.json"):
            with open("tool_output/schema_analysis.json", 'r') as f:
                schema_info = json.load(f)
        else:
            pattern_result["summary"]["warnings"].append("No schema analysis found - using defaults")
            pattern_result["summary"]["detection_quality"] = "degraded"

        if os.path.exists("tool_output/value_samples.json"):
            with open("tool_output/value_samples.json", 'r') as f:
                value_info = json.load(f)
        else:
            pattern_result["summary"]["warnings"].append("No value samples found - patterns limited")
            pattern_result["summary"]["detection_quality"] = "partial"

        # Connect to database
        conn = sqlite3.connect(db_path, timeout=30.0)
        print("✓ Connected for pattern detection")

        # Detect patterns
        print("  Detecting aggregation patterns...")
        pattern_result["aggregation_patterns"] = detect_aggregation_patterns(conn, schema_info, value_info)
        pattern_result["summary"]["patterns_found"] += len(pattern_result["aggregation_patterns"].get("count_recommended", []))
        pattern_result["summary"]["patterns_found"] += len(pattern_result["aggregation_patterns"].get("sum_recommended", []))

        print("  Detecting join patterns...")
        pattern_result["join_patterns"] = detect_join_patterns(conn, schema_info)
        pattern_result["summary"]["patterns_found"] += len(pattern_result["join_patterns"].get("explicit_foreign_keys", []))

        print("  Detecting query patterns...")
        pattern_result["query_patterns"] = detect_query_patterns(conn, value_info)
        pattern_result["summary"]["patterns_found"] += len(pattern_result["query_patterns"].get("percentage_columns", []))

        # Generate critical rules
        pattern_result["critical_rules"] = [
            "ALWAYS return only the columns explicitly requested",
            "Use COUNT(*) for counting rows, not COUNT(column_name)",
            "Use SUM for numeric measures, COUNT for entities",
            "Multiply by 100.0 for percentage calculations",
            "Check case sensitivity - use exact values from samples",
            "Qualify ambiguous column names with table prefix",
            "Use IS NULL / IS NOT NULL for NULL checks",
            "Apply evidence formulas exactly as specified"
        ]

        # Add specific warnings based on findings
        if pattern_result["query_patterns"].get("null_heavy_columns"):
            pattern_result["critical_rules"].append(
                f"WARNING: {len(pattern_result['query_patterns']['null_heavy_columns'])} columns have >50% NULLs"
            )

        if pattern_result["aggregation_patterns"].get("ambiguous"):
            pattern_result["critical_rules"].append(
                f"AMBIGUOUS: {len(pattern_result['aggregation_patterns']['ambiguous'])} columns need context for COUNT/SUM decision"
            )

    except Exception as e:
        pattern_result["summary"]["detection_quality"] = "failed"
        pattern_result["summary"]["warnings"].append(f"Critical error: {str(e)}")
        print(f"❌ Pattern detection failed: {str(e)}")

        # Use all defaults
        pattern_result["aggregation_patterns"] = DEFAULT_PATTERNS["aggregation_guidance"]
        pattern_result["join_patterns"] = DEFAULT_PATTERNS["join_patterns"]
        pattern_result["query_patterns"] = DEFAULT_PATTERNS

    finally:
        if conn:
            conn.close()

    # Save results
    output_path = "tool_output/patterns.json"
    try:
        with open(output_path, 'w') as f:
            json.dump(pattern_result, f, indent=2, default=str)
        print(f"\n✓ Pattern detection saved to {output_path}")
        print(f"  Quality: {pattern_result['summary']['detection_quality']}")
        print(f"  Patterns: {pattern_result['summary']['patterns_found']}")

        if pattern_result["summary"]["warnings"]:
            print(f"  ⚠️  {len(pattern_result['summary']['warnings'])} warnings")

    except Exception as e:
        print(f"❌ Failed to save patterns: {str(e)}")
        try:
            with open("tool_output/patterns.error", 'w') as f:
                f.write(f"Pattern detection failed: {str(e)}\n")
        except:
            pass

    return pattern_result

if __name__ == "__main__":
    result = detect_patterns()

    if result["summary"]["detection_quality"] == "failed":
        sys.exit(1)
    else:
        sys.exit(0)