#!/usr/bin/env python3
"""
Relationship Explorer - Maps all join paths and identifies junction tables
Critical for understanding how to connect entities in queries
"""

import sqlite3
import os
from collections import defaultdict

def explore_relationships(db_path='./database.sqlite'):
    """Explore and document all table relationships and join paths."""
    
    if not os.path.exists(db_path):
        print(f"Database not found at {db_path}")
        return
    
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
    tables = [row[0] for row in cursor.fetchall()]
    
    print("=" * 60)
    print("RELATIONSHIP EXPLORATION")
    print("=" * 60)
    
    # Collect all foreign keys
    all_fks = {}
    junction_tables = []
    
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list([{table}])")
        fks = cursor.fetchall()
        if fks:
            all_fks[table] = []
            for fk in fks:
                # fk format: (id, seq, table, from, to, on_update, on_delete, match)
                from_col = fk[3]
                to_table = fk[2]
                to_col = fk[4]
                all_fks[table].append({
                    'from': from_col,
                    'to_table': to_table,
                    'to_col': to_col
                })
            
            # Identify junction tables (2+ foreign keys)
            if len(fks) >= 2:
                junction_tables.append(table)
    
    # If no explicit foreign keys, try to infer from column names
    if not all_fks:
        print("\nNo explicit foreign keys found. Inferring from column names...")
        inferred_fks = infer_relationships(cursor, tables)
        all_fks = inferred_fks
    
    # Display foreign key relationships
    print("\nDIRECT FOREIGN KEY RELATIONSHIPS:")
    print("-" * 40)
    
    if all_fks:
        for table, fks in all_fks.items():
            for fk in fks:
                print(f"{table}.{fk['from']} → {fk['to_table']}.{fk['to_col']}")
    else:
        print("No relationships found")
    
    # Identify junction tables
    if junction_tables:
        print("\nJUNCTION/BRIDGE TABLES:")
        print("-" * 40)
        for table in junction_tables:
            print(f"\n{table}:")
            linked_tables = []
            for fk in all_fks[table]:
                linked_tables.append(fk['to_table'])
            print(f"  Links: {' ↔ '.join(set(linked_tables))}")
            
            # Get additional columns (non-FK columns)
            cursor.execute(f"PRAGMA table_info([{table}])")
            all_cols = [row[1] for row in cursor.fetchall()]
            fk_cols = [fk['from'] for fk in all_fks[table]]
            extra_cols = [col for col in all_cols if col not in fk_cols and not col.endswith('_id')]
            if extra_cols:
                print(f"  Additional data: {', '.join(extra_cols)}")
    
    # Build relationship graph for path finding
    graph = defaultdict(set)
    for table, fks in all_fks.items():
        for fk in fks:
            graph[table].add(fk['to_table'])
            graph[fk['to_table']].add(table)  # Bidirectional
    
    # Find common multi-step paths
    print("\n\nCOMMON JOIN PATHS:")
    print("-" * 40)
    
    # Find entities that might need multi-step joins
    entity_tables = identify_entity_tables(tables)
    
    paths_found = []
    for start_table in entity_tables:
        for end_table in entity_tables:
            if start_table != end_table and start_table in graph and end_table in graph:
                path = find_path(graph, start_table, end_table)
                if path and len(path) > 2:  # Multi-step paths
                    path_str = " → ".join(path)
                    if path_str not in paths_found and path_str[::-1] not in paths_found:
                        paths_found.append(path_str)
                        print(f"\n{start_table} to {end_table}:")
                        print(f"  Path: {path_str}")
    
    # Analyze relationship cardinality
    print("\n\nRELATIONSHIP CARDINALITY ANALYSIS:")
    print("-" * 40)
    
    for table, fks in all_fks.items():
        for fk in fks:
            # Check if it's likely one-to-many or many-to-many
            cursor.execute(f"SELECT COUNT(DISTINCT [{fk['from']}]) FROM [{table}]")
            distinct_from = cursor.fetchone()[0]
            cursor.execute(f"SELECT COUNT(*) FROM [{table}]")
            total_from = cursor.fetchone()[0]
            
            if distinct_from > 0:
                if distinct_from == total_from:
                    cardinality = "1:1 or N:1"
                else:
                    cardinality = "N:1 (many-to-one)"
                
                print(f"{table}.{fk['from']} → {fk['to_table']}.{fk['to_col']}: {cardinality}")
    
    # Generate evidence reconciliation hints
    print("\n\nEVIDENCE RECONCILIATION HINTS:")
    print("-" * 40)
    
    # Look for columns that might be referenced differently
    name_variations = defaultdict(list)
    
    for table in tables:
        cursor.execute(f"PRAGMA table_info([{table}])")
        columns = [row[1] for row in cursor.fetchall()]
        
        for col in columns:
            col_lower = col.lower()
            # Common variations
            if 'company' in col_lower or 'supplier' in col_lower or 'customer' in col_lower:
                name_variations['CompanyName'].append(f"{table}.{col}")
            if 'ship' in col_lower and ('via' in col_lower or 'method' in col_lower or 'shipper' in col_lower):
                name_variations['ShipVia'].append(f"{table}.{col}")
            if 'territory' in col_lower or 'region' in col_lower or 'area' in col_lower:
                name_variations['Territory'].append(f"{table}.{col}")
            if 'employee' in col_lower or 'staff' in col_lower or 'worker' in col_lower:
                name_variations['Employee'].append(f"{table}.{col}")
    
    print("\nCommon evidence terms and their likely locations:")
    for term, locations in name_variations.items():
        if locations:
            print(f"\n'{term}' might refer to:")
            for loc in locations[:5]:  # Limit to 5 suggestions
                print(f"  → {loc}")
    
    # Save output
    os.makedirs('./tool_output', exist_ok=True)
    with open('./tool_output/relationships.txt', 'w') as f:
        f.write("RELATIONSHIP MAP\n")
        f.write("=" * 40 + "\n\n")
        
        f.write("FOREIGN KEYS:\n")
        for table, fks in all_fks.items():
            for fk in fks:
                f.write(f"  {table}.{fk['from']} → {fk['to_table']}.{fk['to_col']}\n")
        
        if junction_tables:
            f.write("\n\nJUNCTION TABLES:\n")
            for table in junction_tables:
                f.write(f"  {table}\n")
        
        if paths_found:
            f.write("\n\nMULTI-STEP PATHS:\n")
            for path in paths_found:
                f.write(f"  {path}\n")
    
    conn.close()
    print("\n\nRelationship map saved to ./tool_output/relationships.txt")

