#!/usr/bin/env python3
"""
Aggregation Disambiguator Tool
Clarifies when to use different aggregation functions.
"""

import json
import os

def disambiguate_aggregations():
    """Generate clear rules for choosing aggregation functions."""

    os.makedirs("tool_output", exist_ok=True)

    disambiguation_rules = {
        "sum_vs_count": [],
        "aggregation_patterns": [],
        "percentage_formulas": [],
        "group_by_rules": [],
        "common_mistakes": [],
        "database_specific": []
    }

    # SUM vs COUNT disambiguation
    disambiguation_rules["sum_vs_count"] = [
        {
            "use_case": "Counting records/rows/entities",
            "function": "COUNT",
            "examples": [
                "How many customers → COUNT(*)",
                "Number of orders → COUNT(*)",
                "Count of products → COUNT(*)"
            ],
            "warning": "Never use SUM for counting records"
        },
        {
            "use_case": "Counting non-null values in a column",
            "function": "COUNT(column)",
            "examples": [
                "How many customers have email → COUNT(email)",
                "Products with description → COUNT(description)"
            ]
        },
        {
            "use_case": "Counting distinct/unique values",
            "function": "COUNT(DISTINCT column)",
            "examples": [
                "How many different countries → COUNT(DISTINCT country)",
                "Number of unique products → COUNT(DISTINCT product_id)"
            ]
        },
        {
            "use_case": "Adding up numeric values",
            "function": "SUM",
            "examples": [
                "Total revenue → SUM(revenue)",
                "Sum of quantities → SUM(quantity)",
                "Total amount → SUM(amount)"
            ],
            "warning": "Only use SUM for numeric columns"
        },
        {
            "use_case": "Conditional counting",
            "function": "COUNT(CASE WHEN condition THEN 1 END)",
            "alternative": "SUM(CASE WHEN condition THEN 1 ELSE 0 END)",
            "examples": [
                "Count of active users → COUNT(CASE WHEN status='active' THEN 1 END)",
                "Number of high-value orders → COUNT(CASE WHEN amount > 1000 THEN 1 END)"
            ]
        }
    ]

    # Common aggregation patterns
    disambiguation_rules["aggregation_patterns"] = [
        {
            "question_pattern": "Total [numeric_field]",
            "correct_function": "SUM(numeric_field)",
            "wrong_function": "COUNT(numeric_field)",
            "example": "Total sales → SUM(sales_amount) NOT COUNT(sales_amount)"
        },
        {
            "question_pattern": "How many [entities]",
            "correct_function": "COUNT(*)",
            "wrong_function": "SUM(*)",
            "example": "How many orders → COUNT(*) NOT SUM(order_id)"
        },
        {
            "question_pattern": "Number of [entities] per [group]",
            "correct_function": "COUNT(*) ... GROUP BY group",
            "example": "Orders per customer → SELECT customer_id, COUNT(*) GROUP BY customer_id"
        },
        {
            "question_pattern": "Average [numeric_field]",
            "correct_function": "AVG(numeric_field)",
            "note": "Returns mean value, not sum or count"
        },
        {
            "question_pattern": "Maximum/Highest [numeric_field]",
            "correct_function": "MAX(numeric_field)",
            "note": "For entities with max, use ORDER BY DESC LIMIT 1"
        },
        {
            "question_pattern": "[Entity] with most [related_items]",
            "correct_function": "GROUP BY entity ORDER BY COUNT(*) DESC LIMIT 1",
            "wrong": "MAX(COUNT(*)) - This doesn't return the entity"
        }
    ]

    # Percentage calculation formulas
    disambiguation_rules["percentage_formulas"] = [
        {
            "type": "Percentage of total",
            "formula": "COUNT(condition) * 100.0 / COUNT(*)",
            "sql": "SELECT COUNT(CASE WHEN condition THEN 1 END) * 100.0 / COUNT(*)",
            "note": "Use 100.0 for decimal result"
        },
        {
            "type": "Percentage with CAST",
            "formula": "CAST(COUNT(condition) AS REAL) * 100 / COUNT(*)",
            "sql": "SELECT CAST(COUNT(CASE WHEN active THEN 1 END) AS REAL) * 100 / COUNT(*)",
            "note": "CAST ensures decimal division"
        },
        {
            "type": "Percentage by group",
            "formula": "COUNT(*) * 100.0 / SUM(COUNT(*)) OVER()",
            "sql": "SELECT category, COUNT(*) * 100.0 / SUM(COUNT(*)) OVER() FROM table GROUP BY category",
            "note": "Window function for group percentages"
        }
    ]

    # GROUP BY rules
    disambiguation_rules["group_by_rules"] = [
        {
            "rule": "Required when mixing aggregate and non-aggregate columns",
            "example": "SELECT category, COUNT(*) FROM products GROUP BY category",
            "warning": "All non-aggregated columns must be in GROUP BY"
        },
        {
            "rule": "GROUP BY comes before ORDER BY",
            "correct": "GROUP BY category ORDER BY COUNT(*) DESC",
            "wrong": "ORDER BY COUNT(*) DESC GROUP BY category"
        },
        {
            "rule": "Can GROUP BY multiple columns",
            "example": "GROUP BY category, subcategory",
            "result": "Creates groups for each unique combination"
        },
        {
            "rule": "HAVING filters groups, WHERE filters rows",
            "example": "GROUP BY customer HAVING COUNT(*) > 5",
            "note": "Use HAVING for aggregate conditions after grouping"
        }
    ]

    # Common mistakes to avoid
    disambiguation_rules["common_mistakes"] = [
        {
            "mistake": "Using SUM to count records",
            "wrong": "SELECT SUM(1) FROM table",
            "correct": "SELECT COUNT(*) FROM table",
            "impact": "SUM(1) works but COUNT(*) is clearer"
        },
        {
            "mistake": "COUNT(1) vs COUNT(*)",
            "note": "Both work identically in most databases",
            "recommendation": "Use COUNT(*) for clarity"
        },
        {
            "mistake": "SUM(COUNT(*))",
            "wrong": "SELECT SUM(COUNT(*)) FROM table",
            "correct": "SELECT COUNT(*) FROM table",
            "note": "SUM of COUNT rarely makes sense without GROUP BY"
        },
        {
            "mistake": "Missing DISTINCT when needed",
            "wrong": "COUNT(customer_id) for unique customers",
            "correct": "COUNT(DISTINCT customer_id)",
            "impact": "Counts all occurrences vs unique values"
        },
        {
            "mistake": "Using MAX to find entity with highest value",
            "wrong": "WHERE value = (SELECT MAX(value) ...)",
            "correct": "ORDER BY value DESC LIMIT 1",
            "reason": "Simpler and handles ties better"
        }
    ]

    # Write output
    with open("tool_output/aggregation_rules.json", "w") as f:
        json.dump(disambiguation_rules, f, indent=2)

    print("Aggregation disambiguation complete")
    print(f"Generated {len(disambiguation_rules['sum_vs_count'])} SUM vs COUNT rules")
    print(f"Generated {len(disambiguation_rules['aggregation_patterns'])} aggregation patterns")
    print("Results saved to tool_output/aggregation_rules.json")

if __name__ == "__main__":
    disambiguate_aggregations()