#!/usr/bin/env python3
"""
Format Detector Tool
Specifically detects and documents special data formats that require conversion.
Addresses format-related errors identified in the BIRD paper.
"""

import sqlite3
import json
import re
import os
from typing import Dict, List, Any

def detect_currency_format(values: List[str]) -> Dict[str, Any]:
    """Detect currency formats in values."""
    patterns = []
    
    # Common currency patterns
    currency_regexes = {
        'us_dollar_prefix': (r'^US\$[\d,]+\.?\d*$', "US$ prefix format", "CAST(REPLACE(REPLACE(SUBSTR(column, 4), ',', ''), '$', '') AS REAL)"),
        'dollar_sign': (r'^\$[\d,]+\.?\d*$', "$ prefix format", "CAST(REPLACE(REPLACE(column, 2), ',', '') AS REAL)"),
        'euro_sign': (r'^€[\d,]+\.?\d*$', "€ prefix format", "CAST(REPLACE(REPLACE(column, 2), ',', '') AS REAL)"),
        'pound_sign': (r'^£[\d,]+\.?\d*$', "£ prefix format", "CAST(REPLACE(REPLACE(column, 2), ',', '') AS REAL)"),
        'suffix_currency': (r'^[\d,]+\s*(USD|EUR|GBP|JPY|CNY)$', "Suffix currency code", "CAST(REPLACE(SUBSTR(column, 1, LENGTH(column)-3), ',', '') AS REAL)"),
    }
    
    for name, (pattern, description, conversion) in currency_regexes.items():
        matching_values = [v for v in values[:20] if v and re.match(pattern, str(v))]
        if matching_values:
            return {
                'type': 'currency',
                'format': description,
                'pattern': pattern,
                'conversion_template': conversion,
                'examples': matching_values[:3]
            }
    
    return {}

def detect_percentage_format(values: List[Any]) -> Dict[str, Any]:
    """Detect percentage formats."""
    # Check for decimal percentages (0-1 range)
    try:
        numeric_values = []
        for v in values[:50]:
            if v is not None:
                try:
                    num = float(v)
                    numeric_values.append(num)
                except:
                    continue
        
        if numeric_values:
            # Check if values are in 0-1 range (likely decimal percentage)
            if all(0 <= v <= 1 for v in numeric_values):
                return {
                    'type': 'percentage',
                    'format': 'decimal (0-1)',
                    'conversion_template': 'column * 100',
                    'examples': numeric_values[:3]
                }
            # Check if values are in 0-100 range (likely percentage)
            elif all(0 <= v <= 100 for v in numeric_values):
                return {
                    'type': 'percentage', 
                    'format': 'percentage (0-100)',
                    'conversion_template': 'column',
                    'examples': numeric_values[:3]
                }
    except:
        pass
    
    # Check for percentage with % sign
    percent_values = [v for v in values[:20] if v and str(v).endswith('%')]
    if percent_values:
        return {
            'type': 'percentage',
            'format': 'with % sign',
            'conversion_template': "CAST(REPLACE(column, '%', '') AS REAL)",
            'examples': percent_values[:3]
        }
    
    return {}

def detect_date_format(values: List[str]) -> Dict[str, Any]:
    """Detect date/datetime formats."""
    date_patterns = [
        (r'^\d{4}-\d{2}-\d{2}$', 'YYYY-MM-DD', "DATE(column)"),
        (r'^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}', 'YYYY-MM-DD HH:MM:SS', "DATE(column)"),
        (r'^\d{2}/\d{2}/\d{4}$', 'MM/DD/YYYY', "DATE(substr(column,7,4)||'-'||substr(column,1,2)||'-'||substr(column,4,2))"),
        (r'^\d{1,2}/\d{1,2}/\d{2,4}$', 'M/D/YY', "DATE(column)"),
        (r'^\d{8}$', 'YYYYMMDD', "DATE(substr(column,1,4)||'-'||substr(column,5,2)||'-'||substr(column,7,2))"),
        (r'^\d{4}/\d{2}/\d{2}$', 'YYYY/MM/DD', "DATE(REPLACE(column, '/', '-'))"),
        (r'^\d{14}$', 'YYYYMMDDHHmmss', "DATETIME(substr(column,1,4)||'-'||substr(column,5,2)||'-'||substr(column,7,2)||' '||substr(column,9,2)||':'||substr(column,11,2)||':'||substr(column,13,2))"),
    ]
    
    for pattern, format_name, conversion in date_patterns:
        matching_values = [v for v in values[:20] if v and re.match(pattern, str(v))]
        if matching_values:
            return {
                'type': 'date',
                'format': format_name,
                'pattern': pattern,
                'conversion_template': conversion,
                'year_extraction': f"strftime('%Y', {conversion})",
                'month_extraction': f"strftime('%m', {conversion})",
                'examples': matching_values[:3]
            }
    
    return {}

