#!/usr/bin/env python3
"""
Query Template Mining Tool
Analyzes database structure to generate SQL query templates for common patterns.
Focuses on concise, actionable output that helps the eval model generate accurate SQL.
"""

import sqlite3
import json
import os
from typing import Dict, List, Tuple, Set
from collections import defaultdict

class QueryTemplateMiner:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.templates = []
        self.schema_info = {}
        
    def mine(self):
        """Mine database for query templates."""
        try:
            self._analyze_schema()
            self._generate_basic_templates()
            self._generate_join_templates()
            self._generate_aggregation_templates()
            self._generate_special_templates()
            return self.templates, self.schema_info
        finally:
            self.conn.close()
            
    def _analyze_schema(self):
        """Quick schema analysis for template generation."""
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [t[0] for t in self.cursor.fetchall()]
        
        for table in tables:
            # Get columns
            self.cursor.execute(f"PRAGMA table_info({table})")
            columns = self.cursor.fetchall()
            
            # Get row count for size assessment
            self.cursor.execute(f"SELECT COUNT(*) FROM {table}")
            row_count = self.cursor.fetchone()[0]
            
            # Store schema info
            self.schema_info[table] = {
                "columns": [col[1] for col in columns],
                "pk": [col[1] for col in columns if col[5]],  # Primary keys
                "types": {col[1]: col[2] for col in columns},
                "row_count": row_count
            }
            
            # Detect foreign keys
            self.cursor.execute(f"PRAGMA foreign_key_list({table})")
            fks = self.cursor.fetchall()
            self.schema_info[table]["fks"] = [
                {"column": fk[3], "ref_table": fk[2], "ref_column": fk[4]}
                for fk in fks
            ]
            
    def _generate_basic_templates(self):
        """Generate templates for single-table queries."""
        for table, info in self.schema_info.items():
            # Basic selection
            self.templates.append({
                "pattern": "single_table_select",
                "template": f"SELECT {{columns}} FROM {table} WHERE {{condition}};",
                "example": f"SELECT * FROM {table} WHERE {info['columns'][0]} = {{value}};",
                "tables": [table]
            })
            
            # Count queries
            self.templates.append({
                "pattern": "count",
                "template": f"SELECT COUNT(*) FROM {table} WHERE {{condition}};",
                "example": f"SELECT COUNT(*) FROM {table} WHERE {info['columns'][0]} = {{value}};",
                "tables": [table]
            })
            
            # Top-N queries
            if any(col for col in info['columns'] if 'date' in col.lower() or 
                   'year' in col.lower() or 'time' in col.lower() or
                   any(info['types'][col].upper() in ['INTEGER', 'REAL', 'NUMERIC'] for col in info['columns'])):
                self.templates.append({
                    "pattern": "top_n",
                    "template": f"SELECT {{columns}} FROM {table} WHERE {{filter_condition}} ORDER BY {{order_column}} {{ASC|DESC}} LIMIT {{n}};",
                    "example": f"SELECT * FROM {table} ORDER BY {info['columns'][0]} DESC LIMIT 10;",
                    "tables": [table]
                })
                
    def _generate_join_templates(self):
        """Generate templates for join queries."""
        # Find joinable tables
        join_pairs = []
        
        for table1, info1 in self.schema_info.items():
            # Check foreign keys
            for fk in info1.get("fks", []):
                table2 = fk["ref_table"]
                if table2 in self.schema_info:
                    join_pairs.append((table1, table2, fk["column"], fk["ref_column"]))
                    
            # Check common column names (implicit joins)
            for table2, info2 in self.schema_info.items():
                if table1 >= table2:
                    continue
                common_cols = set(info1["columns"]) & set(info2["columns"])
                for col in common_cols:
                    if col.endswith('id') or col.endswith('Id') or col.endswith('ID'):
                        join_pairs.append((table1, table2, col, col))
                        
        # Generate join templates
        for t1, t2, col1, col2 in join_pairs[:10]:  # Limit to avoid explosion
            # Two-table join
            self.templates.append({
                "pattern": "two_table_join",
                "template": f"SELECT {{columns}} FROM {t1} t1 JOIN {t2} t2 ON t1.{col1} = t2.{col2} WHERE {{condition}};",
                "example": f"SELECT t1.*, t2.* FROM {t1} t1 JOIN {t2} t2 ON t1.{col1} = t2.{col2};",
                "tables": [t1, t2]
            })
            
            # Join with aggregation
            self.templates.append({
                "pattern": "join_aggregate",
                "template": f"SELECT t1.{{group_col}}, {{agg_func}}(t2.{{agg_col}}) FROM {t1} t1 JOIN {t2} t2 ON t1.{col1} = t2.{col2} GROUP BY t1.{{group_col}};",
                "example": f"SELECT t1.{col1}, COUNT(*) FROM {t1} t1 JOIN {t2} t2 ON t1.{col1} = t2.{col2} GROUP BY t1.{col1};",
                "tables": [t1, t2]
            })
            
    def _generate_aggregation_templates(self):
        """Generate templates for aggregation queries."""
        for table, info in self.schema_info.items():
            # Find numeric columns
            numeric_cols = [col for col in info["columns"] 
                          if info["types"][col].upper() in ['INTEGER', 'REAL', 'NUMERIC']]
            
            if numeric_cols:
                num_col = numeric_cols[0]
                
                # Simple aggregations
                for agg in ['SUM', 'AVG', 'MAX', 'MIN']:
                    self.templates.append({
                        "pattern": f"{agg.lower()}_aggregation",
                        "template": f"SELECT {agg}({num_col}) FROM {table} WHERE {{condition}};",
                        "example": f"SELECT {agg}({num_col}) FROM {table};",
                        "tables": [table]
                    })
                    
                # Group by aggregation
                if len(info["columns"]) > 1:
                    group_col = next((c for c in info["columns"] if c != num_col), info["columns"][0])
                    self.templates.append({
                        "pattern": "group_by_aggregate",
                        "template": f"SELECT {{group_columns}}, {{agg_function}}({{agg_column}}) FROM {table} GROUP BY {{group_columns}} HAVING {{having_condition}};",
                        "example": f"SELECT {group_col}, COUNT(*), SUM({num_col}) FROM {table} GROUP BY {group_col};",
                        "tables": [table]
                    })
                    
                # Percentage calculation
                self.templates.append({
                    "pattern": "percentage",
                    "template": f"SELECT (CAST(SUM(CASE WHEN {{condition}} THEN 1 ELSE 0 END) AS REAL) * 100 / COUNT(*)) FROM {table};",
                    "example": f"SELECT (CAST(SUM(CASE WHEN {num_col} > 0 THEN 1 ELSE 0 END) AS REAL) * 100 / COUNT(*)) FROM {table};",
                    "tables": [table]
                })
                
    def _generate_special_templates(self):
        """Generate templates for special query patterns."""
        # Year-over-year comparison
        for table, info in self.schema_info.items():
            year_cols = [col for col in info["columns"] if 'year' in col.lower()]
            if year_cols:
                year_col = year_cols[0]
                self.templates.append({
                    "pattern": "year_over_year",
                    "template": f"SELECT ({{metric_year2}} - {{metric_year1}}) FROM (SELECT {{agg}}({{column}}) as {{metric_year1}} FROM {table} WHERE {year_col} = {{year1}}) t1, (SELECT {{agg}}({{column}}) as {{metric_year2}} FROM {table} WHERE {year_col} = {{year2}}) t2;",
                    "example": f"SELECT COUNT(*) FROM {table} WHERE {year_col} = 2023;",
                    "tables": [table]
                })
                
        # Subquery patterns
        for table in list(self.schema_info.keys())[:3]:  # Limit for conciseness
            self.templates.append({
                "pattern": "subquery_filter",
                "template": f"SELECT {{columns}} FROM {table} WHERE {{column}} = (SELECT {{subquery_agg}}({{subquery_col}}) FROM {{subquery_table}} WHERE {{subquery_condition}});",
                "example": f"SELECT * FROM {table} WHERE {self.schema_info[table]['columns'][0]} = (SELECT MAX({self.schema_info[table]['columns'][0]}) FROM {table});",
                "tables": [table]
            })
            
        # Self-join pattern
        for table, info in self.schema_info.items():
            if info["row_count"] < 100000:  # Only for smaller tables
                self.templates.append({
                    "pattern": "self_join",
                    "template": f"SELECT t1.{{columns}} FROM {table} t1 JOIN {table} t2 ON t1.{{join_col}} = t2.{{join_col}} WHERE {{condition}};",
                    "example": f"SELECT t1.* FROM {table} t1 JOIN {table} t2 ON t1.{info['columns'][0]} = t2.{info['columns'][0]};",
                    "tables": [table]
                })
                break  # One example is enough

