#!/usr/bin/env python3
"""
Query Pattern Generator Tool
Generates database-specific SQL patterns and templates.
"""

import sqlite3
import json
import os

def generate_patterns(db_path="database.sqlite"):
    """Generate SQL patterns specific to this database."""

    os.makedirs("tool_output", exist_ok=True)

    # Load previous analysis if available
    schema_file = "tool_output/schema_analysis.json"
    relationships_file = "tool_output/relationships.json"
    patterns_file = "tool_output/patterns.json"

    schema_info = {}
    relationships = {}
    patterns = {}

    if os.path.exists(schema_file):
        with open(schema_file) as f:
            schema_info = json.load(f)

    if os.path.exists(relationships_file):
        with open(relationships_file) as f:
            relationships = json.load(f)

    if os.path.exists(patterns_file):
        with open(patterns_file) as f:
            patterns = json.load(f)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    query_patterns = {
        "basic_patterns": [],
        "join_patterns": [],
        "aggregation_patterns": [],
        "filter_patterns": [],
        "common_mistakes": [],
        "column_selection_patterns": []
    }

    try:
        # Generate basic SELECT patterns for each table
        for table_name, table_info in schema_info.get("tables", {}).items():
            if table_info["row_count"] == 0:
                continue

            # Basic select all
            query_patterns["basic_patterns"].append({
                "description": f"Select all from {table_name}",
                "pattern": f"SELECT * FROM {table_name}",
                "use_case": "Retrieve all data from table"
            })

            # Select specific columns
            columns = list(table_info["columns"].keys())[:3]  # First 3 columns
            if columns:
                query_patterns["basic_patterns"].append({
                    "description": f"Select specific columns from {table_name}",
                    "pattern": f"SELECT {', '.join(columns)} FROM {table_name}",
                    "use_case": "Return only requested columns"
                })

            # Count pattern
            query_patterns["aggregation_patterns"].append({
                "description": f"Count rows in {table_name}",
                "pattern": f"SELECT COUNT(*) FROM {table_name}",
                "use_case": f"How many {table_name}",
                "important": "Use COUNT(*) not COUNT(column) unless counting non-null values"
            })

        # Generate JOIN patterns from relationships
        for fk in relationships.get("foreign_keys", [])[:15]:  # First 15 relationships
            query_patterns["join_patterns"].append({
                "tables": f"{fk['from_table']} → {fk['to_table']}",
                "pattern": f"""SELECT t1.*, t2.*
FROM {fk['from_table']} t1
JOIN {fk['to_table']} t2 ON t1.{fk['from_column']} = t2.{fk['to_column']}""",
                "key_columns": f"{fk['from_column']} = {fk['to_column']}"
            })

        # Generate patterns from aggregation guidance
        for table_name, agg_info in patterns.get("aggregation_guidance", {}).items():
            # SUM patterns for measure columns
            for sum_col in agg_info.get("sum_columns", [])[:2]:
                query_patterns["aggregation_patterns"].append({
                    "description": f"Total {sum_col['column']} in {table_name}",
                    "pattern": f"SELECT SUM({sum_col['column']}) FROM {table_name}",
                    "important": f"SUM not COUNT - {sum_col['reason']}"
                })

            # GROUP BY patterns
            for group_col in agg_info.get("group_by_columns", [])[:2]:
                query_patterns["aggregation_patterns"].append({
                    "description": f"Group {table_name} by {group_col['column']}",
                    "pattern": f"""SELECT {group_col['column']}, COUNT(*)
FROM {table_name}
GROUP BY {group_col['column']}""",
                    "selectivity": group_col["selectivity"]
                })

        # Column selection patterns (critical for accuracy)
        query_patterns["column_selection_patterns"] = [
            {
                "question_pattern": "What is/are the X",
                "sql_pattern": "SELECT X",
                "wrong_pattern": "SELECT * or SELECT X, Y",
                "example": "What is the name? → SELECT name (NOT SELECT id, name)"
            },
            {
                "question_pattern": "List the X",
                "sql_pattern": "SELECT X",
                "wrong_pattern": "SELECT * FROM",
                "example": "List the titles → SELECT title (NOT SELECT *)"
            },
            {
                "question_pattern": "How many",
                "sql_pattern": "SELECT COUNT(*) or COUNT(DISTINCT column)",
                "wrong_pattern": "SELECT COUNT(*), column",
                "example": "How many students? → SELECT COUNT(*) (NOT SELECT COUNT(*), student_name)"
            },
            {
                "question_pattern": "Show X and Y",
                "sql_pattern": "SELECT X, Y",
                "note": "Return ALL requested columns but ONLY requested columns"
            },
            {
                "question_pattern": "Which X has the most Y",
                "sql_pattern": "SELECT X ... ORDER BY COUNT(Y) DESC LIMIT 1",
                "wrong_pattern": "SELECT X, COUNT(Y)",
                "example": "Which teacher has most students? → SELECT teacher_name (NOT teacher_name, COUNT(*))"
            }
        ]

        # Common mistakes to avoid
        query_patterns["common_mistakes"] = [
            {
                "mistake": "Returning extra columns",
                "example": "Question asks for name, query returns id AND name",
                "fix": "Return ONLY what's explicitly requested"
            },
            {
                "mistake": "Using COUNT for measures",
                "example": "COUNT(total_amount) instead of SUM(total_amount)",
                "fix": "COUNT counts rows, SUM adds values"
            },
            {
                "mistake": "Missing percentage multiplication",
                "example": "Returning 0.15 instead of 15%",
                "fix": "Multiply by 100.0 for percentages"
            },
            {
                "mistake": "Wrong inequality operator",
                "example": "Using < when evidence says <=",
                "fix": "Use exact operator from evidence"
            },
            {
                "mistake": "Adding unnecessary filters",
                "example": "Adding WHERE value > 0 when not requested",
                "fix": "Only add filters explicitly mentioned"
            },
            {
                "mistake": "Case sensitivity errors",
                "example": "WHERE name = 'john' when actual value is 'John'",
                "fix": "Use exact case from value samples"
            }
        ]

        # Percentage calculation patterns
        query_patterns["aggregation_patterns"].extend([
            {
                "description": "Calculate percentage (multiply by 100)",
                "pattern": """SELECT (COUNT(CASE WHEN condition THEN 1 END) * 100.0 / COUNT(*)) AS percentage""",
                "important": "Use 100.0 not 1.0 for percentages"
            },
            {
                "description": "Calculate ratio (decimal)",
                "pattern": """SELECT (COUNT(CASE WHEN condition THEN 1 END) * 1.0 / COUNT(*)) AS ratio""",
                "important": "Use 1.0 for ratios, 100.0 for percentages"
            },
            {
                "description": "Percentage with CAST",
                "pattern": """SELECT CAST(SUM(CASE WHEN condition THEN 1 ELSE 0 END) AS REAL) * 100 / COUNT(*)""",
                "use_case": "Alternative percentage calculation"
            }
        ])

        # EXISTS patterns for "all" conditions
        query_patterns["filter_patterns"].extend([
            {
                "description": "All items satisfy condition",
                "pattern": """SELECT * FROM table t1
WHERE NOT EXISTS (
    SELECT 1 FROM related_table t2
    WHERE t2.foreign_key = t1.id
    AND NOT (condition)
)""",
                "example": "Students with A in ALL courses"
            },
            {
                "description": "At least one satisfies",
                "pattern": """SELECT * FROM table t1
WHERE EXISTS (
    SELECT 1 FROM related_table t2
    WHERE t2.foreign_key = t1.id
    AND (condition)
)""",
                "example": "Students with at least one A"
            }
        ])

        # Save comprehensive patterns
        with open("tool_output/query_patterns.json", "w") as f:
            json.dump(query_patterns, f, indent=2)

        print(f"✓ Query pattern generation complete")
        print(f"✓ Generated {len(query_patterns['join_patterns'])} join patterns")
        print(f"✓ Created {len(query_patterns['aggregation_patterns'])} aggregation patterns")
        print(f"✓ Documented {len(query_patterns['common_mistakes'])} common mistakes")

    except Exception as e:
        print(f"✗ Query pattern generation failed: {str(e)}")
        raise
    finally:
        conn.close()

if __name__ == "__main__":
    generate_patterns()