from typing import TYPE_CHECKING
from copy import deepcopy
from langchain_core.messages import AIMessage
from langchain_core.runnables import chain
from langgraph.prebuilt import ToolNode
from orchestrator_maze_implementation.tools.basic_tool_functions import MazeTools
from orchestrator_maze_implementation.tools.basic_tool_functions import move_north, move_south, move_east, move_west, mark_dead_end, get_current_view, start_backtracking
from orchestrator_maze_implementation.state.maze_state import MazeState
from orchestrator_maze_implementation.config.config_service import config


def get_current_agent_by_turn(turn_count: int, all_agents: list, steps_per_agent: int = None) -> str:
    """Calculate which agent should be active based on turn count.
    
    Args:
        turn_count: Current turn number (0-based)
        all_agents: List of all agent IDs
        steps_per_agent: Number of steps each agent takes before rotating (default: from config)
    
    Returns:
        Agent ID that should be active for this turn
    """
    if not all_agents:
        return "agent_0"
    
    # Use config value if not provided
    if steps_per_agent is None:
        steps_per_agent = config.get_steps_per_agent()
    
    # Calculate which agent should be active
    # Each agent gets 'steps_per_agent' consecutive turns
    agent_index = (turn_count // steps_per_agent) % len(all_agents)
    return all_agents[agent_index]


#ToDo: MazeTools need to be correctly bound with the agents (s. https://langchain-ai.github.io/langgraph/how-tos/tool-calling/#use-in-a-workflow)
MazeTools = [move_north, move_south, move_east, move_west, mark_dead_end, get_current_view, start_backtracking]

def execute_tools_with_injection(state: MazeState):
    """Execute tools with proper agent_id injection using ToolNode"""
    
    #extract current agent_id from state to map tool_call correctly
    agent_id = state.get("current_agent")
    if not agent_id:
        raise ValueError("No agent_id found in last message - ensure it's set in the agent node")
    
    # Access the specific agent's message history
    agent_msgs = state["agent_messages"].get(agent_id, [])
    if not agent_msgs:
        return {}  # No messages, no update

    last_message = agent_msgs[-1]
    #print(f"DEBUG LAST MESSAGE INPUT: {last_message}")
    if not (hasattr(last_message, 'tool_calls') and last_message.tool_calls):
        return {}  # No tool calls, no update

    # Safety Check: Ensure state['agent_id']matches 'agent_id' in from additional_kwargs
    embedded_id = last_message.additional_kwargs.get("agent_id")
    if embedded_id and embedded_id != agent_id:
        raise ValueError(f"Mismatch between state.current_agent ({agent_id}) and embedded id ({embedded_id})")
    
    # Create a modified state with injected tool calls
    #modified_state = state.copy()
    
    # Inject agent_id into tool calls in the last message
    injected_tool_calls = inject_agent_id.invoke({
    "ai_msg": last_message,
    "agent_id": agent_id
    })
    
    # Create new AIMessage with injected tool calls
    modified_ai_message = AIMessage(
        content=last_message.content,
        tool_calls=injected_tool_calls,
        additional_kwargs=last_message.additional_kwargs  # Preserve original kwargs
    )
    #print(f"DEBUG MODIFIED MESSAGE modified_ai_message: {modified_ai_message}")

    # Replace the last message with the modified one for consistency
    modified_agent_msgs = agent_msgs[:-1] + [modified_ai_message]

    # Create temp input for ToolNode
    temp_input = { 
        # Passing modified state to avoid including full agent_messages context
        "messages": [modified_ai_message],  # Only current tool call
        "maze_wrappers": state["maze_wrappers"],
        "current_agent": agent_id,
        "system_task": state["system_task"],
        "plan": state["plan"],
        "plan_completed": state["plan_completed"],
        "step_index": state["step_index"],
        "turn_count": state["turn_count"],
        "maze_exit_found": state["maze_exit_found"],
        "exit_position": state["exit_position"],
        "winning_agent": state["winning_agent"],
        "turn_complete": state["turn_complete"],
        "agent_messages": {},  # Empty - we don't want historical messages
        "all_agents": state.get("all_agents", []),
        "orchestrator_guidance": state.get("orchestrator_guidance"),
        "shared_knowledge": state.get("shared_knowledge"),
        "known_openings": state.get("known_openings", {}),
        "free_energy_metrics": state.get("free_energy_metrics", {}),
        "entropy_history": state.get("entropy_history", []),
        "agent_backtracking_state": state.get("agent_backtracking_state", {}),
    }
    #print(f"DEBUG TEMP INPUT: {temp_input}")
    
    # Use ToolNode to execute the tools with proper state injection
    tool_node = ToolNode(MazeTools)
    tool_results = tool_node.invoke(temp_input)

    # Update agent's messages with tool results
    updated_agent_msgs = modified_agent_msgs + tool_results["messages"]

    #print(f"DEBUG MODIFIED MESSAGE modified_ai_message: {updated_agent_msgs}")

    return {
        "agent_messages": {agent_id: updated_agent_msgs},  # Targeted update
    }

@chain
def inject_agent_id(inputs: dict):
    """Inject agent_id into tool calls at runtime following LangChain documentation pattern"""
    ai_msg = inputs["ai_msg"]
    agent_id = inputs["agent_id"]
    
    tool_calls = []
    # FIX: Add None check before iterating over tool_calls
    if ai_msg.tool_calls is not None:
        for tool_call in ai_msg.tool_calls:
            tool_call_copy = deepcopy(tool_call)
            # Inject agent_id for all tool calls
            tool_call_copy["args"]["agent_id"] = agent_id
            tool_calls.append(tool_call_copy)
    return tool_calls

def model_condition(state: MazeState):
    """Decide whether to use tools, continue, or end"""
    
    # Get current agent_id
    agent_id = state["current_agent"]
    # Get the last message - handle empty message lists safely
    agent_messages = state["agent_messages"].get(agent_id, [])
    last_message = agent_messages[-1] if agent_messages else None
    
    if state.get("step_index") <= len(state.get("plan")):
        # PRIORITY 1: If the last message has tool calls, ALWAYS use tools first
        # This must happen before any other checks to ensure tool calls are executed
        if last_message and hasattr(last_message, 'tool_calls') and last_message.tool_calls:
            print(f"DEBUG MODEL CONDITION: Tool calls detected. Routing to tools node")
            return "tools"
        
        # PRIORITY 2: Check if current turn loop is complete (only after tools are handled)
        if state.get("turn_complete"):
            print(f"DEBUG MODEL CONDITION: Turn complete. Routing to check_turn node")
            return "check_turn"  # Route to check_turn node
    
    else: 
        if state.get("turn_complete"):
            print(f"DEBUG MODEL CONDITION: Turn complete. Routing to check_turn node")
            return "check_turn"
        
        if last_message and hasattr(last_message, 'tool_calls') and last_message.tool_calls:
            print(f"DEBUG MODEL CONDITION: Tool calls detected. Routing to tools node")
            return "tools"

    # Check if maze is solved
    maze_wrapper = state["maze_wrappers"][agent_id]
    
    if maze_wrapper.is_at_exit():
        return "end"
    
    # Limit number of iterations to prevent infinite loops
    if state["step_index"] >= 25:  # Equivalent of 25 turns
        print("Maximum turns reached. Ending maze exploration.")
        return "end"
    
    # If no tool calls and maze not solved, continue the loop
    return "continue"

def should_continue_after_tools(state: MazeState):
    """Check if the maze exploration should continue after tools are executed"""
    # Get the first agent's maze wrapper
    agent_id = state["current_agent"]
    maze_wrapper = state["maze_wrappers"][agent_id]
    
    if maze_wrapper.is_at_exit():
        return "end"
    
    # Limit number of iterations to prevent infinite loops
    if state["step_index"] >= 25:  # Rough estimate for 15 turns
        print("Maximum turns reached. Ending maze exploration.")
        return "end"
    
    turn_count = state["turn_count"]
    step_count = maze_wrapper.get_move_count()
    max_total_steps = config.get_max_total_steps()
    print(f"DEBUG MODEL CONDITION: Turn count: {turn_count}, Step count: {step_count}, max total steps: {max_total_steps}")
    if max_total_steps is not None and step_count >= max_total_steps:
        return "end"
    
    # Continue back to maze_agent instead of confirmation
    return "continue"
