"""
Utility functions for FRAME experiments.

This module contains reusable utilities for running mathematical discovery experiments,
such as reporting discovered entities, checking for duplicates, and generating visualizations.
"""

import logging
import os
from pathlib import Path
from typing import Optional, List, Dict, Any, TYPE_CHECKING
from datetime import datetime

# Import the logger from logging.py instead of creating a duplicate
from frame.utils.logging import logger

# Global tracking of thread pools for cleanup
thread_pools = []
# Flag to indicate shutdown is in progress
_shutdown_in_progress = False
# Flag to track if we've already logged visualization thread messages
_logged_viz_threads = False

def cleanup_resources():
    """
    Ensure all resources are properly cleaned up.
    
    This function is designed to be called at exit and during error handling
    to ensure all resources are properly released, even when the 
    program is interrupted.
    """
    global _shutdown_in_progress
    
    # Avoid recursive cleanup
    if _shutdown_in_progress:
        return
        
    _shutdown_in_progress = True
    
    # Placeholder for any future cleanup operations
    # Currently empty since we've simplified the approach
    
    # Reset the flag
    _shutdown_in_progress = False

def print_discovered_entities(graph, interestingness_scorer=None):  # type: (Any, Optional[Any]) -> None
    """
    Print all discovered concepts and conjectures in order of their creation.
    
    Args:
        graph: The knowledge graph containing the entities
        interestingness_scorer: Optional function to compute interestingness scores for entities
    """
    # Get all concepts and conjectures
    concepts = []
    conjectures = []
    for node_id in graph.nodes:
        try:
            entity, node_type, _ = graph.get_node(node_id)
            if node_type.value == "concept":
                concepts.append(node_id)
            elif node_type.value == "conjecture":
                conjectures.append(node_id)
        except Exception as e:
            logging.error(f"Error getting node {node_id}: {e}")
    
    # Sort by node ID if they're numeric or have a clear ordering
    try:
        # First check if node IDs have a numeric format that can be sorted directly
        all_numeric = True
        for node_id in list(graph.nodes)[:5]:  # Check first few nodes
            if not isinstance(node_id, (int, float)) and not (isinstance(node_id, str) and node_id.isdigit()):
                all_numeric = False
                break
                
        if all_numeric:
            # If all node IDs are numeric or can be converted to numbers, sort directly
            concepts.sort(key=lambda node_id: int(node_id) if isinstance(node_id, str) else node_id)
            conjectures.sort(key=lambda node_id: int(node_id) if isinstance(node_id, str) else node_id)
        else:
            # If we have creation timestamps, use those
            has_timestamps = False
            for node_id in list(graph.nodes)[:1]:  # Check first node
                if 'creation_time' in graph.nodes[node_id] or 'timestamp' in graph.nodes[node_id]:
                    has_timestamps = True
                    break
            
            if has_timestamps:
                # Sort by explicit timestamp if available
                concepts.sort(key=lambda node_id: graph.nodes[node_id].get('creation_time', 
                                                graph.nodes[node_id].get('timestamp', 0)))
                conjectures.sort(key=lambda node_id: graph.nodes[node_id].get('creation_time', 
                                                    graph.nodes[node_id].get('timestamp', 0)))
            else:
                # Extract numeric parts from IDs as a fallback
                import re
                def extract_number(node_id):
                    # Try to extract a numeric sequence from the ID
                    if isinstance(node_id, (int, float)):
                        return node_id
                    match = re.search(r'(\d+)', str(node_id))
                    if match:
                        return int(match.group(1))
                    return 0  # Default case
                
                concepts.sort(key=extract_number)
                conjectures.sort(key=extract_number)
    except Exception as e:
        # Fallback to alphabetical if there's an error
        logging.error(f"Error sorting by node ID or creation order: {e}")
        concepts.sort(key=lambda node_id: graph.nodes[node_id]['entity'].name if 'entity' in graph.nodes[node_id] else "")
        conjectures.sort(key=lambda node_id: graph.nodes[node_id]['entity'].name if 'entity' in graph.nodes[node_id] else "")
    
    # Log the number of discovered entities
    logging.info(f"Displaying discovered entities: {len(concepts)} concepts, {len(conjectures)} conjectures")
    
    # Print and log header for concepts
    header = "\n" + "="*80 + f"\nDISCOVERED CONCEPTS ({len(concepts)}):\n" + "="*80
    print(header)
    logging.info(header)
    
    for i, concept_id in enumerate(concepts):
        try:
            concept, _, _ = graph.get_node(concept_id)
            concept_info = f"{i+1}. {concept.name} [ID: {concept_id}]: {concept.description}"
            print(concept_info)
            logging.info(concept_info)
            
            # Calculate and print interestingness score if scorer is available
            if interestingness_scorer:
                try:
                    # Check if the scorer accepts two arguments (entity_id and graph)
                    import inspect
                    sig = inspect.signature(interestingness_scorer)
                    
                    # Debug output if enabled
                    if os.environ.get("INTERESTINGNESS_DEBUG", "0") == "1":
                        debug_msg = f"   DEBUG: Scoring concept ID: {concept_id} (type: {type(concept_id)})"
                        print(debug_msg)
                        logging.debug(debug_msg)
                        
                        debug_msg = f"   DEBUG: Concept name: {concept.name}"
                        print(debug_msg)
                        logging.debug(debug_msg)
                    
                    try:
                        if len(sig.parameters) >= 2:
                            # The scorer accepts two arguments, pass both entity_id and graph
                            score = interestingness_scorer(concept_id, graph)
                        else:
                            # Use the original approach with just entity_id
                            score = interestingness_scorer(concept_id)
                        
                        score_info = f"   Interestingness score: {score:.4f}"
                        print(score_info)
                        logging.info(score_info)
                    except Exception as e:
                        # Skip printing interestingness score when it can't be calculated
                        if os.environ.get("INTERESTINGNESS_DEBUG", "0") == "1":
                            debug_msg = f"   DEBUG: Skipping interestingness scoring: {e}"
                            print(debug_msg)
                            logging.debug(debug_msg)
                except Exception as e:
                    if os.environ.get("INTERESTINGNESS_DEBUG", "0") == "1":
                        debug_msg = f"   DEBUG: Error in interestingness scoring: {e}"
                        print(debug_msg)
                        logging.debug(debug_msg)
            
            # Print examples and nonexamples if available
            if hasattr(concept, "examples"):
                examples = list(concept.examples.get_examples())[:5]  # Limit to 5 for display
                nonexamples = list(concept.examples.get_nonexamples())[:5]  # Limit to 5 for display
                
                if examples:
                    ex_values = [ex.value for ex in examples]
                    examples_info = f"   Examples: {ex_values}"
                    print(examples_info)
                    logging.info(examples_info)
                else:
                    no_examples_info = f"   No examples available"
                    print(no_examples_info)
                    logging.info(no_examples_info)
                    
                if nonexamples:
                    nonex_values = [ex.value for ex in nonexamples]
                    nonexamples_info = f"   Nonexamples: {nonex_values}"
                    print(nonexamples_info)
                    logging.info(nonexamples_info)
                else:
                    no_nonexamples_info = f"   No nonexamples available"
                    print(no_nonexamples_info)
                    logging.info(no_nonexamples_info)
        except Exception as e:
            error_msg = f"{i+1}. Error retrieving concept information: {e}"
            print(error_msg)
            logging.error(error_msg)
        
        print()  # Add a blank line between concepts
    
    # Print and log header for conjectures
    header = "\n" + "="*80 + f"\nDISCOVERED CONJECTURES ({len(conjectures)}):\n" + "="*80
    print(header)
    logging.info(header)
    
    for i, conjecture_id in enumerate(conjectures):
        try:
            conjecture, _, _ = graph.get_node(conjecture_id)
            conjecture_info = f"{i+1}. {conjecture.name} [ID: {conjecture_id}]: {conjecture.description}"
            print(conjecture_info)
            logging.info(conjecture_info)
            
            # Calculate and print interestingness score if scorer is available
            if interestingness_scorer:
                try:
                    # Check if the scorer accepts two arguments (entity_id and graph)
                    import inspect
                    sig = inspect.signature(interestingness_scorer)
                    
                    if len(sig.parameters) >= 2:
                        # The scorer accepts two arguments, pass both entity_id and graph
                        score = interestingness_scorer(conjecture_id, graph)
                    else:
                        # Use the original approach with just entity_id
                        score = interestingness_scorer(conjecture_id)
                        
                    score_info = f"   Interestingness score: {score:.4f}"
                    print(score_info)
                    logging.info(score_info)
                except Exception as e:
                    error_msg = f"   Error calculating interestingness: {e}"
                    print(error_msg)
                    logging.error(error_msg)
        except Exception as e:
            error_msg = f"{i+1}. Error retrieving conjecture information: {e}"
            print(error_msg)
            logging.error(error_msg)
        
        print()  # Add a blank line between conjectures

