#!/usr/bin/env python3
"""
Precision Analyzer - Focused on ownership clarity and column selection accuracy
Primary goal: Prevent returning wrong columns in SQL queries
"""

import sqlite3
import sys
import os
from collections import defaultdict
import re

def connect_db(db_path='database.sqlite'):
    """Find and connect to database"""
    if not os.path.exists(db_path):
        for base in ['.', '..', '../..']:
            potential = os.path.join(base, 'database.sqlite')
            if os.path.exists(potential):
                db_path = potential
                break
    
    if not os.path.exists(db_path):
        print(f"ERROR: Database not found")
        sys.exit(1)
    
    return sqlite3.connect(db_path)

def get_structure(conn):
    """Get complete database structure"""
    cursor = conn.cursor()
    
    # Get tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
    tables = [row[0] for row in cursor.fetchall()]
    
    structure = {}
    for table in tables:
        # Get columns
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()
        
        # Get row count
        cursor.execute(f"SELECT COUNT(*) FROM {table}")
        row_count = cursor.fetchone()[0]
        
        # Get foreign keys
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        fks = cursor.fetchall()
        
        structure[table] = {
            'columns': [(c[1], c[2], c[5]) for c in columns],  # name, type, is_pk
            'row_count': row_count,
            'foreign_keys': [(fk[3], fk[2], fk[4]) for fk in fks]  # from, to_table, to_col
        }
    
    return structure

def analyze_ownership(structure):
    """Create unambiguous ownership map"""
    ownership = {}
    
    for table, info in structure.items():
        # Identify foreign key columns
        fk_columns = {fk[0] for fk in info['foreign_keys']}
        
        # Separate owned vs referenced
        owned = []
        references = []
        
        for col_name, col_type, is_pk in info['columns']:
            if col_name in fk_columns:
                # This is a foreign key
                fk_info = next(fk for fk in info['foreign_keys'] if fk[0] == col_name)
                references.append(f"{col_name} → {fk_info[1]}.{fk_info[2]}")
            elif not (is_pk and col_name.lower() in ['id', 'pk', f'{table}_id']):
                # This table owns this column (unless it's just an ID)
                owned.append(col_name)
        
        ownership[table] = {
            'owns': owned,
            'references': references,
            'row_count': info['row_count']
        }
    
    return ownership

def discover_patterns(conn, structure, sample_size=15):
    """Find all format patterns with examples"""
    cursor = conn.cursor()
    patterns = {
        'dates': {},
        'categories': {},
        'numbers': {},
        'strings': {}
    }
    
    for table, info in structure.items():
        if info['row_count'] == 0:
            continue
        
        # Get sample data
        cursor.execute(f"SELECT * FROM {table} LIMIT {sample_size}")
        rows = cursor.fetchall()
        
        if not rows:
            continue
        
        for col_idx, (col_name, col_type, _) in enumerate(info['columns']):
            values = [row[col_idx] for row in rows if row[col_idx] is not None]
            
            if not values:
                continue
            
            col_key = f"{table}.{col_name}"
            
            # Date detection
            if 'date' in col_name.lower() or 'time' in col_name.lower():
                for val in values[:2]:
                    if isinstance(val, str):
                        if re.match(r'\d{4}-\d{2}-\d{2}', str(val)):
                            patterns['dates'][col_key] = {
                                'format': 'YYYY-MM-DD',
                                'examples': values[:2]
                            }
                            break
            
            # Categorical detection
            unique = list(set(values))
            if len(unique) <= 15 and len(unique) < len(values) * 0.6:
                patterns['categories'][col_key] = sorted(unique)
            
            # Special numeric patterns
            if any(t in col_type.upper() for t in ['INT', 'REAL', 'NUM']):
                if 'price' in col_name.lower() or 'cost' in col_name.lower():
                    nums = [v for v in values if isinstance(v, (int, float))]
                    if nums and all(isinstance(n, int) and n > 100 for n in nums[:3]):
                        patterns['numbers'][col_key] = {
                            'type': 'cents',
                            'examples': nums[:3]
                        }
                elif 'percent' in col_name.lower() or 'rate' in col_name.lower():
                    nums = [v for v in values if isinstance(v, (int, float))]
                    if nums and all(0 <= n <= 1 for n in nums):
                        patterns['numbers'][col_key] = {
                            'type': 'decimal_percentage', 
                            'examples': nums[:3]
                        }
    
    return patterns

def create_evidence_map(structure):
    """Map common evidence terms to columns"""
    evidence = defaultdict(list)
    
    for table, info in structure.items():
        for col_name, _, _ in info['columns']:
            col_lower = col_name.lower()
            full_name = f"{table}.{col_name}"
            
            # Direct mapping
            evidence[col_lower].append(full_name)
            
            # Common patterns
            if 'name' in col_lower:
                evidence['name'].append(full_name)
            if 'date' in col_lower or 'time' in col_lower:
                evidence['date'].append(full_name)
            if 'pop' in col_lower:
                evidence['population'].append(full_name)
            if 'income' in col_lower or 'salary' in col_lower:
                evidence['income'].append(full_name)
            if 'avg' in col_lower or 'average' in col_lower:
                evidence['average'].append(full_name)
            if 'count' in col_lower or 'total' in col_lower:
                evidence['total'].append(full_name)
    
    return dict(evidence)

