#!/usr/bin/env python3
"""
Query Pattern Matcher Tool
Generates SQL templates for common query patterns based on the database structure.
Helps avoid structural errors and provides proven patterns.
"""

import sqlite3
import json
import os

def generate_query_patterns(db_path="database.sqlite"):
    """Generate SQL pattern templates for common queries."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    patterns = {
        "count_patterns": [],
        "top_patterns": [],
        "aggregation_patterns": [],
        "join_patterns": [],
        "group_by_patterns": [],
        "disambiguation_rules": []
    }

    try:
        # Get all tables and their structure
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall() if not row[0].startswith("sqlite_")]

        # Analyze each table for pattern generation
        table_info = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()
            table_info[table] = {
                "columns": [col[1] for col in columns],
                "types": {col[1]: col[2] for col in columns},
                "pk": [col[1] for col in columns if col[5]]
            }

        # Generate COUNT patterns
        patterns["count_patterns"] = [
            {
                "name": "Simple count all",
                "when": "How many total records",
                "template": "SELECT COUNT(*) FROM table",
                "example": f"SELECT COUNT(*) FROM `{tables[0] if tables else 'table'}`",
                "warning": "Use COUNT(*) not COUNT(column) for total count"
            },
            {
                "name": "Count with condition",
                "when": "How many records match condition",
                "template": "SELECT COUNT(*) FROM table WHERE condition",
                "example": "SELECT COUNT(*) FROM customers WHERE age > 30",
                "warning": "Still use COUNT(*) not COUNT(column)"
            },
            {
                "name": "Count distinct values",
                "when": "Question contains 'unique', 'different', or 'distinct'",
                "template": "SELECT COUNT(DISTINCT column) FROM table",
                "example": "SELECT COUNT(DISTINCT customer_id) FROM orders",
                "warning": "Only use DISTINCT when explicitly needed"
            },
            {
                "name": "Conditional count",
                "when": "Count based on complex condition",
                "template": "SELECT COUNT(CASE WHEN condition THEN 1 END) FROM table",
                "example": "SELECT COUNT(CASE WHEN status = 'active' THEN 1 END) FROM users",
                "warning": "Use CASE WHEN for conditional counting, not SUM"
            }
        ]

        # Generate TOP/MOST patterns
        patterns["top_patterns"] = [
            {
                "name": "Entity with highest value",
                "when": "Which/What entity has the most/highest",
                "template": "SELECT entity_column FROM table ORDER BY metric DESC LIMIT 1",
                "example": "SELECT customer_name FROM customers ORDER BY total_purchases DESC LIMIT 1",
                "warning": "Use ORDER BY LIMIT, NOT subquery with MAX"
            },
            {
                "name": "Entity with most related records",
                "when": "Entity with most associations",
                "template": "SELECT t1.name FROM table1 t1 JOIN table2 t2 ON t1.id = t2.fk GROUP BY t1.id ORDER BY COUNT(*) DESC LIMIT 1",
                "example": "SELECT c.name FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id ORDER BY COUNT(*) DESC LIMIT 1",
                "warning": "GROUP BY before ORDER BY"
            },
            {
                "name": "Top N entities",
                "when": "Top/First N entities by metric",
                "template": "SELECT entity_column FROM table ORDER BY metric DESC LIMIT N",
                "example": "SELECT product_name FROM products ORDER BY price DESC LIMIT 5",
                "warning": "LIMIT comes after ORDER BY"
            }
        ]

        # Generate aggregation patterns
        patterns["aggregation_patterns"] = [
            {
                "name": "SUM for totals",
                "when": "Total/sum of numeric values",
                "template": "SELECT SUM(numeric_column) FROM table",
                "example": "SELECT SUM(amount) FROM transactions",
                "warning": "Use SUM for values, not for counting records"
            },
            {
                "name": "COUNT for occurrences",
                "when": "Number of records/rows/occurrences",
                "template": "SELECT COUNT(*) FROM table",
                "example": "SELECT COUNT(*) FROM orders WHERE status = 'completed'",
                "warning": "Use COUNT not SUM for counting records"
            },
            {
                "name": "AVG for averages",
                "when": "Average/mean value",
                "template": "SELECT AVG(numeric_column) FROM table",
                "example": "SELECT AVG(price) FROM products",
                "warning": "AVG ignores NULL values"
            },
            {
                "name": "Percentage calculation",
                "when": "Percentage or ratio",
                "template": "SELECT CAST(COUNT(CASE WHEN condition THEN 1 END) AS REAL) * 100.0 / COUNT(*)",
                "example": "SELECT CAST(COUNT(CASE WHEN status = 'active' THEN 1 END) AS REAL) * 100.0 / COUNT(*) FROM users",
                "warning": "Cast to REAL for decimal division"
            }
        ]

        # Generate GROUP BY patterns
        patterns["group_by_patterns"] = [
            {
                "name": "Group by with aggregate",
                "when": "Aggregate per group",
                "template": "SELECT group_column, AGG(metric) FROM table GROUP BY group_column",
                "example": "SELECT category, COUNT(*) FROM products GROUP BY category",
                "warning": "All non-aggregate columns must be in GROUP BY"
            },
            {
                "name": "Group by with HAVING",
                "when": "Filter groups by aggregate condition",
                "template": "SELECT group_column, AGG(metric) FROM table GROUP BY group_column HAVING AGG(metric) > value",
                "example": "SELECT customer_id, COUNT(*) FROM orders GROUP BY customer_id HAVING COUNT(*) > 5",
                "warning": "HAVING comes after GROUP BY"
            },
            {
                "name": "Multiple column grouping",
                "when": "Group by multiple dimensions",
                "template": "SELECT col1, col2, AGG(metric) FROM table GROUP BY col1, col2",
                "example": "SELECT year, month, SUM(sales) FROM transactions GROUP BY year, month",
                "warning": "Include all grouping columns in SELECT and GROUP BY"
            }
        ]

        # Generate disambiguation rules
        patterns["disambiguation_rules"] = [
            {
                "confusion": "COUNT vs COUNT(DISTINCT)",
                "rule": "Use COUNT(DISTINCT) ONLY when 'unique', 'different', or 'distinct' is explicitly mentioned",
                "examples": [
                    "How many orders → COUNT(*)",
                    "How many unique customers → COUNT(DISTINCT customer_id)"
                ]
            },
            {
                "confusion": "SUM vs COUNT",
                "rule": "SUM for numeric values, COUNT for records",
                "examples": [
                    "Total sales amount → SUM(amount)",
                    "Total number of sales → COUNT(*)",
                    "Sum of quantities → SUM(quantity)",
                    "Number of items → COUNT(*)"
                ]
            },
            {
                "confusion": "WHERE subquery vs ORDER BY LIMIT",
                "rule": "Use ORDER BY LIMIT for finding top/most/highest entity",
                "examples": [
                    "Wrong: WHERE salary = (SELECT MAX(salary) FROM ...)",
                    "Right: ORDER BY salary DESC LIMIT 1"
                ]
            },
            {
                "confusion": "Column selection",
                "rule": "Return ONLY what's asked, no extra columns",
                "examples": [
                    "Asked for name → SELECT name (not SELECT id, name, description)",
                    "Asked for count → SELECT COUNT(*) (not SELECT COUNT(*), category)"
                ]
            },
            {
                "confusion": "JOIN necessity",
                "rule": "Only JOIN when data from multiple tables needed",
                "examples": [
                    "Count in single table → No JOIN needed",
                    "Data from related tables → JOIN required"
                ]
            }
        ]

        # Add table-specific patterns
        for table in tables[:3]:  # Limit to first 3 tables for brevity
            if table_info[table]["pk"]:
                pk = table_info[table]["pk"][0]
                patterns["join_patterns"].append({
                    "table": table,
                    "primary_key": pk,
                    "template": f"FROM `{table}` T1 JOIN other_table T2 ON T1.`{pk}` = T2.foreign_key",
                    "note": "Use backticks if table name is reserved word"
                })

    except Exception as e:
        patterns["error"] = str(e)
    finally:
        conn.close()

    # Write to output file
    output_path = "tool_output/query_patterns.json"
    with open(output_path, 'w') as f:
        json.dump(patterns, f, indent=2)

    print(f"Query pattern generation complete - results in {output_path}")
    print(f"Generated {sum(len(v) for v in patterns.values() if isinstance(v, list))} pattern templates")
    print(f"Created {len(patterns['disambiguation_rules'])} disambiguation rules")

if __name__ == "__main__":
    generate_query_patterns()