def format_templates_output(templates: List[Dict], schema_info: Dict) -> str:
    """Format templates into concise, actionable output."""
    output = []
    
    # Quick schema reference
    output.append("# Query Templates for Database\n")
    output.append("## Quick Schema Reference")
    for table, info in schema_info.items():
        cols_str = ", ".join(info["columns"][:7])  # Limit columns shown
        if len(info["columns"]) > 7:
            cols_str += f", ... ({len(info['columns'])} total)"
        output.append(f"- **{table}** ({info['row_count']} rows): {cols_str}")
        if info["pk"]:
            output.append(f"  PK: {', '.join(info['pk'])}")
        if info["fks"]:
            for fk in info["fks"][:2]:  # Limit FK display
                output.append(f"  FK: {fk['column']} → {fk['ref_table']}.{fk['ref_column']}")
    output.append("")
    
    # Group templates by pattern
    patterns = defaultdict(list)
    for template in templates:
        patterns[template["pattern"]].append(template)
        
    # Output key templates
    output.append("## Query Templates\n")
    
    # Single table patterns
    output.append("### Single Table Queries")
    for pattern in ["single_table_select", "count", "top_n"]:
        if pattern in patterns:
            t = patterns[pattern][0]  # Show first example
            output.append(f"**{pattern}**:")
            output.append(f"```sql\n{t['template']}\n```")
            
    # Join patterns
    output.append("\n### Join Queries")
    for pattern in ["two_table_join", "join_aggregate"]:
        if pattern in patterns:
            for t in patterns[pattern][:2]:  # Show up to 2 examples
                output.append(f"**{pattern} ({' + '.join(t['tables'])})**:")
                output.append(f"```sql\n{t['template']}\n```")
                
    # Aggregation patterns
    output.append("\n### Aggregation Queries")
    for pattern in ["sum_aggregation", "group_by_aggregate", "percentage"]:
        if pattern in patterns:
            t = patterns[pattern][0]
            output.append(f"**{pattern}**:")
            output.append(f"```sql\n{t['template']}\n```")
            
    # Special patterns
    if any(p in patterns for p in ["year_over_year", "subquery_filter", "self_join"]):
        output.append("\n### Special Patterns")
        for pattern in ["year_over_year", "subquery_filter", "self_join"]:
            if pattern in patterns:
                t = patterns[pattern][0]
                output.append(f"**{pattern}**:")
                output.append(f"```sql\n{t['template']}\n```")
                
    # Evidence mapping guide
    output.append("\n## Evidence Mapping Guide")
    output.append("When evidence provides mappings:")
    output.append("- `X refers to ColumnName` → Replace X with ColumnName")
    output.append("- `'value' is the ColumnName` → Use WHERE ColumnName = 'value'")
    output.append("- `calculation = FORMULA` → Use the exact FORMULA")
    output.append("- `the oldest` → Use MIN() or ORDER BY ASC LIMIT 1")
    output.append("- `percentage of` → Use percentage template")
    
    return "\n".join(output)

def main():
    """Main execution."""
    db_path = "./database.sqlite"
    
    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return
        
    os.makedirs("tool_output", exist_ok=True)
    
    print("Mining query templates...")
    
    try:
        miner = QueryTemplateMiner(db_path)
        templates, schema_info = miner.mine()
        
        # Save raw templates
        with open("tool_output/templates.json", "w") as f:
            json.dump({"templates": templates, "schema": schema_info}, f, indent=2)
            
        # Save formatted output
        formatted = format_templates_output(templates, schema_info)
        with open("tool_output/query_templates.txt", "w") as f:
            f.write(formatted)
            
        print(f"Generated {len(templates)} query templates")
        print("Output saved to tool_output/query_templates.txt")
        
    except Exception as e:
        print(f"Error: {str(e)}")
        # Still try to output something useful
        with open("tool_output/query_templates.txt", "w") as f:
            f.write(f"# Template Mining Error\nError occurred: {str(e)}\n")
            f.write("Falling back to basic analysis mode.\n")

if __name__ == "__main__":
    main()