def check_for_duplicates(graph):  # type: (Any) -> None
    """
    Check for duplicate concepts in the knowledge graph.
    
    This is useful for debugging and evaluating the effectiveness of the
    concept generation process. Duplicate concepts represent redundancy in
    the concept discovery process.
    
    Args:
        graph: The knowledge graph to check for duplicates
    """
    # Log and print header
    header = "\n" + "="*80 + "\nCHECKING FOR DUPLICATE CONCEPTS:\n" + "="*80
    print(header)
    logging.info(header)
    
    # Get all concepts
    concepts = []
    for node_id in graph.nodes:
        try:
            entity, node_type, _ = graph.get_node(node_id)
            if node_type.value == "concept":
                concepts.append(node_id)
        except:
            continue
    
    # Group concepts by their examples
    concept_examples = {}
    duplicate_count = 0
    
    for concept_id in concepts:
        concept = graph.nodes[concept_id]['entity']
        if not hasattr(concept, 'examples'):
            continue
            
        # Get examples and convert to a hashable representation
        examples = concept.examples.get_examples()
        nonexamples = concept.examples.get_nonexamples()
        
        example_values = tuple(sorted([(ex.value if not isinstance(ex.value, dict) else str(ex.value)) 
                                      for ex in examples]))
        nonexample_values = tuple(sorted([(ex.value if not isinstance(ex.value, dict) else str(ex.value)) 
                                         for ex in nonexamples]))
        
        example_key = (example_values, nonexample_values)
        
        if example_key in concept_examples:
            concept_examples[example_key].append(concept_id)
            duplicate_count += 1
        else:
            concept_examples[example_key] = [concept_id]
    
    # Print and log the results
    for example_key, concept_ids in concept_examples.items():
        if len(concept_ids) > 1:
            example_values, nonexample_values = example_key
            duplicate_msg = f"Found {len(concept_ids)} concepts with identical examples:"
            print(duplicate_msg)
            logging.info(duplicate_msg)
            for concept_id in concept_ids:
                concept = graph.nodes[concept_id]['entity']
                concept_detail_msg = f"  - {concept.name}: {concept.description}"
                print(concept_detail_msg)
                logging.info(concept_detail_msg)
            print() # Keep blank line for console readability
    
    # Print and log summary
    summary1_msg = f"Total unique example sets: {len(concept_examples)}"
    summary2_msg = f"Total concepts: {len(concepts)}"
    summary3_msg = f"Total duplicates found: {duplicate_count}"
    footer = "="*80
    
    print(summary1_msg)
    print(summary2_msg)
    print(summary3_msg)
    print(footer)
    
    logging.info(summary1_msg)
    logging.info(summary2_msg)
    logging.info(summary3_msg)
    logging.info(footer)

def generate_visualizations(
    graph,  # type: Any
    viz_dir,  # type: str
    timestamp,  # type: str
    logger=None,  # type: Optional[logging.Logger]
) -> None:
    """
    Generate and save visualizations of the knowledge graph.
    
    This function handles the creation of visualizations like construction trees.
    
    Args:
        graph: The knowledge graph to visualize
        viz_dir: Output directory for visualization files
        timestamp: Timestamp string to use in filenames
        logger: Optional logger for messages (will use print if None)
    """
    # Use print if no logger is provided
    log = logger.info if logger else print
    log_warning = logger.warning if logger else print
    
    try:
        # Get the output path
        output_path = os.path.join(viz_dir, f"construction_tree_{timestamp}")
        log(f"Generating construction tree visualization: {output_path}")
        
        # Generate the visualization directly
        graph.visualize_construction_tree(output_file=output_path)
        log("Construction tree visualization complete")
        log(f"Construction tree visualization saved to {output_path}.png")
        
    except Exception as viz_error:
        log_warning(f"Error generating visualizations: {viz_error}") 