#!/usr/bin/env python3
"""
Evidence Pattern Extractor - Analyzes database patterns to help with evidence interpretation
Focuses on identifying naming patterns, value patterns, and common query structures
"""

import sqlite3
import json
import os
import re
from typing import Dict, List, Any, Set
from collections import defaultdict

def connect_db(db_path: str) -> sqlite3.Connection:
    """Connect to the database"""
    return sqlite3.connect(db_path)

def extract_naming_patterns(conn: sqlite3.Connection) -> Dict[str, Any]:
    """Extract naming patterns and create evidence mapping hints"""
    cursor = conn.cursor()

    # Get all tables and columns
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    naming_patterns = {
        'column_aliases': defaultdict(list),
        'table_aliases': defaultdict(list),
        'common_prefixes': defaultdict(list),
        'common_suffixes': defaultdict(list),
        'abbreviations': {}
    }

    # Common abbreviations and their expansions
    abbrev_map = {
        'num': ['number', 'numeric'],
        'qty': ['quantity'],
        'amt': ['amount'],
        'desc': ['description'],
        'addr': ['address'],
        'dept': ['department'],
        'emp': ['employee'],
        'mgr': ['manager', 'management'],
        'cust': ['customer'],
        'prod': ['product'],
        'cat': ['category'],
        'loc': ['location'],
        'org': ['organization'],
        'acct': ['account'],
        'trans': ['transaction'],
        'ref': ['reference'],
        'id': ['identifier', 'identity'],
        'dt': ['date'],
        'tm': ['time'],
        'yr': ['year'],
        'mth': ['month'],
        'grp': ['group'],
        'stat': ['status', 'state'],
        'val': ['value'],
        'pct': ['percent', 'percentage'],
        'avg': ['average'],
        'min': ['minimum'],
        'max': ['maximum']
    }

    all_columns = []

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        for col in columns:
            col_name = col[1]
            col_type = col[2]
            all_columns.append((table, col_name, col_type))

            # Extract prefixes and suffixes
            parts = col_name.lower().split('_')
            if len(parts) > 1:
                naming_patterns['common_prefixes'][parts[0]].append(f"{table}.{col_name}")
                naming_patterns['common_suffixes'][parts[-1]].append(f"{table}.{col_name}")

            # Check for abbreviations
            col_lower = col_name.lower()
            for abbrev, expansions in abbrev_map.items():
                if abbrev in col_lower:
                    for expansion in expansions:
                        naming_patterns['column_aliases'][expansion].append(f"{table}.{col_name}")
                        naming_patterns['abbreviations'][f"{table}.{col_name}"] = expansion

            # Common column name patterns
            if 'name' in col_lower:
                naming_patterns['column_aliases']['name'].append(f"{table}.{col_name}")
            if 'date' in col_lower or 'time' in col_lower:
                naming_patterns['column_aliases']['temporal'].append(f"{table}.{col_name}")
            if 'amount' in col_lower or 'price' in col_lower or 'cost' in col_lower or 'total' in col_lower:
                naming_patterns['column_aliases']['monetary'].append(f"{table}.{col_name}")
            if col_name.endswith('_id') or col_name.endswith('ID') or col_name == 'id':
                naming_patterns['column_aliases']['identifier'].append(f"{table}.{col_name}")

    # Clean up patterns - only keep those that appear multiple times
    for key in ['common_prefixes', 'common_suffixes']:
        naming_patterns[key] = {k: v for k, v in naming_patterns[key].items() if len(v) > 1}

    return naming_patterns