def identify_junction_tables(structure):
    """Find junction tables (many-to-many relationships)"""
    junctions = []
    
    for table, info in structure.items():
        fk_count = len(info['foreign_keys'])
        total_cols = len(info['columns'])
        
        # Junction tables are mostly just foreign keys
        if fk_count >= 2 and fk_count >= (total_cols - 2):
            junctions.append(table)
    
    return junctions

def main():
    conn = connect_db()
    
    print("=== PRECISION DATABASE ANALYSIS ===\n")
    
    # Get structure
    structure = get_structure(conn)
    
    # Complexity
    total_tables = len(structure)
    total_cols = sum(len(info['columns']) for info in structure.values())
    
    if total_tables <= 5:
        complexity = "Simple"
    elif total_tables <= 10:
        complexity = "Medium"
    else:
        complexity = "Complex"
    
    print(f"## 1. COMPLEXITY ASSESSMENT")
    print(f"{complexity}: {total_tables} tables, {total_cols} total columns")
    print(f"Analysis depth: {'Standard' if complexity == 'Simple' else 'Enhanced'}\n")
    
    # Ownership
    ownership = analyze_ownership(structure)
    
    print("## 2. ATTRIBUTION OWNERSHIP MAP")
    for table, info in ownership.items():
        print(f"\n{table} ({info['row_count']} rows):")
        if info['owns']:
            # Limit output for readability
            owns_list = info['owns'][:20]
            if len(info['owns']) > 20:
                owns_list.append(f"... and {len(info['owns'])-20} more")
            print(f"  OWNS: {', '.join(owns_list)}")
        else:
            print(f"  OWNS: Nothing (likely junction table)")
        
        if info['references']:
            refs_list = info['references'][:5]
            if len(info['references']) > 5:
                refs_list.append(f"... and {len(info['references'])-5} more")
            print(f"  REFERENCES: {', '.join(refs_list)}")
    
    # Patterns
    patterns = discover_patterns(conn, structure)
    
    print("\n## 3. FORMAT PATTERNS DISCOVERED")
    
    if patterns['dates']:
        print("\nDates:")
        for col, info in list(patterns['dates'].items())[:8]:
            examples = ', '.join(f"'{e}'" for e in info['examples'])
            print(f"  {col}: {info['format']} (examples: {examples})")
    
    if patterns['categories']:
        print("\nCategories:")
        count = 0
        for col, values in patterns['categories'].items():
            if count >= 10:
                break
            if len(values) <= 8:
                print(f"  {col}: {values}")
            else:
                print(f"  {col}: {values[:5]} ... [{len(values)} total]")
            count += 1
    
    if patterns['numbers']:
        print("\nNumeric Patterns:")
        for col, info in patterns['numbers'].items():
            examples = ', '.join(str(e) for e in info['examples'])
            if info['type'] == 'cents':
                print(f"  {col}: Stored as cents (examples: {examples})")
            elif info['type'] == 'decimal_percentage':
                print(f"  {col}: Decimal percentage (examples: {examples})")
    
    # Evidence map
    evidence = create_evidence_map(structure)
    
    print("\n## 4. EVIDENCE RECONCILIATION MAP")
    priority_terms = ['name', 'date', 'population', 'income', 'average', 'total', 
                      'count', 'price', 'status', 'type']
    
    for term in priority_terms:
        if term in evidence and evidence[term]:
            cols = evidence[term][:3]
            if len(evidence[term]) > 3:
                cols.append(f"... {len(evidence[term])-3} more")
            print(f"  '{term}' → {cols}")
    
    # Relationships
    print("\n## 5. CRITICAL RELATIONSHIPS")
    
    relationships_shown = 0
    for table, info in structure.items():
        for from_col, to_table, to_col in info['foreign_keys']:
            if relationships_shown >= 15:
                break
            print(f"{table}.{from_col} → {to_table}.{to_col}")
            relationships_shown += 1
        if relationships_shown >= 15:
            break
    
    # Junction tables
    junctions = identify_junction_tables(structure)
    if junctions:
        print(f"\nJunction Tables: {', '.join(junctions)}")
    
    # Column selection reminders
    print("\n## 6. COLUMN SELECTION REMINDERS")
    print("- If asked for ONE thing, return ONE column")
    print("- Never add counts unless asked 'how many'")
    print("- Never add names unless asked for names") 
    print("- Check ownership before EVERY join")
    print("- Evidence values must be used exactly")
    
    conn.close()
    print("\n=== ANALYSIS COMPLETE ===")

if __name__ == "__main__":
    main()