#!/usr/bin/env python3
"""
Maps aggregation targets and identifies GROUP BY requirements.
Determines when DISTINCT is needed and what entities to count.
"""

import sqlite3
import json
import os
from collections import defaultdict

def analyze_aggregation_patterns(db_path="database.sqlite"):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    aggregation_info = {
        "primary_keys": {},
        "unique_columns": {},
        "countable_entities": {},
        "grouping_candidates": {},
        "distinct_requirements": {},
        "aggregation_hints": {}
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [row[0] for row in cursor.fetchall()]

        for table in tables:
            # Get table info
            cursor.execute(f"PRAGMA table_info({table})")
            columns = cursor.fetchall()
            column_names = [col[1] for col in columns]

            # Get primary keys
            pk_columns = [col[1] for col in columns if col[5] > 0]  # pk flag
            if pk_columns:
                aggregation_info["primary_keys"][table] = pk_columns

            # Check for unique constraints
            cursor.execute(f"PRAGMA index_list({table})")
            indexes = cursor.fetchall()

            unique_cols = []
            for idx in indexes:
                if idx[2] == 1:  # unique flag
                    cursor.execute(f"PRAGMA index_info({idx[1]})")
                    idx_cols = cursor.fetchall()
                    unique_cols.extend([col[2] for col in idx_cols])

            if unique_cols:
                aggregation_info["unique_columns"][table] = list(set(unique_cols))

            # Identify countable entities
            countable = []

            # Primary key is always countable
            if pk_columns:
                countable.append({
                    "column": pk_columns[0] if len(pk_columns) == 1 else f"({', '.join(pk_columns)})",
                    "type": "primary_key",
                    "use_for": "counting unique records"
                })

            # ID-like columns are countable
            for col_name in column_names:
                if (col_name.lower().endswith('_id') or
                    col_name.lower().endswith('id') or
                    col_name.lower() in ['id', 'code', 'key', 'identifier']):

                    # Check uniqueness
                    cursor.execute(f"SELECT COUNT(DISTINCT {col_name}), COUNT(*) FROM {table}")
                    distinct_count, total_count = cursor.fetchone()

                    if distinct_count and total_count:
                        uniqueness_ratio = distinct_count / total_count
                        if uniqueness_ratio > 0.9:  # Mostly unique
                            countable.append({
                                "column": col_name,
                                "type": "high_cardinality_id",
                                "uniqueness": round(uniqueness_ratio, 3),
                                "use_for": "counting distinct entities"
                            })
                        elif uniqueness_ratio > 0.1:  # Categorical
                            countable.append({
                                "column": col_name,
                                "type": "categorical_id",
                                "distinct_values": distinct_count,
                                "use_for": "grouping or counting categories"
                            })

            if countable:
                aggregation_info["countable_entities"][table] = countable

            # Identify good grouping columns
            grouping = []
            for col_name in column_names:
                # Skip primary keys and high-cardinality columns
                if col_name in pk_columns:
                    continue

                # Sample cardinality
                cursor.execute(f"""
                    SELECT COUNT(DISTINCT {col_name}) as distinct_vals,
                           COUNT(*) as total_rows
                    FROM {table}
                    WHERE {col_name} IS NOT NULL
                """)
                result = cursor.fetchone()

                if result and result[0] and result[1]:
                    distinct_vals, total_rows = result
                    cardinality_ratio = distinct_vals / total_rows

                    # Good for grouping if cardinality is not too high or too low
                    if 0.001 < cardinality_ratio < 0.5 and distinct_vals < 1000:
                        grouping.append({
                            "column": col_name,
                            "distinct_values": distinct_vals,
                            "cardinality_ratio": round(cardinality_ratio, 3),
                            "good_for": "GROUP BY operations"
                        })

            if grouping:
                # Sort by cardinality ratio (lower is often better for grouping)
                grouping.sort(key=lambda x: x["cardinality_ratio"])
                aggregation_info["grouping_candidates"][table] = grouping[:10]  # Top 10

            # Check for scenarios requiring DISTINCT
            # Look for junction tables or one-to-many relationships
            cursor.execute(f"SELECT sql FROM sqlite_master WHERE name = '{table}'")
            create_sql = cursor.fetchone()

            if create_sql and 'FOREIGN KEY' in create_sql[0]:
                # This table has foreign keys, might need DISTINCT when joining
                aggregation_info["distinct_requirements"][table] = {
                    "has_foreign_keys": True,
                    "hint": "Use COUNT(DISTINCT primary_key) when joining to avoid duplicates"
                }

        # Generate aggregation hints based on relationships
        for table, countable in aggregation_info["countable_entities"].items():
            hints = []

            # If table has multiple countable entities
            if len(countable) > 1:
                hints.append(f"Multiple countable columns available - choose based on context")

            # If table is likely a junction table
            if table in aggregation_info["distinct_requirements"]:
                if any('_id' in col["column"] for col in countable if isinstance(col["column"], str)):
                    hints.append("Junction table - use DISTINCT when counting through joins")

            if hints:
                aggregation_info["aggregation_hints"][table] = hints

        # Output results
        os.makedirs("tool_output", exist_ok=True)

        with open("tool_output/aggregation_patterns.json", "w") as f:
            json.dump(aggregation_info, f, indent=2)

        # Generate summary
        summary = []
        summary.append("# Aggregation Pattern Analysis\n\n")

        summary.append("## Primary Keys (for COUNT DISTINCT):\n")
        for table, pks in aggregation_info["primary_keys"].items():
            summary.append(f"- {table}: {', '.join(pks)}\n")

        if aggregation_info["countable_entities"]:
            summary.append("\n## Countable Entities by Table:\n")
            for table, entities in aggregation_info["countable_entities"].items():
                summary.append(f"\n{table}:\n")
                for entity in entities[:3]:  # Top 3
                    summary.append(f"  - {entity['column']}: {entity['use_for']}\n")

        if aggregation_info["grouping_candidates"]:
            summary.append("\n## Best GROUP BY Columns:\n")
            for table, candidates in aggregation_info["grouping_candidates"].items():
                summary.append(f"\n{table}:\n")
                for candidate in candidates[:3]:  # Top 3
                    summary.append(f"  - {candidate['column']}: {candidate['distinct_values']} distinct values\n")

        if aggregation_info["distinct_requirements"]:
            summary.append("\n## Tables Requiring DISTINCT:\n")
            for table, req in aggregation_info["distinct_requirements"].items():
                summary.append(f"- {table}: {req['hint']}\n")

        with open("tool_output/aggregation_summary.txt", "w") as f:
            f.writelines(summary)

        print("Aggregation pattern analysis complete - results in tool_output/")

    except Exception as e:
        print(f"Error analyzing aggregation patterns: {e}")
        aggregation_info["error"] = str(e)
        with open("tool_output/aggregation_patterns.json", "w") as f:
            json.dump(aggregation_info, f, indent=2)

    finally:
        conn.close()


if __name__ == "__main__":
    analyze_aggregation_patterns()