#!/usr/bin/env python3
"""
Pattern Generator - Creates validated SQL query templates.
Combines template mining with schema validation for working patterns.
"""

import sqlite3
import json
import os

class PatternGenerator:
    def __init__(self, db_path="database.sqlite"):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.templates = []
        self.schema = {}

    def generate_patterns(self):
        """Generate validated query patterns."""
        try:
            self._load_schema()
            self._generate_basic_patterns()
            self._generate_join_patterns()
            self._generate_aggregation_patterns()
            self._generate_special_patterns()
            self._save_patterns()
            return self.templates
        finally:
            self.conn.close()

    def _load_schema(self):
        """Load schema information."""
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [t[0] for t in self.cursor.fetchall()]

        for table in tables:
            self.cursor.execute(f"PRAGMA table_info([{table}])")
            columns = self.cursor.fetchall()

            self.schema[table] = {
                "columns": [col[1] for col in columns],
                "types": {col[1]: col[2] for col in columns},
                "pk": [col[1] for col in columns if col[5] > 0]
            }

            # Get foreign keys
            self.cursor.execute(f"PRAGMA foreign_key_list([{table}])")
            fks = self.cursor.fetchall()
            self.schema[table]["fks"] = [
                {"column": fk[3], "ref_table": fk[2], "ref_column": fk[4]}
                for fk in fks
            ]

    def _generate_basic_patterns(self):
        """Generate basic single-table patterns."""
        for table, info in self.schema.items():
            # Escape table name if needed
            table_ref = f"[{table}]" if " " in table or "-" in table else table

            # Basic selection
            self.templates.append({
                "pattern": "basic_select",
                "description": f"Select from {table}",
                "template": f"SELECT {{columns}} FROM {table_ref} WHERE {{condition}};",
                "example": f"SELECT * FROM {table_ref} WHERE {info['columns'][0]} = {{value}};",
                "tables": [table]
            })

            # Count pattern
            self.templates.append({
                "pattern": "count",
                "description": f"Count rows in {table}",
                "template": f"SELECT COUNT(*) FROM {table_ref} WHERE {{condition}};",
                "example": f"SELECT COUNT(*) FROM {table_ref};",
                "tables": [table]
            })

            # Distinct count
            if info["pk"]:
                pk_col = info["pk"][0]
                self.templates.append({
                    "pattern": "count_distinct",
                    "description": f"Count distinct in {table}",
                    "template": f"SELECT COUNT(DISTINCT {{column}}) FROM {table_ref} WHERE {{condition}};",
                    "example": f"SELECT COUNT(DISTINCT {pk_col}) FROM {table_ref};",
                    "tables": [table]
                })

            # Top-N pattern for tables with numeric/date columns
            numeric_cols = [col for col, typ in info["types"].items()
                          if typ.upper() in ['INTEGER', 'REAL', 'NUMERIC']]
            if numeric_cols:
                num_col = numeric_cols[0]
                self.templates.append({
                    "pattern": "top_n",
                    "description": f"Top N from {table}",
                    "template": f"SELECT {{columns}} FROM {table_ref} WHERE {{column}} > 0 ORDER BY {{column}} {{DESC|ASC}} LIMIT {{n}};",
                    "example": f"SELECT * FROM {table_ref} WHERE {num_col} > 0 ORDER BY {num_col} DESC LIMIT 10;",
                    "tables": [table]
                })

    def _generate_join_patterns(self):
        """Generate JOIN patterns."""
        join_pairs = []

        # Find joinable tables via foreign keys
        for table1, info1 in self.schema.items():
            for fk in info1.get("fks", []):
                table2 = fk["ref_table"]
                if table2 in self.schema:
                    join_pairs.append((table1, table2, fk["column"], fk["ref_column"]))

        # Find tables with common columns
        for table1, info1 in self.schema.items():
            for table2, info2 in self.schema.items():
                if table1 >= table2:
                    continue
                common = set(info1["columns"]) & set(info2["columns"])
                for col in common:
                    if col.lower().endswith('id') or col.lower().endswith('_id'):
                        join_pairs.append((table1, table2, col, col))

        # Generate templates for joins (limit to avoid explosion)
        seen = set()
        for t1, t2, col1, col2 in join_pairs[:15]:
            key = tuple(sorted([t1, t2]))
            if key in seen:
                continue
            seen.add(key)

            # Escape table names if needed
            t1_ref = f"[{t1}]" if " " in t1 or "-" in t1 else t1
            t2_ref = f"[{t2}]" if " " in t2 or "-" in t2 else t2

            # Basic join
            self.templates.append({
                "pattern": "two_table_join",
                "description": f"Join {t1} with {t2}",
                "template": f"SELECT {{columns}} FROM {t1_ref} t1 JOIN {t2_ref} t2 ON t1.{col1} = t2.{col2} WHERE {{condition}};",
                "example": f"SELECT t1.*, t2.* FROM {t1_ref} t1 JOIN {t2_ref} t2 ON t1.{col1} = t2.{col2};",
                "tables": [t1, t2],
                "join_columns": [f"t1.{col1}", f"t2.{col2}"]
            })

            # Join with aggregation
            self.templates.append({
                "pattern": "join_aggregate",
                "description": f"Aggregate {t2} grouped by {t1}",
                "template": f"SELECT t1.{{group_col}}, {{AGG}}(t2.{{agg_col}}) FROM {t1_ref} t1 JOIN {t2_ref} t2 ON t1.{col1} = t2.{col2} GROUP BY t1.{{group_col}};",
                "example": f"SELECT t1.{col1}, COUNT(*) FROM {t1_ref} t1 JOIN {t2_ref} t2 ON t1.{col1} = t2.{col2} GROUP BY t1.{col1};",
                "tables": [t1, t2]
            })

    def _generate_aggregation_patterns(self):
        """Generate aggregation patterns."""
        for table, info in self.schema.items():
            table_ref = f"[{table}]" if " " in table or "-" in table else table

            # Find numeric columns
            numeric_cols = [col for col, typ in info["types"].items()
                          if typ.upper() in ['INTEGER', 'REAL', 'NUMERIC']]

            if numeric_cols:
                num_col = numeric_cols[0]

                # Basic aggregations
                for agg in ['SUM', 'AVG', 'MAX', 'MIN']:
                    self.templates.append({
                        "pattern": f"{agg.lower()}_aggregation",
                        "description": f"{agg} of {num_col} in {table}",
                        "template": f"SELECT {agg}({{column}}) FROM {table_ref} WHERE {{condition}};",
                        "example": f"SELECT {agg}({num_col}) FROM {table_ref};",
                        "tables": [table]
                    })

                # Group by aggregation
                if len(info["columns"]) > 1:
                    group_col = next((c for c in info["columns"] if c != num_col), info["columns"][0])
                    self.templates.append({
                        "pattern": "group_by_aggregate",
                        "description": f"Group by in {table}",
                        "template": f"SELECT {{group_col}}, {{AGG}}({{agg_col}}) FROM {table_ref} GROUP BY {{group_col}};",
                        "example": f"SELECT {group_col}, SUM({num_col}) FROM {table_ref} GROUP BY {group_col};",
                        "tables": [table]
                    })

    def _generate_special_patterns(self):
        """Generate special patterns for common SQL needs."""

        # Percentage calculation pattern
        self.templates.append({
            "pattern": "percentage",
            "description": "Calculate percentage",
            "template": "SELECT CAST({{numerator}} AS REAL) * 100 / NULLIF({{denominator}}, 0);",
            "example": "SELECT CAST(COUNT(CASE WHEN condition THEN 1 END) AS REAL) * 100 / NULLIF(COUNT(*), 0);",
            "tables": []
        })

        # Subquery pattern
        if self.schema:
            table = list(self.schema.keys())[0]
            table_ref = f"[{table}]" if " " in table or "-" in table else table
            col = self.schema[table]["columns"][0]

            self.templates.append({
                "pattern": "subquery_filter",
                "description": "Filter with subquery",
                "template": "SELECT {{columns}} FROM {{table}} WHERE {{column}} = (SELECT {{AGG}}({{column}}) FROM {{table}} WHERE {{condition}});",
                "example": f"SELECT * FROM {table_ref} WHERE {col} = (SELECT MAX({col}) FROM {table_ref});",
                "tables": []
            })

        # String pattern matching
        self.templates.append({
            "pattern": "string_pattern",
            "description": "String pattern matching",
            "template": "SELECT {{columns}} FROM {{table}} WHERE {{column}} LIKE '{{pattern}}';",
            "example": "SELECT * FROM table WHERE column LIKE '%pattern%';",
            "tables": []
        })

        # CASE statement pattern
        self.templates.append({
            "pattern": "case_when",
            "description": "Conditional logic",
            "template": "SELECT CASE WHEN {{condition1}} THEN {{value1}} WHEN {{condition2}} THEN {{value2}} ELSE {{default}} END;",
            "example": "SELECT CASE WHEN value > 100 THEN 'High' WHEN value > 50 THEN 'Medium' ELSE 'Low' END;",
            "tables": []
        })

    def _save_patterns(self):
        """Save generated patterns."""
        os.makedirs('tool_output', exist_ok=True)

        # Save as JSON
        with open('tool_output/query_patterns.json', 'w') as f:
            json.dump({"templates": self.templates}, f, indent=2)

        # Save as readable text
        with open('tool_output/query_templates.txt', 'w') as f:
            f.write("SQL QUERY TEMPLATES\n")
            f.write("=" * 60 + "\n\n")

            # Group by pattern type
            pattern_groups = {}
            for template in self.templates:
                pattern = template["pattern"]
                if pattern not in pattern_groups:
                    pattern_groups[pattern] = []
                pattern_groups[pattern].append(template)

            for pattern_type, templates in pattern_groups.items():
                f.write(f"## {pattern_type.upper().replace('_', ' ')}\n")
                f.write("-" * 40 + "\n")
                for t in templates[:2]:  # Limit examples per type
                    f.write(f"Template: {t['template']}\n")
                    f.write(f"Example:  {t['example']}\n")
                    if t.get("tables"):
                        f.write(f"Tables:   {', '.join(t['tables'])}\n")
                    f.write("\n")
                f.write("\n")

        print(f"Pattern generation complete: {len(self.templates)} templates created")

if __name__ == "__main__":
    generator = PatternGenerator()
    generator.generate_patterns()