#!/usr/bin/env python3
"""
Detect tables and patterns that require GROUP BY clauses.
Fixed version with proper column quoting.
"""

import sqlite3
import os

class AggregationDetector:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.aggregation_patterns = {}

    def detect(self):
        """Detect aggregation patterns in the database."""
        os.makedirs('tool_output', exist_ok=True)

        try:
            self._detect_patterns()
            self._save_results()
            print("Aggregation detection complete - results in tool_output/aggregation_guide.txt")
            return True
        except Exception as e:
            print(f"Aggregation detection error: {e}")
            return False
        finally:
            self.conn.close()

    def _quote_column(self, col_name: str) -> str:
        """Properly quote column names with spaces or special characters."""
        needs_quoting = any(char in col_name for char in [' ', '(', ')', '-', '/', '#', '@', '.'])
        if needs_quoting:
            return f'"{col_name}"'
        return col_name

    def _detect_patterns(self):
        """Detect which tables need GROUP BY."""
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [row[0] for row in self.cursor.fetchall()]

        for table in tables:
            self.cursor.execute(f"PRAGMA table_info({table})")
            columns = [row[1] for row in self.cursor.fetchall()]

            # Categorize columns
            entity_columns = []
            time_columns = []
            measure_columns = []

            for col in columns:
                col_lower = col.lower()

                # Entity columns (things we group by)
                if any(x in col_lower for x in ['id', 'code', 'name', 'type', 'team', 'player']):
                    if not any(x in col_lower for x in ['date', 'time', 'created']):
                        entity_columns.append(col)

                # Time columns
                if any(x in col_lower for x in ['year', 'date', 'month', 'season', 'quarter']):
                    time_columns.append(col)

                # Measure columns (things we aggregate)
                if any(x in col_lower for x in ['amount', 'count', 'total', 'score', 'price', 'quantity']):
                    measure_columns.append(col)

            # Check for multiple rows per entity
            aggregation_needed = False
            examples = []

            for entity_col in entity_columns[:2]:  # Check first 2 entity columns
                quoted_col = self._quote_column(entity_col)

                try:
                    # Check for duplicates
                    query = f"""
                    SELECT {quoted_col}, COUNT(*) as cnt
                    FROM {table}
                    WHERE {quoted_col} IS NOT NULL
                    GROUP BY {quoted_col}
                    HAVING COUNT(*) > 1
                    ORDER BY cnt DESC
                    LIMIT 5
                    """
                    self.cursor.execute(query)
                    duplicates = self.cursor.fetchall()

                    if duplicates:
                        aggregation_needed = True
                        for dup_val, dup_count in duplicates[:3]:
                            examples.append({
                                'column': entity_col,
                                'value': str(dup_val),
                                'count': dup_count
                            })

                        # Get average rows per entity
                        self.cursor.execute(f"SELECT COUNT(*) FROM {table}")
                        total_rows = self.cursor.fetchone()[0]

                        self.cursor.execute(f"SELECT COUNT(DISTINCT {quoted_col}) FROM {table}")
                        unique_count = self.cursor.fetchone()[0]

                        if unique_count > 0:
                            avg_rows_per_entity = total_rows / unique_count
                        else:
                            avg_rows_per_entity = 0

                        break

                except Exception:
                    continue

            self.aggregation_patterns[table] = {
                'needs_group_by': aggregation_needed,
                'entity_columns': entity_columns,
                'time_columns': time_columns,
                'measure_columns': measure_columns,
                'examples': examples,
                'avg_rows_per_entity': avg_rows_per_entity if aggregation_needed else 1
            }

    def _save_results(self):
        """Save detection results to file."""
        with open('tool_output/aggregation_guide.txt', 'w') as f:
            f.write("# AGGREGATION REQUIREMENTS\n\n")

            # Tables needing GROUP BY
            tables_needing_groupby = []
            for table, patterns in self.aggregation_patterns.items():
                if patterns['needs_group_by']:
                    tables_needing_groupby.append((table, patterns))

            if tables_needing_groupby:
                f.write("## ⚠️ TABLES REQUIRING GROUP BY\n\n")

                for table, patterns in tables_needing_groupby:
                    f.write(f"### {table}\n")
                    f.write(f"- **NEEDS GROUP BY**: ~{patterns['avg_rows_per_entity']:.1f} rows per entity\n")

                    if patterns['entity_columns']:
                        f.write(f"- **Group by columns**: {', '.join(patterns['entity_columns'][:3])}\n")

                    if patterns['measure_columns']:
                        f.write(f"- **Aggregate columns**: {', '.join(patterns['measure_columns'][:3])}\n")

                    if patterns['examples']:
                        f.write("- **Examples**:\n")
                        for ex in patterns['examples'][:2]:
                            f.write(f"  - {ex['column']}='{ex['value']}' has {ex['count']} rows\n")

                    f.write("\n")

                f.write("### Common GROUP BY Patterns\n")
                f.write("```sql\n")
                f.write("-- Count per entity\n")
                f.write("SELECT entity, COUNT(*)\n")
                f.write("FROM table\n")
                f.write("GROUP BY entity;\n\n")

                f.write("-- Sum measures per entity\n")
                f.write("SELECT entity, SUM(measure)\n")
                f.write("FROM table\n")
                f.write("GROUP BY entity;\n\n")

                f.write("-- Entity with time period\n")
                f.write("SELECT entity, year, SUM(value)\n")
                f.write("FROM table\n")
                f.write("GROUP BY entity, year;\n")
                f.write("```\n\n")

            # Tables without aggregation needs
            tables_single_row = []
            for table, patterns in self.aggregation_patterns.items():
                if not patterns['needs_group_by']:
                    tables_single_row.append(table)

            if tables_single_row:
                f.write("## Tables with Single Row per Entity\n")
                f.write("These tables typically don't need GROUP BY:\n\n")
                for table in tables_single_row:
                    f.write(f"- {table}\n")
                f.write("\n")

            # General rules
            f.write("## GROUP BY Rules\n\n")
            f.write("1. **Always GROUP BY when using aggregate functions** (SUM, COUNT, AVG, etc.)\n")
            f.write("2. **Include all non-aggregated columns** from SELECT in GROUP BY\n")
            f.write("3. **Use conditional counting**:\n")
            f.write("   ```sql\n")
            f.write("   COUNT(CASE WHEN condition THEN 1 END)  -- CORRECT\n")
            f.write("   COUNT(condition)  -- WRONG\n")
            f.write("   ```\n")
            f.write("4. **Check the warnings above** for tables with multiple rows per entity\n")

if __name__ == "__main__":
    detector = AggregationDetector("database.sqlite")
    detector.detect()