#!/usr/bin/env python3
"""
Aggregation Analyzer
Determines correct aggregation context and GROUP BY requirements.
"""

import json
import os
import sqlite3

def analyze_aggregation_context(db_path):
    """Analyze aggregation requirements for the database."""
    aggregation_rules = {
        'aggregation_patterns': [],
        'grouping_rules': [],
        'temporal_aggregations': [],
        'entity_aggregations': [],
        'common_errors': [],
        'validation_rules': []
    }

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    # Identify potential aggregation columns
    aggregation_columns = {}
    entity_columns = {}
    temporal_columns = {}

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        aggregation_columns[table] = []
        entity_columns[table] = []
        temporal_columns[table] = []

        for col in columns:
            col_name = col[1]
            col_type = col[2]

            # Numeric columns that can be aggregated
            if col_type in ['INTEGER', 'REAL', 'NUMERIC']:
                if any(keyword in col_name.lower() for keyword in
                       ['amount', 'count', 'total', 'sum', 'quantity', 'price',
                        'score', 'points', 'rebounds', 'assists', 'goals',
                        'duration', 'weight', 'height', 'salary']):
                    aggregation_columns[table].append(col_name)

            # Entity identifiers for GROUP BY
            if any(pattern in col_name.lower() for pattern in
                   ['id', 'name', 'code', 'type', 'category']):
                entity_columns[table].append(col_name)

            # Temporal columns
            if any(pattern in col_name.lower() for pattern in
                   ['year', 'date', 'month', 'season', 'time']):
                temporal_columns[table].append(col_name)

    # Aggregation patterns
    aggregation_rules['aggregation_patterns'] = [
        {
            'pattern': 'Per-Entity Total',
            'keywords': ['throughout career', 'overall', 'total per', 'for each'],
            'sql_template': 'SELECT entity_id, SUM(metric) GROUP BY entity_id',
            'example': 'Total points per player',
            'priority': 'HIGH'
        },
        {
            'pattern': 'Single Maximum/Minimum',
            'keywords': ['the most', 'the least', 'the highest', 'the lowest', 'champion'],
            'sql_template': 'SELECT ... ORDER BY aggregate_metric DESC/ASC LIMIT 1',
            'example': 'The player with most rebounds',
            'priority': 'HIGH'
        },
        {
            'pattern': 'Temporal Aggregation',
            'keywords': ['in [year]', 'during [period]', 'between [year] and [year]'],
            'sql_template': 'SELECT ... WHERE temporal_column = value GROUP BY ...',
            'example': 'Total sales in 2020',
            'priority': 'HIGH'
        },
        {
            'pattern': 'Category Aggregation',
            'keywords': ['by category', 'per type', 'for each group'],
            'sql_template': 'SELECT category, AGG(metric) GROUP BY category',
            'example': 'Average score by department',
            'priority': 'MEDIUM'
        },
        {
            'pattern': 'Having Condition',
            'keywords': ['more than X total', 'exceeding', 'with at least'],
            'sql_template': 'GROUP BY ... HAVING AGG(metric) > value',
            'example': 'Players with more than 1000 points',
            'priority': 'MEDIUM'
        }
    ]

    # Grouping rules based on question patterns
    aggregation_rules['grouping_rules'] = [
        {
            'scenario': 'Question mentions "for each" or "per"',
            'action': 'GROUP BY the entity mentioned after these keywords',
            'example': '"points per player" → GROUP BY player_id'
        },
        {
            'scenario': 'Aggregating without "each/per"',
            'action': 'Usually aggregate across all records (no GROUP BY)',
            'example': '"total points" → SUM(points) without grouping'
        },
        {
            'scenario': 'Finding "the most/least"',
            'action': 'May need GROUP BY for intermediate aggregation',
            'example': '"team with most wins" → GROUP BY team, ORDER BY SUM(wins)'
        },
        {
            'scenario': 'Evidence shows COUNT(x) > n',
            'action': 'Use GROUP BY with HAVING clause',
            'example': 'GROUP BY entity HAVING COUNT(x) > n'
        },
        {
            'scenario': 'Multiple aggregations mentioned',
            'action': 'Consider if they need same or different grouping',
            'example': 'May need subqueries for different aggregation levels'
        }
    ]

    # Temporal aggregation rules
    if any(temporal_columns.values()):
        aggregation_rules['temporal_aggregations'] = [
            {
                'pattern': 'Single time point',
                'keywords': ['in [specific year]', 'on [date]'],
                'sql': 'WHERE year = value',
                'aggregation': 'Aggregate within that time period'
            },
            {
                'pattern': 'Time range',
                'keywords': ['from X to Y', 'between X and Y'],
                'sql': 'WHERE year BETWEEN X AND Y',
                'aggregation': 'Aggregate across the range'
            },
            {
                'pattern': 'Throughout time',
                'keywords': ['overall', 'all-time', 'career'],
                'sql': 'No time filter',
                'aggregation': 'Aggregate all records'
            },
            {
                'pattern': 'Year-over-year',
                'keywords': ['change from', 'growth', 'improvement'],
                'sql': 'Compare aggregates from different years',
                'aggregation': 'May need self-join or window functions'
            }
        ]

    # Entity aggregation rules
    aggregation_rules['entity_aggregations'] = []
    for table, entities in entity_columns.items():
        if entities:
            aggregation_rules['entity_aggregations'].append({
                'table': table,
                'entity_columns': entities,
                'aggregatable_columns': aggregation_columns.get(table, []),
                'common_patterns': [
                    f"SUM({col}) GROUP BY {entity}"
                    for col in aggregation_columns.get(table, [])
                    for entity in entities[:2]  # Just show examples
                ]
            })

    # Common aggregation errors
    aggregation_rules['common_errors'] = [
        {
            'error': 'Missing GROUP BY when using aggregate functions',
            'example': 'SELECT name, COUNT(*) without GROUP BY name',
            'fix': 'Add GROUP BY for non-aggregated columns'
        },
        {
            'error': 'Wrong aggregation level',
            'example': 'SUM per game when asked for career total',
            'fix': 'Check if grouping is needed based on question context'
        },
        {
            'error': 'Using WHERE instead of HAVING for aggregate conditions',
            'example': 'WHERE COUNT(*) > 5',
            'fix': 'Use HAVING for post-aggregation filtering'
        },
        {
            'error': 'Not aggregating when "total" is mentioned',
            'example': 'Returning individual records instead of SUM',
            'fix': 'Look for aggregation keywords: total, sum, average'
        },
        {
            'error': 'Incorrect join before aggregation',
            'example': 'Join causing duplicate rows before SUM',
            'fix': 'Aggregate in subquery if needed'
        }
    ]

    # Validation rules
    aggregation_rules['validation_rules'] = [
        {
            'rule': 'If using aggregate function, check GROUP BY',
            'check': 'All non-aggregated columns in GROUP BY',
            'priority': 'CRITICAL'
        },
        {
            'rule': 'If "per/each" mentioned, verify grouping',
            'check': 'GROUP BY matches the "per" entity',
            'priority': 'HIGH'
        },
        {
            'rule': 'If "the most/least", verify ORDER BY and LIMIT',
            'check': 'ORDER BY aggregate DESC/ASC LIMIT 1',
            'priority': 'HIGH'
        },
        {
            'rule': 'If evidence has HAVING, use it',
            'check': 'HAVING clause matches evidence condition',
            'priority': 'CRITICAL'
        }
    ]

    # Context keywords
    aggregation_rules['context_keywords'] = {
        'require_aggregation': [
            'total', 'sum', 'average', 'mean', 'count',
            'maximum', 'minimum', 'highest', 'lowest'
        ],
        'require_grouping': [
            'per', 'each', 'by', 'for every', 'grouped by'
        ],
        'no_grouping': [
            'overall', 'all combined', 'in total', 'altogether'
        ],
        'require_having': [
            'with more than', 'having at least', 'exceeding'
        ]
    }

    conn.close()
    return aggregation_rules

def main():
    db_path = "./database.sqlite"
    output_dir = "./tool_output"
    os.makedirs(output_dir, exist_ok=True)

    # Analyze aggregation context
    rules = analyze_aggregation_context(db_path)

    # Save results
    output_path = os.path.join(output_dir, "aggregation_rules.json")
    with open(output_path, 'w') as f:
        json.dump(rules, f, indent=2)

    print("Aggregation analysis complete")
    print(f"Generated {len(rules['aggregation_patterns'])} aggregation patterns")
    print(f"Generated {len(rules['grouping_rules'])} grouping rules")
    print(f"Identified {len(rules['common_errors'])} common errors")
    print(f"Results saved to {output_path}")
    print("\nREMEMBER: Context determines aggregation level")

if __name__ == "__main__":
    main()