def extract_value_patterns(conn: sqlite3.Connection) -> Dict[str, Any]:
    """Extract common value patterns for evidence matching"""
    cursor = conn.cursor()

    value_patterns = {
        'categorical_columns': [],
        'date_formats': [],
        'numeric_ranges': [],
        'text_patterns': [],
        'boolean_columns': []
    }

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        for col in columns:
            col_name = col[1]
            col_type = col[2]

            try:
                # Check for categorical columns (low cardinality)
                cursor.execute(f"SELECT COUNT(DISTINCT {col_name}), COUNT(*) FROM {table}")
                distinct, total = cursor.fetchone()

                if distinct and total and distinct < 50 and total > distinct * 2:
                    # Get the categories
                    cursor.execute(f"SELECT DISTINCT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 20")
                    categories = [row[0] for row in cursor.fetchall()]

                    value_patterns['categorical_columns'].append({
                        'table': table,
                        'column': col_name,
                        'categories': categories,
                        'cardinality': distinct
                    })

                    # Check for boolean-like columns
                    if distinct == 2:
                        value_patterns['boolean_columns'].append({
                            'table': table,
                            'column': col_name,
                            'values': categories
                        })

                # Check for date patterns
                if 'date' in col_type.lower() or 'time' in col_type.lower():
                    cursor.execute(f"SELECT DISTINCT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 5")
                    samples = [row[0] for row in cursor.fetchall()]

                    if samples:
                        # Try to detect date format
                        date_format = detect_date_format(samples)
                        if date_format:
                            value_patterns['date_formats'].append({
                                'table': table,
                                'column': col_name,
                                'format': date_format,
                                'samples': samples[:3]
                            })

                # Check for numeric ranges
                if 'INT' in col_type.upper() or 'REAL' in col_type.upper() or 'NUMERIC' in col_type.upper():
                    cursor.execute(f"SELECT MIN({col_name}), MAX({col_name}), AVG({col_name}) FROM {table}")
                    min_val, max_val, avg_val = cursor.fetchone()

                    if min_val is not None:
                        value_patterns['numeric_ranges'].append({
                            'table': table,
                            'column': col_name,
                            'min': min_val,
                            'max': max_val,
                            'average': avg_val,
                            'likely_year': 1900 <= min_val <= 2100 and 1900 <= max_val <= 2100,
                            'likely_percentage': 0 <= min_val <= 100 and 0 <= max_val <= 100,
                            'likely_currency': min_val >= 0 and (max_val > 100 or '.' in str(avg_val))
                        })

                # Check for text patterns
                if 'TEXT' in col_type.upper() or 'CHAR' in col_type.upper():
                    cursor.execute(f"SELECT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 100")
                    samples = [row[0] for row in cursor.fetchall() if row[0]]

                    if samples:
                        patterns = analyze_text_patterns(samples)
                        if patterns:
                            value_patterns['text_patterns'].append({
                                'table': table,
                                'column': col_name,
                                'patterns': patterns
                            })
            except Exception as e:
                pass

    return value_patterns

def detect_date_format(samples: List) -> str:
    """Detect common date format from samples"""
    if not samples:
        return None

    # Common date patterns
    patterns = [
        (r'^\d{4}-\d{2}-\d{2}$', 'YYYY-MM-DD'),
        (r'^\d{2}/\d{2}/\d{4}$', 'MM/DD/YYYY'),
        (r'^\d{2}-\d{2}-\d{4}$', 'DD-MM-YYYY'),
        (r'^\d{4}$', 'YYYY'),
        (r'^\d{4}-\d{2}$', 'YYYY-MM'),
    ]

    for sample in samples:
        if sample:
            sample_str = str(sample)
            for pattern, format_name in patterns:
                if re.match(pattern, sample_str):
                    return format_name

    return 'UNKNOWN'

def analyze_text_patterns(samples: List[str]) -> Dict:
    """Analyze text samples for common patterns"""
    patterns = {
        'has_uppercase': False,
        'has_lowercase': False,
        'has_mixed_case': False,
        'has_numbers': False,
        'has_special_chars': False,
        'common_prefix': None,
        'common_suffix': None,
        'average_length': 0,
        'likely_email': False,
        'likely_url': False,
        'likely_phone': False,
        'likely_code': False
    }

    if not samples:
        return patterns

    # Analyze patterns
    lengths = []
    has_upper = []
    has_lower = []
    has_digit = []

    for sample in samples:
        if sample:
            lengths.append(len(sample))
            has_upper.append(any(c.isupper() for c in sample))
            has_lower.append(any(c.islower() for c in sample))
            has_digit.append(any(c.isdigit() for c in sample))

            # Check for specific patterns
            if '@' in sample and '.' in sample:
                patterns['likely_email'] = True
            if sample.startswith('http') or 'www.' in sample:
                patterns['likely_url'] = True
            if re.match(r'[\d\-\(\)\+\s]{10,}', sample):
                patterns['likely_phone'] = True

    patterns['has_uppercase'] = any(has_upper)
    patterns['has_lowercase'] = any(has_lower)
    patterns['has_mixed_case'] = patterns['has_uppercase'] and patterns['has_lowercase']
    patterns['has_numbers'] = any(has_digit)
    patterns['average_length'] = sum(lengths) / len(lengths) if lengths else 0

    # Check if likely a code (short, uppercase, maybe with numbers)
    if patterns['average_length'] < 10 and patterns['has_uppercase'] and not patterns['has_lowercase']:
        patterns['likely_code'] = True

    return patterns

