#!/usr/bin/env python3
"""
Pattern Selector - Context-aware query pattern generation
Combines patterns from iter8 with adaptive complexity
"""

import sqlite3
import os

def select_patterns():
    """Generate context-appropriate query patterns."""
    
    try:
        conn = sqlite3.connect("database.sqlite")
        cursor = conn.cursor()
        
        output = []
        output.append("# QUERY PATTERNS")
        output.append("(Context-specific templates for common operations)")
        output.append("")
        
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall()]
        
        # Analyze database complexity
        cursor.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table'")
        table_count = cursor.fetchone()[0]
        
        # Count total columns
        total_columns = 0
        has_dates = False
        has_aggregates = False
        
        for table in tables:
            table_escaped = f"`{table}`" if any(c in table for c in [' ', '-', '.']) else table
            try:
                cursor.execute(f"PRAGMA table_info({table_escaped})")
                columns = cursor.fetchall()
                total_columns += len(columns)
                
                for col in columns:
                    col_type = col[2] or ''
                    if 'DATE' in col_type.upper() or 'TIME' in col_type.upper():
                        has_dates = True
                    if any(agg in col[1].lower() for agg in ['sum', 'total', 'amount', 'count']):
                        has_aggregates = True
            except:
                pass
        
        # Determine complexity level
        if table_count <= 5 and total_columns <= 30:
            complexity = "SIMPLE"
        elif table_count <= 15 and total_columns <= 100:
            complexity = "MODERATE"
        else:
            complexity = "COMPLEX"
        
        output.append(f"## Database Complexity: {complexity}")
        output.append(f"Tables: {table_count}, Total Columns: {total_columns}")
        output.append("")
        
        # Basic patterns (always include)
        output.append("## BASIC PATTERNS")
        output.append("")
        
        output.append("### Simple Selection")
        output.append("```sql")
        output.append("SELECT column1, column2")
        output.append("FROM table")
        output.append("WHERE condition;")
        output.append("```")
        output.append("")
        
        output.append("### Count Records")
        output.append("```sql")
        output.append("SELECT COUNT(*)")
        output.append("FROM table")
        output.append("WHERE condition;")
        output.append("```")
        output.append("")
        
        # Aggregation patterns
        output.append("## AGGREGATION PATTERNS")
        output.append("")
        
        output.append("### Sum with NULL safety")
        output.append("```sql")
        output.append("SELECT SUM(COALESCE(column, 0))")
        output.append("FROM table;")
        output.append("```")
        output.append("")
        
        output.append("### Average excluding zeros")
        output.append("```sql")
        output.append("SELECT AVG(NULLIF(column, 0))")
        output.append("FROM table;")
        output.append("```")
        output.append("")
        
        output.append("### Percentage calculation")
        output.append("```sql")
        output.append("SELECT CAST(COUNT(CASE WHEN condition THEN 1 END) AS REAL) * 100 / NULLIF(COUNT(*), 0)")
        output.append("FROM table;")
        output.append("```")
        output.append("")
        
        # JOIN patterns (for moderate/complex databases)
        if complexity in ["MODERATE", "COMPLEX"]:
            output.append("## JOIN PATTERNS")
            output.append("")
            
            # Find actual foreign keys for examples
            fk_examples = []
            for table in tables[:3]:  # Limit to first 3 tables
                table_escaped = f"`{table}`" if any(c in table for c in [' ', '-', '.']) else table
                try:
                    cursor.execute(f"PRAGMA foreign_key_list({table_escaped})")
                    fks = cursor.fetchall()
                    for fk in fks:
                        fk_examples.append((table, fk[3], fk[2], fk[4]))
                        if len(fk_examples) >= 2:
                            break
                except:
                    pass
                if len(fk_examples) >= 2:
                    break
            
            if fk_examples:
                for from_table, from_col, to_table, to_col in fk_examples:
                    from_escaped = f"`{from_table}`" if any(c in from_table for c in [' ', '-', '.']) else from_table
                    to_escaped = f"`{to_table}`" if any(c in to_table for c in [' ', '-', '.']) else to_table
                    
                    output.append(f"### Join {from_table} with {to_table}")
                    output.append("```sql")
                    output.append(f"SELECT t1.*, t2.*")
                    output.append(f"FROM {from_escaped} t1")
                    output.append(f"JOIN {to_escaped} t2 ON t1.{from_col} = t2.{to_col}")
                    output.append(f"WHERE condition;")
                    output.append("```")
                    output.append("")
            else:
                output.append("### Generic Join")
                output.append("```sql")
                output.append("SELECT t1.*, t2.*")
                output.append("FROM table1 t1")
                output.append("JOIN table2 t2 ON t1.foreign_key = t2.id")
                output.append("WHERE condition;")
                output.append("```")
                output.append("")
        
        # Group By patterns
        output.append("## GROUP BY PATTERNS")
        output.append("")
        
        output.append("### Count by category")
        output.append("```sql")
        output.append("SELECT category, COUNT(*)")
        output.append("FROM table")
        output.append("GROUP BY category")
        output.append("ORDER BY COUNT(*) DESC;")
        output.append("```")
        output.append("")
        
        output.append("### Sum by group with HAVING")
        output.append("```sql")
        output.append("SELECT group_column, SUM(value_column)")
        output.append("FROM table")
        output.append("GROUP BY group_column")
        output.append("HAVING SUM(value_column) > threshold;")
        output.append("```")
        output.append("")
        
        # Top-N patterns
        output.append("## TOP-N PATTERNS")
        output.append("")
        
        output.append("### Top N records")
        output.append("```sql")
        output.append("SELECT *")
        output.append("FROM table")
        output.append("ORDER BY column DESC")
        output.append("LIMIT N;")
        output.append("```")
        output.append("")
        
        output.append("### Find maximum record")
        output.append("```sql")
        output.append("SELECT *")
        output.append("FROM table")
        output.append("WHERE column = (SELECT MAX(column) FROM table);")
        output.append("```")
        output.append("")
        
        # Date patterns (if dates detected)
        if has_dates:
            output.append("## DATE PATTERNS")
            output.append("")
            
            output.append("### Date range filter")
            output.append("```sql")
            output.append("SELECT *")
            output.append("FROM table")
            output.append("WHERE date_column BETWEEN '2020-01-01' AND '2020-12-31';")
            output.append("```")
            output.append("")
            
            output.append("### Extract year/month")
            output.append("```sql")
            output.append("SELECT ")
            output.append("  strftime('%Y', date_column) as year,")
            output.append("  strftime('%m', date_column) as month")
            output.append("FROM table;")
            output.append("```")
            output.append("")
        
        # Complex patterns (only for complex databases)
        if complexity == "COMPLEX":
            output.append("## ADVANCED PATTERNS")
            output.append("")
            
            output.append("### Multi-table aggregation")
            output.append("```sql")
            output.append("SELECT ")
            output.append("  t1.group_column,")
            output.append("  COUNT(DISTINCT t2.id) as count1,")
            output.append("  SUM(t3.value) as total")
            output.append("FROM table1 t1")
            output.append("LEFT JOIN table2 t2 ON t1.id = t2.foreign_key")
            output.append("LEFT JOIN table3 t3 ON t1.id = t3.foreign_key")
            output.append("GROUP BY t1.group_column;")
            output.append("```")
            output.append("")
        
        conn.close()
        
        # Write output
        os.makedirs("tool_output", exist_ok=True)
        with open("tool_output/pattern_selector_output.txt", "w", encoding='utf-8') as f:
            f.write("\n".join(output))
        
        print(f"Pattern selection complete - {complexity} database, patterns selected accordingly")
        
    except Exception as e:
        # This is non-critical, so provide basic patterns
        print(f"Pattern selection failed: {e}, using basic patterns")
        
        basic_output = [
            "# QUERY PATTERNS",
            "",
            "## BASIC PATTERNS",
            "",
            "### Count",
            "SELECT COUNT(*) FROM table;",
            "",
            "### Sum",
            "SELECT SUM(column) FROM table;",
            "",
            "### Join",
            "SELECT * FROM t1 JOIN t2 ON t1.id = t2.id;",
            "",
            "### Group By",
            "SELECT col, COUNT(*) FROM table GROUP BY col;",
            "",
            "### Top N",
            "SELECT * FROM table ORDER BY col DESC LIMIT N;"
        ]
        
        os.makedirs("tool_output", exist_ok=True)
        with open("tool_output/pattern_selector_output.txt", "w") as f:
            f.write("\n".join(basic_output))

if __name__ == "__main__":
    select_patterns()