#!/usr/bin/env python3
"""
Generate SQL templates with correct column ordering patterns.
Enhanced to emphasize column order importance.
"""

import sqlite3
import os

class TemplateGenerator:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.templates = {}

    def generate(self):
        """Generate SQL templates for common patterns."""
        os.makedirs('tool_output', exist_ok=True)

        try:
            self._analyze_database()
            self._generate_templates()
            self._save_results()
            print("Template generation complete - results in tool_output/query_templates.txt")
            return True
        except Exception as e:
            print(f"Template generation error: {e}")
            return False
        finally:
            self.conn.close()

    def _quote_column(self, col_name: str) -> str:
        """Properly quote column names with spaces or special characters."""
        needs_quoting = any(char in col_name for char in [' ', '(', ')', '-', '/', '#', '@', '.'])
        if needs_quoting:
            return f'"{col_name}"'
        return col_name

    def _analyze_database(self):
        """Analyze database structure for template generation."""
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        self.tables = [row[0] for row in self.cursor.fetchall()]

        self.table_info = {}
        for table in self.tables:
            self.cursor.execute(f"PRAGMA table_info({table})")
            columns = self.cursor.fetchall()

            self.cursor.execute(f"PRAGMA foreign_key_list({table})")
            foreign_keys = self.cursor.fetchall()

            self.table_info[table] = {
                'columns': [col[1] for col in columns],
                'types': {col[1]: col[2] for col in columns},
                'foreign_keys': foreign_keys
            }

    def _generate_templates(self):
        """Generate SQL templates for common query patterns."""

        # Column ordering templates
        self.templates['column_ordering'] = [
            {
                'pattern': 'Two columns requested',
                'template': '''-- Question: "What is the A and the B?"
SELECT A, B  -- EXACT order: A first, then B
FROM table;''',
                'warning': 'Column order MUST match question order'
            },
            {
                'pattern': 'Three columns requested',
                'template': '''-- Question: "List the X, Y, and Z"
SELECT X, Y, Z  -- EXACT order as listed
FROM table;''',
                'warning': 'Never reorder columns'
            },
            {
                'pattern': 'ID and description',
                'template': '''-- Question: "State the ID and name"
SELECT id, name  -- ID first as requested
FROM table;''',
                'warning': 'Don\'t swap even if name seems more important'
            }
        ]

        # Join completion templates
        self.templates['join_completion'] = [
            {
                'pattern': 'Get name from ID',
                'template': '''-- When question asks for NAME not ID
SELECT t2.name  -- Join to get the name
FROM table1 t1
JOIN table2 t2 ON t1.foreign_id = t2.id
WHERE condition;''',
                'warning': 'Must complete join to get human-readable value'
            },
            {
                'pattern': 'Multiple joins for names',
                'template': '''-- Get multiple names from IDs
SELECT t2.name, t3.description
FROM main_table t1
JOIN name_table t2 ON t1.name_id = t2.id
JOIN desc_table t3 ON t1.desc_id = t3.id;''',
                'warning': 'Complete all joins for final values'
            }
        ]

        # Aggregation templates
        self.templates['aggregation'] = [
            {
                'pattern': 'Count with GROUP BY',
                'template': '''-- Count per entity
SELECT entity, COUNT(*) as count
FROM table
GROUP BY entity  -- Must include all non-aggregated columns
ORDER BY count DESC;''',
                'warning': 'GROUP BY all SELECT columns except aggregates'
            },
            {
                'pattern': 'Conditional counting',
                'template': '''-- Count only matching rows
SELECT COUNT(CASE WHEN status = 'active' THEN 1 END) as active_count
FROM table;''',
                'warning': 'Use CASE WHEN for conditional counts, not COUNT(condition)'
            },
            {
                'pattern': 'Finding maximum with details',
                'template': '''-- Get entity with max value
SELECT entity, value
FROM table
ORDER BY value DESC
LIMIT 1;

-- OR with aggregation
SELECT entity, MAX(value) as max_value
FROM table
GROUP BY entity
ORDER BY max_value DESC
LIMIT 1;''',
                'warning': 'Choose based on whether GROUP BY is needed'
            }
        ]

        # Evidence-based templates
        self.templates['evidence_patterns'] = [
            {
                'pattern': 'Evidence with exact mapping',
                'template': '''-- Evidence: "category refers to cat_type = 'electronics'"
SELECT columns
FROM table
WHERE cat_type = 'electronics';  -- Use exact column and value from evidence''',
                'warning': 'Evidence overrides schema assumptions'
            },
            {
                'pattern': 'Evidence with multiple conditions',
                'template': '''-- Evidence: "year = 2020; status = 'active'; region refers to area_code"
SELECT columns
FROM table
WHERE year = 2020
  AND status = 'active'
  AND area_code = some_value;  -- Apply ALL conditions''',
                'warning': 'Never skip evidence conditions'
            }
        ]

        # Table-specific templates
        for table in self.tables[:3]:  # Limit to first 3 tables
            columns = self.table_info[table]['columns']

            if len(columns) >= 2:
                col1 = self._quote_column(columns[0])
                col2 = self._quote_column(columns[1])

                self.templates[f'{table}_basic'] = [
                    {
                        'pattern': f'Select from {table}',
                        'template': f'''-- Basic selection
SELECT {col1}, {col2}
FROM {table}
WHERE {col1} = value;''',
                        'warning': 'Check if columns need quoting'
                    }
                ]

    def _save_results(self):
        """Save templates to file."""
        with open('tool_output/query_templates.txt', 'w') as f:
            f.write("# SQL QUERY TEMPLATES\n\n")
            f.write("## ⚠️ CRITICAL: COLUMN ORDER MATTERS\n\n")
            f.write("**ALWAYS return columns in the EXACT order requested in the question!**\n\n")

            # Column ordering section first (most important)
            if 'column_ordering' in self.templates:
                f.write("## Column Ordering Templates\n\n")
                for template in self.templates['column_ordering']:
                    f.write(f"### {template['pattern']}\n")
                    f.write(f"```sql\n{template['template']}\n```\n")
                    f.write(f"**⚠️ {template['warning']}**\n\n")

            # Join completion
            if 'join_completion' in self.templates:
                f.write("## Join Completion Patterns\n\n")
                for template in self.templates['join_completion']:
                    f.write(f"### {template['pattern']}\n")
                    f.write(f"```sql\n{template['template']}\n```\n")
                    f.write(f"**⚠️ {template['warning']}**\n\n")

            # Evidence patterns
            if 'evidence_patterns' in self.templates:
                f.write("## Evidence Application Patterns\n\n")
                for template in self.templates['evidence_patterns']:
                    f.write(f"### {template['pattern']}\n")
                    f.write(f"```sql\n{template['template']}\n```\n")
                    f.write(f"**⚠️ {template['warning']}**\n\n")

            # Aggregation patterns
            if 'aggregation' in self.templates:
                f.write("## Aggregation Patterns\n\n")
                for template in self.templates['aggregation']:
                    f.write(f"### {template['pattern']}\n")
                    f.write(f"```sql\n{template['template']}\n```\n")
                    f.write(f"**⚠️ {template['warning']}**\n\n")

            # Common mistakes section
            f.write("## Common Mistakes to Avoid\n\n")
            f.write("1. **Wrong column order**: Returning (B, A) when asked for (A, B)\n")
            f.write("2. **Incomplete joins**: Returning ID when name was requested\n")
            f.write("3. **Missing evidence**: Not applying all conditions from evidence\n")
            f.write("4. **Wrong aggregation**: Using COUNT(condition) instead of COUNT(CASE WHEN)\n")
            f.write("5. **Unquoted columns**: Forgetting quotes for columns with spaces\n")
            f.write("6. **Case mismatch**: Using 'value' when data has 'Value'\n")
            f.write("7. **Table confusion**: Using wrong table for overlapping columns\n")

            f.write("\n## Remember\n\n")
            f.write("- **Evidence is law** - Apply all evidence conditions\n")
            f.write("- **Order matters** - Return columns as requested\n")
            f.write("- **Complete joins** - Get final values, not intermediate IDs\n")
            f.write("- **Check warnings** - Each table may have special requirements\n")

if __name__ == "__main__":
    generator = TemplateGenerator("database.sqlite")
    generator.generate()