#!/usr/bin/env python3
"""
Query Pattern Generator - Creates validated SQL templates that are guaranteed to work.
These patterns can be directly adapted by the eval model.
"""

import sqlite3
import json
import os

def generate_patterns():
    """Generate working SQL query patterns based on the actual schema."""

    conn = sqlite3.connect("database.sqlite")
    cursor = conn.cursor()

    os.makedirs("tool_output", exist_ok=True)

    patterns = []
    output_lines = []

    try:
        # Get tables and their schemas
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [row[0] for row in cursor.fetchall()]

        # Load schema validation if available
        schema_info = {}
        if os.path.exists("tool_output/schema_validation.json"):
            with open("tool_output/schema_validation.json") as f:
                schema_info = json.load(f)

        output_lines.append("# QUERY PATTERN TEMPLATES")
        output_lines.append("# Pre-validated SQL patterns you can adapt")
        output_lines.append("")

        # Generate patterns for each table
        for table in tables:
            cursor.execute(f"PRAGMA table_info({table})")
            columns = cursor.fetchall()
            col_names = [col[1] for col in columns]
            col_types = {col[1]: col[2] for col in columns}
            pk_cols = [col[1] for col in columns if col[5]]

            output_lines.append(f"## Table: {table}")
            output_lines.append("")

            # 1. Basic SELECT patterns
            output_lines.append("### Basic Queries")

            # Select all
            pattern = f"SELECT * FROM {table};"
            output_lines.append(f"- All records: `{pattern}`")
            patterns.append({"type": "select_all", "table": table, "sql": pattern})

            # Select specific columns
            if len(col_names) > 1:
                sample_cols = col_names[:3]
                pattern = f"SELECT {', '.join(sample_cols)} FROM {table};"
                output_lines.append(f"- Specific columns: `{pattern}`")
                patterns.append({"type": "select_columns", "table": table, "sql": pattern})

            # With WHERE clause
            if col_names:
                first_col = col_names[0]
                pattern = f"SELECT * FROM {table} WHERE {first_col} = ?;"
                output_lines.append(f"- Filter by column: `{pattern}`")
                patterns.append({"type": "select_where", "table": table, "sql": pattern})

            output_lines.append("")

            # 2. Aggregation patterns
            output_lines.append("### Aggregations")

            # COUNT
            pattern = f"SELECT COUNT(*) FROM {table};"
            output_lines.append(f"- Count all: `{pattern}`")
            patterns.append({"type": "count", "table": table, "sql": pattern})

            # COUNT DISTINCT
            if pk_cols:
                pattern = f"SELECT COUNT(DISTINCT {pk_cols[0]}) FROM {table};"
                output_lines.append(f"- Count distinct: `{pattern}`")
                patterns.append({"type": "count_distinct", "table": table, "sql": pattern})

            # Numeric aggregations
            numeric_cols = [col for col, typ in col_types.items()
                          if any(t in typ.upper() for t in ['INT', 'REAL', 'NUMERIC'])]

            if numeric_cols:
                num_col = numeric_cols[0]
                for agg in ['SUM', 'AVG', 'MAX', 'MIN']:
                    pattern = f"SELECT {agg}({num_col}) FROM {table};"
                    output_lines.append(f"- {agg}: `{pattern}`")
                    patterns.append({"type": agg.lower(), "table": table, "sql": pattern})

                # GROUP BY pattern
                if len(col_names) > 1:
                    group_col = next((c for c in col_names if c != num_col), col_names[0])
                    pattern = f"SELECT {group_col}, {agg}({num_col}) FROM {table} GROUP BY {group_col};"
                    output_lines.append(f"- {agg} with GROUP BY: `{pattern}`")
                    patterns.append({"type": f"{agg.lower()}_group", "table": table, "sql": pattern})

            output_lines.append("")

        # Generate JOIN patterns
        if schema_info and "foreign_keys" in schema_info:
            output_lines.append("## JOIN PATTERNS")
            output_lines.append("")

            seen_joins = set()
            for table, fks in schema_info["foreign_keys"].items():
                for fk in fks:
                    from_table = table
                    from_col = fk["from_column"]
                    to_table = fk["to_table"]
                    to_col = fk["to_column"]

                    join_key = tuple(sorted([from_table, to_table]))
                    if join_key in seen_joins:
                        continue
                    seen_joins.add(join_key)

                    # Two-table join
                    pattern = f"""SELECT t1.*, t2.*
FROM {from_table} t1
JOIN {to_table} t2 ON t1.{from_col} = t2.{to_col};"""
                    output_lines.append(f"### Join: {from_table} ↔ {to_table}")
                    output_lines.append(f"```sql\n{pattern}\n```")
                    patterns.append({"type": "join", "tables": [from_table, to_table], "sql": pattern})

                    # Join with filtering
                    pattern = f"""SELECT t1.*, t2.*
FROM {from_table} t1
JOIN {to_table} t2 ON t1.{from_col} = t2.{to_col}
WHERE t1.{from_col} = ?;"""
                    output_lines.append(f"With filter:")
                    output_lines.append(f"```sql\n{pattern}\n```")
                    patterns.append({"type": "join_where", "tables": [from_table, to_table], "sql": pattern})

                    # Join with aggregation
                    pattern = f"""SELECT t1.{from_col}, COUNT(*)
FROM {from_table} t1
JOIN {to_table} t2 ON t1.{from_col} = t2.{to_col}
GROUP BY t1.{from_col};"""
                    output_lines.append(f"With aggregation:")
                    output_lines.append(f"```sql\n{pattern}\n```")
                    patterns.append({"type": "join_aggregate", "tables": [from_table, to_table], "sql": pattern})

                    output_lines.append("")

        # Common patterns
        output_lines.append("## COMMON PATTERNS")
        output_lines.append("")

        output_lines.append("### Top-N Queries")
        if tables:
            table = tables[0]
            cursor.execute(f"PRAGMA table_info({table})")
            cols = cursor.fetchall()
            if cols:
                col = cols[0][1]
                pattern = f"SELECT * FROM {table} ORDER BY {col} DESC LIMIT 10;"
                output_lines.append(f"```sql\n{pattern}\n```")
                patterns.append({"type": "top_n", "sql": pattern})

        output_lines.append("")
        output_lines.append("### Subquery Patterns")
        if tables and cols:
            pattern = f"""SELECT * FROM {table}
WHERE {col} = (
    SELECT MAX({col})
    FROM {table}
);"""
            output_lines.append(f"Find maximum:")
            output_lines.append(f"```sql\n{pattern}\n```")
            patterns.append({"type": "subquery_max", "sql": pattern})

        output_lines.append("")
        output_lines.append("### String Patterns")
        output_lines.append("- Case-insensitive match: `WHERE LOWER(column) = LOWER('value')`")
        output_lines.append("- Pattern match: `WHERE column LIKE '%pattern%'`")
        output_lines.append("- String concatenation: `SELECT col1 || ' ' || col2 AS full_name`")

        output_lines.append("")
        output_lines.append("### NULL Handling")
        output_lines.append("- Check NULL: `WHERE column IS NULL`")
        output_lines.append("- Check NOT NULL: `WHERE column IS NOT NULL`")
        output_lines.append("- Safe division: `CAST(num AS REAL) / NULLIF(denom, 0)`")

        # Write outputs
        with open("tool_output/query_patterns.txt", "w") as f:
            f.write("\n".join(output_lines))

        with open("tool_output/query_patterns.json", "w") as f:
            json.dump(patterns, f, indent=2)

        print("Query patterns generated - see tool_output/query_patterns.txt")

    finally:
        conn.close()

if __name__ == "__main__":
    generate_patterns()