def detect_numeric_in_text(values: List[str]) -> Dict[str, Any]:
    """Detect numeric values stored as text with special formatting."""
    # Check for numbers with commas
    comma_numbers = [v for v in values[:20] if v and re.match(r'^[\d,]+\.?\d*$', str(v))]
    if comma_numbers:
        return {
            'type': 'numeric_text',
            'format': 'comma-separated',
            'conversion_template': "CAST(REPLACE(column, ',', '') AS REAL)",
            'examples': comma_numbers[:3]
        }
    
    # Check for numbers in parentheses (negative)
    paren_numbers = [v for v in values[:20] if v and re.match(r'^\([\d,]+\.?\d*\)$', str(v))]
    if paren_numbers:
        return {
            'type': 'numeric_text',
            'format': 'parentheses (negative)',
            'conversion_template': "-1 * CAST(REPLACE(REPLACE(REPLACE(column, '(', ''), ')', ''), ',', '') AS REAL)",
            'examples': paren_numbers[:3]
        }
    
    return {}

def analyze_table_formats(conn: sqlite3.Connection, table: str) -> Dict[str, Any]:
    """Analyze all columns in a table for special formats."""
    cursor = conn.cursor()
    
    # Get column names
    cursor.execute(f"PRAGMA table_info(`{table}`)")
    columns = [(col[1], col[2]) for col in cursor.fetchall()]
    
    format_info = {}
    
    for col_name, col_type in columns:
        # Get sample values
        cursor.execute(f"SELECT DISTINCT `{col_name}` FROM `{table}` WHERE `{col_name}` IS NOT NULL LIMIT 100")
        values = [row[0] for row in cursor.fetchall()]
        
        if not values:
            continue
        
        # Detect formats
        detected_format = None
        
        # Try currency detection
        detected_format = detect_currency_format([str(v) for v in values])
        
        # Try percentage detection if not currency
        if not detected_format:
            detected_format = detect_percentage_format(values)
        
        # Try date detection if not currency or percentage
        if not detected_format:
            detected_format = detect_date_format([str(v) for v in values])
        
        # Try numeric in text detection
        if not detected_format and col_type in ['TEXT', 'VARCHAR']:
            detected_format = detect_numeric_in_text([str(v) for v in values])
        
        if detected_format:
            format_info[col_name] = detected_format
    
    return format_info

def main():
    """Main function to detect formats across all tables."""
    conn = sqlite3.connect('./database.sqlite')
    cursor = conn.cursor()
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
    tables = [row[0] for row in cursor.fetchall()]
    
    print("Detecting special formats in database...")
    
    all_formats = {}
    conversion_templates = {}
    
    for table in tables:
        print(f"  Analyzing {table}...")
        table_formats = analyze_table_formats(conn, table)
        
        if table_formats:
            all_formats[table] = table_formats
            
            # Create conversion templates for this table
            for col_name, format_info in table_formats.items():
                if 'conversion_template' in format_info:
                    key = f"{table}.{col_name}"
                    conversion_templates[key] = {
                        'format': format_info['format'],
                        'conversion': format_info['conversion_template'].replace('column', f"`{col_name}`"),
                        'examples': format_info.get('examples', [])
                    }
    
    # Create output directory
    os.makedirs('./tool_output', exist_ok=True)
    
    # Save detailed format analysis
    with open('./tool_output/format_analysis.json', 'w') as f:
        json.dump(all_formats, f, indent=2, default=str)
    
    # Save conversion templates
    with open('./tool_output/conversion_templates.json', 'w') as f:
        json.dump(conversion_templates, f, indent=2, default=str)
    
    # Create human-readable summary
    summary_lines = ["=== SPECIAL FORMAT DETECTION ===\n"]
    
    for table, formats in all_formats.items():
        summary_lines.append(f"\n{table}:")
        for col, format_info in formats.items():
            summary_lines.append(f"  {col}: {format_info['type']} - {format_info['format']}")
            if 'conversion_template' in format_info:
                summary_lines.append(f"    Conversion: {format_info['conversion_template']}")
            if 'examples' in format_info:
                summary_lines.append(f"    Examples: {format_info['examples'][:2]}")
    
    if not all_formats:
        summary_lines.append("\nNo special formats detected.")
    
    with open('./tool_output/format_summary.txt', 'w') as f:
        f.write('\n'.join(summary_lines))
    
    print(f"\nFormat detection complete!")
    print(f"  Detailed analysis: ./tool_output/format_analysis.json")
    print(f"  Conversion templates: ./tool_output/conversion_templates.json")
    print(f"  Summary: ./tool_output/format_summary.txt")
    
    conn.close()

if __name__ == "__main__":
    main()