#!/usr/bin/env python3
"""
Domain Specialist Tool - Provides domain-specific guidance
Analyzes the database to provide specialized patterns for different domains.
"""

import sqlite3
import json
import os

def analyze_domain_specifics(db_path="database.sqlite"):
    """Provide domain-specific analysis and guidance."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    domain_guide = {
        "detected_domain": "unknown",
        "special_patterns": [],
        "critical_columns": {},
        "value_variations": {},
        "common_query_patterns": []
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = cursor.fetchall()
        table_names = [t[0] for t in tables]
        table_names_lower = [t.lower() for t in table_names]

        # Detect domain
        if any('patient' in t or 'careplan' in t or 'condition' in t for t in table_names_lower):
            domain_guide["detected_domain"] = "medical"
            analyze_medical_domain(conn, cursor, domain_guide, table_names)

        elif any('weather' in t or 'station' in t for t in table_names_lower):
            domain_guide["detected_domain"] = "weather"
            analyze_weather_domain(conn, cursor, domain_guide, table_names)

        elif any('recipe' in t or 'ingredient' in t or 'nutrition' in t for t in table_names_lower):
            domain_guide["detected_domain"] = "food"
            analyze_food_domain(conn, cursor, domain_guide, table_names)

        elif any('sale' in t or 'store' in t or 'item' in t for t in table_names_lower):
            domain_guide["detected_domain"] = "sales"
            analyze_sales_domain(conn, cursor, domain_guide, table_names)

        else:
            # Generic analysis for unknown domains
            analyze_generic_domain(conn, cursor, domain_guide, table_names)

    except Exception as e:
        domain_guide["error"] = str(e)
    finally:
        conn.close()

    # Write to output file
    output_path = "tool_output/domain_guide.json"
    with open(output_path, 'w') as f:
        json.dump(domain_guide, f, indent=2)

    print(f"Domain analysis complete - results in {output_path}")
    print(f"Detected domain: {domain_guide['detected_domain']}")

def analyze_medical_domain(conn, cursor, guide, tables):
    """Analyze medical/healthcare specific patterns."""

    guide["special_patterns"] = [
        {
            "pattern": "Age calculation",
            "sql": "(julianday(event_date) - julianday(birthdate)) / 365.25",
            "usage": "Calculate patient age at time of event"
        },
        {
            "pattern": "Year extraction",
            "sql": "strftime('%Y', date_column)",
            "usage": "Extract year from dates"
        },
        {
            "pattern": "Race vs Ethnicity",
            "note": "These are usually different columns - check both"
        }
    ]

    # Look for specific medical columns
    for table in tables:
        cursor.execute(f"PRAGMA table_info(`{table}`)")
        columns = cursor.fetchall()

        for col in columns:
            col_name = col[1]
            col_lower = col_name.lower()

            # Medical codes
            if 'code' in col_lower:
                cursor.execute(f"SELECT DISTINCT `{col_name}` FROM `{table}` LIMIT 20")
                codes = [row[0] for row in cursor.fetchall() if row[0]]
                if codes:
                    guide["critical_columns"][f"{table}.{col_name}"] = {
                        "type": "medical_code",
                        "samples": codes[:10],
                        "note": "Medical codes must be exact"
                    }

            # Race/ethnicity confusion
            if 'race' in col_lower or 'ethnicity' in col_lower:
                cursor.execute(f"SELECT DISTINCT `{col_name}` FROM `{table}` LIMIT 10")
                values = [row[0] for row in cursor.fetchall() if row[0]]
                guide["critical_columns"][f"{table}.{col_name}"] = {
                    "type": "demographic",
                    "samples": values,
                    "note": f"This is {col_name} - not {'ethnicity' if 'race' in col_lower else 'race'}"
                }

            # Medical descriptions
            if 'description' in col_lower:
                cursor.execute(f"SELECT DISTINCT `{col_name}` FROM `{table}` WHERE `{col_name}` LIKE '%grass%' OR `{col_name}` LIKE '%allerg%' LIMIT 10")
                descriptions = [row[0] for row in cursor.fetchall()]
                if descriptions:
                    guide["value_variations"][f"{table}.{col_name}"] = {
                        "samples": descriptions,
                        "note": "Exact matching required - 'grass pollen' not 'grass'"
                    }

    guide["common_query_patterns"] = [
        "Patient age calculations: Use julianday for accurate age",
        "Medical codes: Must match exactly from evidence",
        "Conditions: COUNT(DISTINCT) for unique conditions per patient",
        "Date ranges: Use strftime for year/month extraction"
    ]

def analyze_weather_domain(conn, cursor, guide, tables):
    """Analyze weather/climate specific patterns."""

    guide["special_patterns"] = [
        {
            "pattern": "Date matching between tables",
            "sql": "sales.date = weather.date AND relation.station = weather.station",
            "usage": "Join sales and weather data"
        },
        {
            "pattern": "Time comparisons",
            "sql": "time(sunrise) < time('05:00:00')",
            "usage": "Compare time values"
        },
        {
            "pattern": "Conditional weather sums",
            "sql": "SUM(CASE WHEN temperature > 90 THEN units ELSE 0 END)",
            "usage": "Sum only when weather condition met"
        }
    ]

    # Look for weather-specific columns
    for table in tables:
        cursor.execute(f"PRAGMA table_info(`{table}`)")
        columns = cursor.fetchall()

        for col in columns:
            col_name = col[1]
            col_lower = col_name.lower()

            # Time columns
            if any(word in col_lower for word in ['sunrise', 'sunset', 'time']):
                cursor.execute(f"SELECT DISTINCT `{col_name}` FROM `{table}` LIMIT 5")
                times = [row[0] for row in cursor.fetchall() if row[0]]
                if times:
                    guide["critical_columns"][f"{table}.{col_name}"] = {
                        "type": "time",
                        "samples": times[:3],
                        "note": "Use time() function for comparisons"
                    }

            # Temperature columns
            if any(word in col_lower for word in ['temp', 'tmax', 'tmin', 'depart']):
                guide["critical_columns"][f"{table}.{col_name}"] = {
                    "type": "temperature",
                    "note": "Numeric comparisons, often used in conditions"
                }

    guide["common_query_patterns"] = [
        "Date-based joins: Match on date AND location",
        "Weather conditions: Use CASE WHEN for conditional aggregation",
        "Time comparisons: Always use time() function",
        "Station relations: Join through relation/mapping table"
    ]

def analyze_food_domain(conn, cursor, guide, tables):
    """Analyze recipe/food specific patterns."""

    guide["special_patterns"] = [
        {
            "pattern": "Category matching",
            "sql": "category LIKE '%dairy%'",
            "usage": "Use wildcards on both sides"
        },
        {
            "pattern": "Nutritional thresholds",
            "sql": "sodium < 5",
            "usage": "Use < not BETWEEN 0 AND X"
        },
        {
            "pattern": "Quantity conditions",
            "sql": "max_qty = min_qty",
            "usage": "Check quantity relationships"
        }
    ]

    # Look for ingredient variations
    for table in tables:
        if 'ingredient' in table.lower():
            cursor.execute(f"SELECT name FROM pragma_table_info('{table}') WHERE name LIKE '%name%'")
            name_cols = cursor.fetchall()

            for col in name_cols:
                col_name = col[0]
                cursor.execute(f"SELECT DISTINCT `{col_name}` FROM `{table}` LIMIT 50")
                ingredients = [row[0] for row in cursor.fetchall() if row[0]]

                # Look for variations
                variations = {}
                for ing in ingredients:
                    if isinstance(ing, str):
                        base = ing.lower().replace('steak', '').replace('fillet', '').strip()
                        if base not in variations:
                            variations[base] = []
                        variations[base].append(ing)

                # Find ingredients with variations
                for base, variants in variations.items():
                    if len(variants) > 1:
                        guide["value_variations"][f"{table}.{col_name}"] = {
                            "base": base,
                            "variants": variants,
                            "note": "Must use exact variant from database"
                        }

    guide["common_query_patterns"] = [
        "Ingredient matching: Use exact names (e.g., 'sea bass steak')",
        "Categories: Use %category% with wildcards on both sides",
        "Nutritional values: Simple < or > comparisons",
        "Recipe counts: Often need COUNT(DISTINCT recipe_id)"
    ]

def analyze_sales_domain(conn, cursor, guide, tables):
    """Analyze sales/retail specific patterns."""

    guide["special_patterns"] = [
        {
            "pattern": "Store-station relationships",
            "sql": "JOIN relation ON sales.store = relation.store",
            "usage": "Connect stores to weather stations"
        },
        {
            "pattern": "Date-based aggregation",
            "sql": "substr(date, 1, 4) = '2012'",
            "usage": "Filter by year"
        }
    ]

    guide["common_query_patterns"] = [
        "Sales aggregation: SUM(units) or COUNT(*)",
        "Store relationships: Join through relation table",
        "Item analysis: GROUP BY item_nbr",
        "Date filtering: Use substr or strftime for date parts"
    ]

def analyze_generic_domain(conn, cursor, guide, tables):
    """Generic analysis for unrecognized domains."""

    guide["special_patterns"] = [
        {
            "pattern": "Standard aggregation",
            "sql": "COUNT(*), SUM(column), AVG(column)",
            "usage": "Common aggregation functions"
        }
    ]

    guide["common_query_patterns"] = [
        "Use COUNT(DISTINCT) for unique counts",
        "Return only requested columns",
        "Match values exactly from database",
        "Use simplest query structure possible"
    ]

if __name__ == "__main__":
    analyze_domain_specifics()