def generate_evidence_mappings(naming_patterns: Dict, value_patterns: Dict) -> Dict[str, Any]:
    """Generate evidence interpretation mappings"""
    mappings = {
        'column_mappings': {},
        'value_mappings': {},
        'calculation_patterns': [],
        'temporal_patterns': [],
        'aggregation_hints': []
    }

    # Create column mapping suggestions
    for alias, columns in naming_patterns['column_aliases'].items():
        if alias not in ['identifier', 'temporal', 'monetary']:  # Skip generic categories
            mappings['column_mappings'][alias] = {
                'possible_columns': columns,
                'confidence': 'HIGH' if len(columns) == 1 else 'MEDIUM'
            }

    # Add abbreviation mappings
    for col, expansion in naming_patterns['abbreviations'].items():
        if expansion not in mappings['column_mappings']:
            mappings['column_mappings'][expansion] = {
                'possible_columns': [col],
                'confidence': 'HIGH'
            }

    # Create value mapping suggestions
    for cat_col in value_patterns['categorical_columns']:
        if cat_col['cardinality'] < 10:
            for category in cat_col['categories']:
                if category and isinstance(category, str):
                    cat_lower = str(category).lower()
                    key = f"{cat_col['table']}.{cat_col['column']}.{cat_lower}"
                    mappings['value_mappings'][key] = {
                        'table': cat_col['table'],
                        'column': cat_col['column'],
                        'value': category,
                        'type': 'categorical'
                    }

    # Add calculation patterns
    for num_col in value_patterns['numeric_ranges']:
        if num_col['likely_percentage']:
            mappings['calculation_patterns'].append({
                'table': num_col['table'],
                'column': num_col['column'],
                'pattern': 'percentage',
                'hint': f"Values range 0-100, likely percentage"
            })
        elif num_col['likely_currency']:
            mappings['calculation_patterns'].append({
                'table': num_col['table'],
                'column': num_col['column'],
                'pattern': 'currency',
                'hint': f"Likely monetary amount"
            })

    # Add temporal patterns
    for date_col in value_patterns['date_formats']:
        mappings['temporal_patterns'].append({
            'table': date_col['table'],
            'column': date_col['column'],
            'format': date_col['format'],
            'samples': date_col['samples']
        })

    # Add aggregation hints based on patterns
    monetary_cols = naming_patterns['column_aliases'].get('monetary', [])
    if monetary_cols:
        mappings['aggregation_hints'].append({
            'columns': monetary_cols,
            'suggested_operations': ['SUM', 'AVG', 'MAX', 'MIN'],
            'type': 'monetary'
        })

    identifier_cols = naming_patterns['column_aliases'].get('identifier', [])
    if identifier_cols:
        mappings['aggregation_hints'].append({
            'columns': identifier_cols,
            'suggested_operations': ['COUNT(DISTINCT col)'],
            'type': 'identifier'
        })

    return mappings

def main():
    db_path = "./database.sqlite"
    output_dir = "./tool_output"
    os.makedirs(output_dir, exist_ok=True)

    conn = connect_db(db_path)

    # Extract patterns
    naming_patterns = extract_naming_patterns(conn)
    value_patterns = extract_value_patterns(conn)
    evidence_mappings = generate_evidence_mappings(naming_patterns, value_patterns)

    # Create comprehensive report
    pattern_report = {
        'naming_patterns': naming_patterns,
        'value_patterns': value_patterns,
        'evidence_mappings': evidence_mappings,
        'summary': {
            'total_categorical_columns': len(value_patterns['categorical_columns']),
            'total_boolean_columns': len(value_patterns['boolean_columns']),
            'total_date_columns': len(value_patterns['date_formats']),
            'total_numeric_columns': len(value_patterns['numeric_ranges']),
            'total_abbreviations': len(naming_patterns['abbreviations']),
            'total_column_mappings': len(evidence_mappings['column_mappings'])
        }
    }

    # Write output
    with open(f"{output_dir}/evidence_patterns.json", 'w') as f:
        json.dump(pattern_report, f, indent=2, default=str)

    print("Evidence pattern extraction complete")
    print(f"Found {pattern_report['summary']['total_column_mappings']} column mappings")
    print(f"Identified {pattern_report['summary']['total_categorical_columns']} categorical columns")
    print(f"Detected {len(evidence_mappings['calculation_patterns'])} calculation patterns")

    conn.close()

if __name__ == "__main__":
    main()