def infer_relationships(cursor, tables):
    """Infer relationships from column naming patterns."""
    inferred = {}
    
    # Get all columns for all tables
    table_columns = {}
    for table in tables:
        cursor.execute(f"PRAGMA table_info([{table}])")
        columns = [row[1] for row in cursor.fetchall()]
        table_columns[table] = columns
    
    # Look for ID patterns
    for table in tables:
        for col in table_columns[table]:
            # Pattern: column ends with _id or Id or ID
            if col.endswith('_id') or col.endswith('Id') or col.endswith('ID'):
                # Try to find matching table
                potential_table = col[:-3] if col.endswith('_id') else col[:-2]
                potential_table_lower = potential_table.lower()
                
                for target_table in tables:
                    if (target_table.lower() == potential_table_lower or 
                        target_table.lower() == potential_table_lower + 's' or
                        target_table.lower() == potential_table_lower + 'es'):
                        
                        if table not in inferred:
                            inferred[table] = []
                        
                        # Try to find primary key in target table
                        target_pk = 'id'  # default assumption
                        cursor.execute(f"PRAGMA table_info([{target_table}])")
                        for row in cursor.fetchall():
                            if row[5] == 1:  # is primary key
                                target_pk = row[1]
                                break
                        
                        inferred[table].append({
                            'from': col,
                            'to_table': target_table,
                            'to_col': target_pk
                        })
                        break
    
    return inferred

def identify_entity_tables(tables):
    """Identify main entity tables (not junction/lookup tables)."""
    entity_tables = []
    for table in tables:
        table_lower = table.lower()
        # Skip obvious junction/lookup tables
        if not any(word in table_lower for word in ['_to_', '_x_', 'junction', 'bridge', 'link', 'map']):
            # Skip obvious lookup tables
            if not any(word in table_lower for word in ['lookup', 'ref_', 'dim_']):
                entity_tables.append(table)
    return entity_tables

def find_path(graph, start, end, path=[]):
    """Find shortest path between two tables in relationship graph."""
    path = path + [start]
    if start == end:
        return path
    if start not in graph:
        return None
    shortest = None
    for node in graph[start]:
        if node not in path:
            newpath = find_path(graph, node, end, path)
            if newpath:
                if not shortest or len(newpath) < len(shortest):
                    shortest = newpath
    return shortest

if __name__ == "__main__":
    explore_relationships()