from orchestrator_maze_implementation.utils.free_energy_calculator import FreeEnergyCalculator
from orchestrator_maze_implementation.state.maze_state import MazeState
from typing import Dict, Any
import numpy as np

def extract_text_content(message_content):
    """Extract text content from complex message formats (handles both string and list formats)"""
    if isinstance(message_content, str):
        return message_content
    elif isinstance(message_content, list):
        # Handle list of content items (e.g., from reasoning models)
        text_parts = []
        for item in message_content:
            if isinstance(item, dict) and 'text' in item:
                text_parts.append(item['text'])
            elif isinstance(item, str):
                text_parts.append(item)
        return ' '.join(text_parts)
    else:
        return str(message_content) if message_content else ""

def preprocess_agent_messages(agent_messages):
    """Preprocess agent messages to handle complex content formats"""
    processed_messages = []
    
    for msg in agent_messages:
        try:
            # Create a copy of the message with processed content
            if hasattr(msg, 'content') and msg.content is not None:
                processed_content = extract_text_content(msg.content)
                # Create a simple message-like object for the calculator
                processed_msg = type('ProcessedMessage', (), {
                    'content': processed_content,
                    'type': getattr(msg, 'type', 'unknown'),
                    'name': getattr(msg, 'name', None),
                    'tool_calls': getattr(msg, 'tool_calls', None)
                })()
                processed_messages.append(processed_msg)
            else:
                processed_messages.append(msg)
        except Exception as e:
            print(f"Warning: Error processing message {type(msg)}: {e}")
            continue
    
    return processed_messages

def benchmarking_node(state: MazeState) -> Dict[str, Any]:
    """
    Benchmarking node that calculates Free Energy metrics for current system state.
    Triggered every N turns or when specific conditions are met.
    """
    
    print(f"🔬 BENCHMARKING NODE: Analyzing system state at turn {state.get('turn_count', 0)}")

    try:
        fe_calculator = FreeEnergyCalculator()

        #calculate individual agent free energy scores
        agent_fe_scores = {}
        all_agents = list(state.get("maze_wrappers", {}).keys())

        for agent_id in all_agents:
            try:
                # Preprocess agent messages to handle complex content formats
                agent_messages = state.get("agent_messages", {}).get(agent_id, [])
                if agent_messages:
                    processed_messages = preprocess_agent_messages(agent_messages)
                    # Temporarily update state with processed messages for FE calculation
                    original_messages = state.get("agent_messages", {}).get(agent_id, [])
                    if "agent_messages" not in state:
                        state["agent_messages"] = {}
                    state["agent_messages"][agent_id] = processed_messages
                
                fe_metrics = fe_calculator.calculate_expected_free_energy(agent_id, state)
                agent_fe_scores[agent_id] = fe_metrics

                print(f"  🤖 {agent_id}: FE={fe_metrics['total_fe']:.3f} "
                    f"(Epistemic={fe_metrics['epistemic_uncertainty']:.3f}, "
                    f"Pragmatic={fe_metrics['pragmatic_cost']:.3f})")
                print(f"  🤖 {agent_id}: FE={fe_metrics} ")
                
                # Restore original messages
                if agent_messages:
                    state["agent_messages"][agent_id] = original_messages
                    
            except Exception as e:
                print(f"Warning: Error calculating FE for agent {agent_id}: {e}")
                # Provide default values if calculation fails
                agent_fe_scores[agent_id] = {
                    'total_fe': 1.0,
                    'epistemic_uncertainty': 0.5,
                    'pragmatic_cost': 0.5
                }
        
        benchmark_results = {
            "agent_fe_scores": agent_fe_scores
        }
        
        # Store in entropy history (preserve existing if it exists)
        current_entropy_history = state.get("entropy_history", [])
        if current_entropy_history is None:
            current_entropy_history = []
        updated_entropy_history = current_entropy_history.copy()
        updated_entropy_history.append(benchmark_results)
        
        return {
            "entropy_history": updated_entropy_history,
            "free_energy_metrics": benchmark_results,
        }
        
    except Exception as e:
        print(f"Error in benchmarking_node: {e}")
        # Return default values to prevent system failure
        default_agent_fe_scores = {}
        all_agents = list(state.get("maze_wrappers", {}).keys())
        for agent_id in all_agents:
            default_agent_fe_scores[agent_id] = {
                'total_fe': 1.0,
                'epistemic_uncertainty': 0.5,
                'pragmatic_cost': 0.5
            }
        
        benchmark_results = {
            "agent_fe_scores": default_agent_fe_scores
        }
        
        return {
            "entropy_history": [benchmark_results],
            "free_energy_metrics": benchmark_results,
        }