#!/usr/bin/env python3
"""
Enhanced Join Optimizer
Maps and optimizes database join paths with preference for simpler approaches.
Enhanced with station name joins and weather pattern handling.
"""

import sqlite3
import json
import os
from collections import defaultdict, deque

def optimize_joins(db_path):
    """Main join optimization with enhanced path finding."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    optimization = {
        'direct_joins': [],
        'name_based_joins': [],
        'multi_hop_paths': [],
        'join_recommendations': {},
        'station_patterns': {},
        'weather_patterns': {},
        'common_join_errors': []
    }

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    # Build join graph
    join_graph = build_join_graph(cursor, tables)

    # Find optimal paths
    find_direct_joins(join_graph, optimization)
    find_name_based_joins(cursor, tables, optimization)
    find_multi_hop_paths(join_graph, optimization)
    detect_station_patterns(cursor, tables, optimization)
    detect_weather_patterns(cursor, tables, optimization)
    generate_recommendations(optimization, tables)
    identify_common_errors(optimization)

    conn.close()
    return optimization

def build_join_graph(cursor, tables):
    """Build a graph of join relationships."""
    graph = defaultdict(list)

    for table in tables:
        # Get foreign keys
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        foreign_keys = cursor.fetchall()

        for fk in foreign_keys:
            from_col = fk[3]
            to_table = fk[2]
            to_col = fk[4]

            graph[table].append({
                'to_table': to_table,
                'join_condition': f"{table}.{from_col} = {to_table}.{to_col}",
                'type': 'foreign_key'
            })
            graph[to_table].append({
                'to_table': table,
                'join_condition': f"{to_table}.{to_col} = {table}.{from_col}",
                'type': 'foreign_key_reverse'
            })

        # Get column info for name-based joins
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()
        col_names = [col[1] for col in columns]

        # Look for potential name-based joins
        for col in col_names:
            if 'name' in col.lower() or 'id' in col.lower():
                # Check other tables for matching columns
                for other_table in tables:
                    if other_table == table:
                        continue

                    cursor.execute(f"PRAGMA table_info({other_table})")
                    other_columns = cursor.fetchall()
                    other_col_names = [c[1] for c in other_columns]

                    for other_col in other_col_names:
                        if columns_match_for_join(col, other_col, table, other_table):
                            graph[table].append({
                                'to_table': other_table,
                                'join_condition': f"{table}.{col} = {other_table}.{other_col}",
                                'type': 'name_based'
                            })

    return graph

def columns_match_for_join(col1, col2, table1, table2):
    """Determine if two columns likely form a join relationship."""
    col1_lower = col1.lower()
    col2_lower = col2.lower()

    # Direct matches
    if col1 == col2:
        return True

    # Table-prefixed matches
    if col1_lower == f"{table2.lower()}_id" or col2_lower == f"{table1.lower()}_id":
        return True

    # Name-based matches
    if 'name' in col1_lower and 'name' in col2_lower:
        # Check if they're related name fields
        if col1_lower.replace(table1.lower(), '').strip('_') == col2_lower.replace(table2.lower(), '').strip('_'):
            return True

    # Station-specific patterns
    if 'station' in col1_lower and 'station' in col2_lower:
        if ('id' in col1_lower and 'id' in col2_lower) or ('name' in col1_lower and 'name' in col2_lower):
            return True

    return False

def find_direct_joins(join_graph, optimization):
    """Find all direct join paths."""
    for table, connections in join_graph.items():
        for conn in connections:
            if conn['type'] == 'foreign_key':
                optimization['direct_joins'].append({
                    'from': table,
                    'to': conn['to_table'],
                    'condition': conn['join_condition'],
                    'complexity': 1,
                    'template': f"FROM {table} t1 JOIN {conn['to_table']} t2 ON {conn['join_condition']}"
                })

def find_name_based_joins(cursor, tables, optimization):
    """Enhanced name-based join detection."""
    for t1_idx, t1 in enumerate(tables):
        cursor.execute(f"PRAGMA table_info({t1})")
        t1_cols = cursor.fetchall()

        for t2 in tables[t1_idx + 1:]:
            cursor.execute(f"PRAGMA table_info({t2})")
            t2_cols = cursor.fetchall()

            # Look for name-based join opportunities
            for c1 in t1_cols:
                c1_name = c1[1]
                if 'name' not in c1_name.lower():
                    continue

                for c2 in t2_cols:
                    c2_name = c2[1]
                    if 'name' not in c2_name.lower():
                        continue

                    # Check if these might be joinable
                    if columns_could_join(c1_name, c2_name, t1, t2):
                        optimization['name_based_joins'].append({
                            'table1': t1,
                            'column1': c1_name,
                            'table2': t2,
                            'column2': c2_name,
                            'template': f"FROM {t1} t1 JOIN {t2} t2 ON t1.{c1_name} = t2.{c2_name}",
                            'preferred': 'name' in c1_name.lower() and 'name' in c2_name.lower()
                        })

def columns_could_join(col1, col2, table1, table2):
    """Determine if two columns could potentially join."""
    # Station-specific logic
    if 'station' in col1.lower() and 'station' in col2.lower():
        # Prefer name joins over ID joins for stations
        if 'name' in col1.lower() and 'name' in col2.lower():
            return True
        if 'id' in col1.lower() and 'id' in col2.lower():
            return True

    # General name matching
    if col1.lower().endswith('_name') and col2.lower() == 'name':
        return True
    if col2.lower().endswith('_name') and col1.lower() == 'name':
        return True

    return False

def find_multi_hop_paths(join_graph, optimization):
    """Find paths requiring multiple joins."""
    tables = list(join_graph.keys())

    for start_table in tables:
        for end_table in tables:
            if start_table == end_table:
                continue

            # BFS to find shortest path
            path = find_shortest_path(join_graph, start_table, end_table)

            if path and len(path) > 2:  # Multi-hop path
                optimization['multi_hop_paths'].append({
                    'from': start_table,
                    'to': end_table,
                    'path': path,
                    'complexity': len(path) - 1,
                    'template': generate_join_template(path, join_graph)
                })

def find_shortest_path(graph, start, end):
    """BFS to find shortest join path."""
    if start not in graph or end not in graph:
        return None

    queue = deque([(start, [start])])
    visited = set()

    while queue:
        current, path = queue.popleft()

        if current == end:
            return path

        if current in visited:
            continue

        visited.add(current)

        for connection in graph[current]:
            next_table = connection['to_table']
            if next_table not in visited:
                queue.append((next_table, path + [next_table]))

    return None

def generate_join_template(path, graph):
    """Generate SQL join template for a path."""
    if len(path) < 2:
        return ""

    template_parts = [f"FROM {path[0]} t1"]

    for i in range(1, len(path)):
        prev_table = path[i-1]
        curr_table = path[i]

        # Find the join condition
        condition = None
        for conn in graph[prev_table]:
            if conn['to_table'] == curr_table:
                condition = conn['join_condition']
                break

        if condition:
            template_parts.append(f"JOIN {curr_table} t{i+1} ON {condition}")

    return "\n  ".join(template_parts)

def detect_station_patterns(cursor, tables, optimization):
    """Detect station-specific join patterns."""
    station_tables = [t for t in tables if 'station' in t.lower()]
    trip_tables = [t for t in tables if 'trip' in t.lower() or 'ride' in t.lower()]

    if station_tables and trip_tables:
        for station_table in station_tables:
            cursor.execute(f"PRAGMA table_info({station_table})")
            station_cols = cursor.fetchall()
            station_col_names = [c[1] for c in station_cols]

            for trip_table in trip_tables:
                cursor.execute(f"PRAGMA table_info({trip_table})")
                trip_cols = cursor.fetchall()
                trip_col_names = [c[1] for c in trip_cols]

                # Check for station name joins
                station_name_cols = [c for c in station_col_names if 'name' in c.lower()]
                trip_station_name_cols = [c for c in trip_col_names if 'station' in c.lower() and 'name' in c.lower()]

                if station_name_cols and trip_station_name_cols:
                    optimization['station_patterns']['name_join'] = {
                        'pattern': 'Station name-based join',
                        'template': f"FROM {trip_table} t JOIN {station_table} s ON t.start_station_name = s.name",
                        'recommendation': 'PREFER name-based joins for stations over ID-based joins',
                        'reason': 'More reliable and avoids ID mismatch issues'
                    }

                # Check for station ID joins
                station_id_cols = [c for c in station_col_names if c.lower() == 'id' or 'station_id' in c.lower()]
                trip_station_id_cols = [c for c in trip_col_names if 'station' in c.lower() and 'id' in c.lower()]

                if station_id_cols and trip_station_id_cols:
                    optimization['station_patterns']['id_join'] = {
                        'pattern': 'Station ID-based join',
                        'template': f"FROM {trip_table} t JOIN {station_table} s ON t.start_station_id = s.id",
                        'recommendation': 'Use ONLY if name-based join not available',
                        'reason': 'ID joins can have referential integrity issues'
                    }

def detect_weather_patterns(cursor, tables, optimization):
    """Detect weather-specific join patterns."""
    weather_tables = [t for t in tables if 'weather' in t.lower()]
    trip_tables = [t for t in tables if 'trip' in t.lower() or 'ride' in t.lower()]

    if weather_tables and trip_tables:
        for weather_table in weather_tables:
            cursor.execute(f"PRAGMA table_info({weather_table})")
            weather_cols = cursor.fetchall()
            weather_col_names = [c[1] for c in weather_cols]

            for trip_table in trip_tables:
                cursor.execute(f"PRAGMA table_info({trip_table})")
                trip_cols = cursor.fetchall()
                trip_col_names = [c[1] for c in trip_cols]

                # Check for date and zip code joins
                weather_has_date = any('date' in c.lower() for c in weather_col_names)
                weather_has_zip = any('zip' in c.lower() for c in weather_col_names)
                trip_has_date = any('date' in c.lower() for c in trip_col_names)
                trip_has_zip = any('zip' in c.lower() for c in trip_col_names)

                if weather_has_date and weather_has_zip and trip_has_date and trip_has_zip:
                    optimization['weather_patterns']['date_zip_join'] = {
                        'pattern': 'Weather-trip join on date and location',
                        'template': f"""FROM {trip_table} t
  JOIN {weather_table} w ON DATE(t.start_date) = w.date
    AND t.zip_code = w.zip_code""",
                        'recommendation': 'Always join on BOTH date AND location',
                        'date_conversion': 'Use DATE() function to extract date from datetime',
                        'zip_handling': 'Cast to TEXT if types differ'
                    }

def generate_recommendations(optimization, tables):
    """Generate specific join recommendations."""
    recommendations = []

    # Station join recommendations
    if optimization['station_patterns']:
        if 'name_join' in optimization['station_patterns']:
            recommendations.append({
                'type': 'station_joins',
                'priority': 'HIGH',
                'recommendation': 'Always prefer station name joins over ID joins',
                'reason': 'Name joins are more reliable and avoid referential integrity issues'
            })

    # Weather join recommendations
    if optimization['weather_patterns']:
        recommendations.append({
            'type': 'weather_joins',
            'priority': 'HIGH',
            'recommendation': 'Join weather data on both date AND location',
            'reason': 'Weather varies by location and time'
        })

    # Direct join preference
    if optimization['direct_joins']:
        recommendations.append({
            'type': 'general',
            'priority': 'MEDIUM',
            'recommendation': 'Use direct foreign key joins when available',
            'reason': 'Simpler queries are more maintainable and performant'
        })

    # Name-based join caution
    if optimization['name_based_joins']:
        recommendations.append({
            'type': 'general',
            'priority': 'LOW',
            'recommendation': 'Validate name-based joins with sample data',
            'reason': 'Name matching can be ambiguous'
        })

    optimization['join_recommendations'] = recommendations

def identify_common_errors(optimization):
    """Identify common join errors to avoid."""
    errors = []

    # Station join errors
    if optimization['station_patterns']:
        errors.append({
            'error': 'Using station ID when name is available',
            'impact': 'Can cause mismatches or missing results',
            'fix': 'Switch to name-based joins'
        })

    # Weather join errors
    if optimization['weather_patterns']:
        errors.append({
            'error': 'Joining weather only on date without location',
            'impact': 'Returns incorrect weather for different locations',
            'fix': 'Add zip_code or city to join condition'
        })

    # General join errors
    errors.append({
        'error': 'Over-complicating joins with unnecessary tables',
        'impact': 'Performance degradation and complexity',
        'fix': 'Use shortest join path available'
    })

    errors.append({
        'error': 'Not handling NULL values in join columns',
        'impact': 'Missing results from LEFT/RIGHT joins',
        'fix': 'Use COALESCE or IS NOT NULL checks'
    })

    optimization['common_join_errors'] = errors

def main():
    """Main execution function."""
    db_path = "./database.sqlite"

    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return

    # Create output directory
    os.makedirs("tool_output", exist_ok=True)

    # Optimize joins
    optimization = optimize_joins(db_path)

    # Save results
    output_path = "tool_output/optimized_joins.json"
    with open(output_path, 'w') as f:
        json.dump(optimization, f, indent=2)

    print(f"✅ Join optimization complete")
    print(f"🔗 Direct joins found: {len(optimization['direct_joins'])}")
    print(f"📝 Name-based joins: {len(optimization['name_based_joins'])}")
    print(f"🌉 Multi-hop paths: {len(optimization['multi_hop_paths'])}")
    print(f"💾 Results saved to: {output_path}")

    # Print key recommendations
    if optimization['join_recommendations']:
        print("\n🎯 Key Recommendations:")
        for rec in optimization['join_recommendations'][:3]:
            print(f"  • {rec['recommendation']}")

if __name__ == "__main__":
    main()