#!/usr/bin/env python3
"""
Aggregation Disambiguator Tool
Provides clear rules for when to use SUM vs COUNT and other aggregation confusions.
Based on extensive error analysis.
"""

import sqlite3
import json
import os

def disambiguate_aggregations(db_path="database.sqlite"):
    """Generate disambiguation rules for aggregation functions."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    disambiguation = {
        "sum_vs_count_rules": [],
        "group_by_rules": [],
        "having_vs_where": [],
        "distinct_rules": [],
        "percentage_patterns": [],
        "database_specific_hints": []
    }

    try:
        # Get table and column information
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall() if not row[0].startswith("sqlite_")]

        numeric_columns = {}
        id_columns = {}

        for table in tables:
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()

            numeric_columns[table] = []
            id_columns[table] = []

            for col in columns:
                col_name = col[1]
                col_type = col[2].upper()
                is_pk = col[5]

                if col_type in ['INTEGER', 'REAL', 'NUMERIC', 'FLOAT', 'DOUBLE']:
                    numeric_columns[table].append(col_name)

                if is_pk or col_name.lower().endswith('id') or col_name.lower() == 'id':
                    id_columns[table].append(col_name)

        # Define SUM vs COUNT rules
        disambiguation["sum_vs_count_rules"] = [
            {
                "rule": "COUNT for counting records",
                "when": "Counting rows, records, occurrences, or entities",
                "use": "COUNT(*) or COUNT(column)",
                "examples": [
                    "How many orders → COUNT(*)",
                    "Number of customers → COUNT(*)",
                    "Count of transactions → COUNT(*)"
                ],
                "warning": "NEVER use SUM to count records"
            },
            {
                "rule": "SUM for numeric totals",
                "when": "Adding up numeric values",
                "use": "SUM(numeric_column)",
                "examples": [
                    "Total sales amount → SUM(amount)",
                    "Sum of quantities → SUM(quantity)",
                    "Total revenue → SUM(revenue)"
                ],
                "warning": "Only use SUM on numeric value columns"
            },
            {
                "rule": "COUNT for IDs",
                "when": "Counting based on ID columns",
                "use": "COUNT(id) or COUNT(DISTINCT id)",
                "examples": [
                    "Number of unique customers → COUNT(DISTINCT customer_id)",
                    "How many products → COUNT(product_id)"
                ],
                "warning": "Do NOT use SUM(id) to count IDs"
            },
            {
                "rule": "SUM(OrderQty) vs COUNT(OrderQty)",
                "when": "Dealing with quantity columns",
                "use": "SUM for total quantity, COUNT for number of orders",
                "examples": [
                    "Total quantity ordered → SUM(OrderQty)",
                    "Number of orders → COUNT(*) or COUNT(OrderQty)"
                ],
                "warning": "Common confusion - check what's being asked"
            }
        ]

        # Define GROUP BY rules
        disambiguation["group_by_rules"] = [
            {
                "rule": "Required with mixed aggregates",
                "when": "SELECT has both aggregate and non-aggregate columns",
                "template": "SELECT col1, AGG(col2) FROM table GROUP BY col1",
                "warning": "ALL non-aggregate columns must be in GROUP BY"
            },
            {
                "rule": "GROUP BY before ORDER BY",
                "when": "Grouping and sorting",
                "template": "GROUP BY column ORDER BY aggregate",
                "warning": "GROUP BY comes before ORDER BY"
            },
            {
                "rule": "Not needed for simple aggregates",
                "when": "Only aggregate functions in SELECT",
                "template": "SELECT COUNT(*) FROM table",
                "warning": "No GROUP BY needed if no non-aggregate columns"
            }
        ]

        # Define HAVING vs WHERE
        disambiguation["having_vs_where"] = [
            {
                "rule": "WHERE for row filtering",
                "when": "Filter before aggregation",
                "template": "WHERE column = value",
                "example": "WHERE age > 30"
            },
            {
                "rule": "HAVING for aggregate filtering",
                "when": "Filter after aggregation",
                "template": "HAVING COUNT(*) > 5",
                "example": "GROUP BY customer_id HAVING SUM(amount) > 1000"
            }
        ]

        # Define DISTINCT rules
        disambiguation["distinct_rules"] = [
            {
                "rule": "COUNT(DISTINCT) only when needed",
                "when": "Question explicitly asks for unique/different/distinct",
                "correct": "COUNT(DISTINCT column)",
                "incorrect": "Using DISTINCT when not needed",
                "examples": [
                    "How many customers → COUNT(*) if one row per customer",
                    "How many unique customers → COUNT(DISTINCT customer_id)"
                ]
            },
            {
                "rule": "DISTINCT in SELECT",
                "when": "Remove duplicate rows from results",
                "template": "SELECT DISTINCT column FROM table",
                "warning": "Applied to entire row, not individual columns"
            }
        ]

        # Define percentage calculation patterns
        disambiguation["percentage_patterns"] = [
            {
                "pattern": "Percentage of total",
                "formula": "CAST(COUNT(CASE WHEN condition THEN 1 END) AS REAL) * 100.0 / COUNT(*)",
                "example": "Percentage of active users",
                "warning": "Cast to REAL for decimal division"
            },
            {
                "pattern": "Ratio as percentage",
                "formula": "CAST(numerator AS REAL) * 100 / denominator",
                "example": "Success rate percentage",
                "warning": "Multiply by 100 if percentage needed"
            },
            {
                "pattern": "Percentage change",
                "formula": "((new_value - old_value) * 100.0) / old_value",
                "example": "Percentage increase",
                "warning": "Check if absolute or percentage change needed"
            }
        ]

        # Add database-specific hints
        for table in tables[:5]:  # Limit to first 5 tables
            if numeric_columns.get(table):
                for col in numeric_columns[table][:3]:  # Limit columns
                    col_lower = col.lower()

                    hint = {
                        "table": table,
                        "column": col,
                        "hint": ""
                    }

                    if 'qty' in col_lower or 'quantity' in col_lower:
                        hint["hint"] = "SUM for total quantity, COUNT for number of records"
                    elif 'amount' in col_lower or 'total' in col_lower or 'price' in col_lower:
                        hint["hint"] = "SUM for monetary totals"
                    elif 'count' in col_lower:
                        hint["hint"] = "This is already a count - SUM it, don't COUNT it"
                    elif col_lower.endswith('id'):
                        hint["hint"] = "COUNT or COUNT(DISTINCT) for IDs, never SUM"

                    if hint["hint"]:
                        disambiguation["database_specific_hints"].append(hint)

    except Exception as e:
        disambiguation["error"] = str(e)
    finally:
        conn.close()

    # Add critical disambiguation notes
    disambiguation["critical_notes"] = [
        "SUM adds values, COUNT counts records - never confuse them",
        "COUNT(*) counts all rows, COUNT(column) counts non-null values",
        "GROUP BY is required when mixing aggregates with regular columns",
        "DISTINCT inside COUNT() for unique values, outside SELECT for unique rows",
        "Always CAST to REAL for percentage calculations",
        "Evidence formulas override these rules"
    ]

    # Write to output file
    output_path = "tool_output/aggregation_disambiguation.json"
    with open(output_path, 'w') as f:
        json.dump(disambiguation, f, indent=2)

    print(f"Aggregation disambiguation complete - results in {output_path}")
    print(f"Generated {len(disambiguation['sum_vs_count_rules'])} SUM vs COUNT rules")
    print(f"Created {len(disambiguation['database_specific_hints'])} database-specific hints")

if __name__ == "__main__":
    disambiguate_aggregations()