#!/usr/bin/env python3
"""
Column Selector Tool
Maps question patterns to required columns for precise SQL generation.
Prevents returning wrong or extra columns.
"""

import sqlite3
import json
import os

def analyze_column_selection(db_path="database.sqlite"):
    """Analyze columns to provide selection guidance."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    column_guidance = {
        "entity_columns": {},
        "name_columns": {},
        "id_columns": {},
        "metric_columns": {},
        "selection_patterns": [],
        "column_purpose_map": {},
        "strict_selection_rules": []  # Strict rules to prevent extra columns
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = cursor.fetchall()

        for table_name, in tables:
            if table_name.startswith("sqlite_"):
                continue

            column_guidance["entity_columns"][table_name] = {}
            column_guidance["name_columns"][table_name] = []
            column_guidance["id_columns"][table_name] = []
            column_guidance["metric_columns"][table_name] = []

            # Get column information
            cursor.execute(f"PRAGMA table_info(`{table_name}`)")
            columns = cursor.fetchall()

            for col in columns:
                col_name = col[1]
                col_type = col[2].upper()
                is_pk = bool(col[5])

                col_lower = col_name.lower()

                # ID columns
                if is_pk or col_lower.endswith('id') or col_lower == 'id':
                    column_guidance["id_columns"][table_name].append(col_name)

                    # Map entity types
                    if col_lower.endswith('_id') and col_lower != 'id':
                        entity_type = col_lower[:-3]
                        column_guidance["entity_columns"][table_name][entity_type] = col_name

                # Name columns - expanded patterns
                name_patterns = ['name', 'firstname', 'lastname', 'first', 'last',
                               'fullname', 'full_name', 'username', 'display_name',
                               'title', 'label', 'description', 'company', 'brand']
                if any(pattern in col_lower for pattern in name_patterns):
                    column_guidance["name_columns"][table_name].append(col_name)

                # Metric columns
                if col_type in ['INTEGER', 'REAL', 'NUMERIC', 'FLOAT', 'DOUBLE']:
                    metric_keywords = ['amount', 'total', 'count', 'sum', 'quantity',
                                     'price', 'cost', 'value', 'score', 'rating',
                                     'salary', 'revenue', 'profit', 'balance', 'qty']
                    if any(keyword in col_lower for keyword in metric_keywords):
                        column_guidance["metric_columns"][table_name].append(col_name)

                # Map column purposes
                if table_name not in column_guidance["column_purpose_map"]:
                    column_guidance["column_purpose_map"][table_name] = {}

                # Determine column purpose
                purpose = "data"
                if is_pk:
                    purpose = "primary_key"
                elif col_lower.endswith('_id'):
                    purpose = "foreign_key"
                elif any(pattern in col_lower for pattern in name_patterns):
                    purpose = "identifier"
                elif col_type in ['DATE', 'DATETIME', 'TIMESTAMP'] or 'date' in col_lower or 'time' in col_lower:
                    purpose = "temporal"
                elif col_type in ['INTEGER', 'REAL', 'NUMERIC'] and any(k in col_lower for k in metric_keywords):
                    purpose = "metric"
                elif col_lower in ['status', 'state', 'type', 'category', 'active', 'enabled', 'flag']:
                    purpose = "categorical"

                column_guidance["column_purpose_map"][table_name][col_name] = purpose

        # Define strict selection patterns
        column_guidance["selection_patterns"] = [
            {
                "pattern": "Who/Which person",
                "guidance": "Return ONLY name columns (FirstName, LastName or equivalent)",
                "example": "SELECT first_name, last_name",
                "warning": "NEVER return IDs for 'who' questions"
            },
            {
                "pattern": "What [entity]",
                "guidance": "Return ONLY the entity's name/title column",
                "example": "What product → SELECT product_name",
                "warning": "Do NOT return ID unless 'ID' is explicitly mentioned"
            },
            {
                "pattern": "How many",
                "guidance": "Return ONLY COUNT(*) or COUNT(column)",
                "example": "SELECT COUNT(*)",
                "warning": "NEVER return what's being counted, just the count"
            },
            {
                "pattern": "List [entities]",
                "guidance": "Return ONLY identifying columns for the entity",
                "example": "List products → SELECT product_name",
                "warning": "Do NOT add extra columns for context"
            },
            {
                "pattern": "Total/Sum",
                "guidance": "Return ONLY the aggregate",
                "example": "SELECT SUM(amount)",
                "warning": "Use SUM for values, COUNT for records"
            },
            {
                "pattern": "[Entity] with most/highest",
                "guidance": "Return entity identifier ONLY, not the metric",
                "example": "Customer with most orders → SELECT customer_name ... ORDER BY COUNT(*) DESC LIMIT 1",
                "warning": "Do NOT include the count in SELECT"
            },
            {
                "pattern": "Average/Mean",
                "guidance": "Return ONLY AVG(column)",
                "example": "SELECT AVG(price)",
                "warning": "Just the average, no additional columns"
            }
        ]

        # Add strict selection rules
        column_guidance["strict_selection_rules"] = [
            {
                "rule": "Minimal Columns",
                "description": "Return the MINIMUM columns needed to answer the question",
                "examples": [
                    "Asked for name → Return name only",
                    "Asked for count → Return count only",
                    "Asked for product → Return product name only"
                ]
            },
            {
                "rule": "No Extra Context",
                "description": "Do NOT add columns for 'helpful context'",
                "examples": [
                    "Wrong: SELECT id, name, description",
                    "Right: SELECT name"
                ]
            },
            {
                "rule": "Evidence Override",
                "description": "Evidence defines exact columns to return",
                "examples": [
                    "Evidence: 'name refers to FirstName, LastName' → Return both",
                    "Evidence: 'return ID' → Include ID even if not typical"
                ]
            },
            {
                "rule": "COUNT vs DISTINCT COUNT",
                "description": "Use COUNT(DISTINCT) ONLY when 'unique', 'different', or 'distinct' mentioned",
                "examples": [
                    "How many images → COUNT(*)",
                    "How many unique images → COUNT(DISTINCT img_id)"
                ]
            }
        ]

        # Check for specific column patterns in each table
        for table_name in column_guidance["name_columns"]:
            name_cols = column_guidance["name_columns"][table_name]

            # Check for first/last name pattern
            has_first = any('first' in c.lower() for c in name_cols)
            has_last = any('last' in c.lower() for c in name_cols)

            if has_first and has_last:
                column_guidance["selection_patterns"].append({
                    "pattern": f"Person from {table_name}",
                    "guidance": f"Return FirstName and LastName columns",
                    "example": f"SELECT first_name, last_name FROM {table_name}",
                    "warning": "Include both name parts for person identification"
                })

    except Exception as e:
        column_guidance["error"] = str(e)
    finally:
        conn.close()

    # Write to output file
    output_path = "tool_output/column_guidance.json"
    with open(output_path, 'w') as f:
        json.dump(column_guidance, f, indent=2)

    print(f"Column selection analysis complete - results in {output_path}")
    print(f"Generated {len(column_guidance['selection_patterns'])} selection patterns")
    print(f"Created {len(column_guidance['strict_selection_rules'])} strict rules")

if __name__ == "__main__":
    analyze_column_selection()