from langgraph.graph import StateGraph
from langgraph.graph import START, END
from orchestrator_maze_implementation.config.config_service import config
from orchestrator_maze_implementation.state.maze_state import MazeState
#from orchestrator_maze_implementation.agents.maze_planning_agent import maze_planning_agent
from orchestrator_maze_implementation.agents.maze_execution_agent import maze_execution_agent
from orchestrator_maze_implementation.agents.maze_orchestration_agent import maze_orchestration_agent
from orchestrator_maze_implementation.agents.benchmarking_node import benchmarking_node
#from orchestrator_maze_implementation.agents.sync_knowledge_agent import sync_knowledge_node
from orchestrator_maze_implementation.router.router_support import execute_tools_with_injection, model_condition, should_continue_after_tools, get_current_agent_by_turn  # Add this import

def create_maze_workflow(state: MazeState):
    """Create LangGraph workflow for maze solving with ExecutionCell integration"""
    ablation_config = config.get('ablations')

    # Build LangGraph workflow
    workflow = StateGraph(MazeState)

    workflow.add_node("maze_execution_agent", maze_execution_agent)
    if ablation_config.get("enable_orchestration_agent", True):
        workflow.add_node("maze_orchestration_agent", maze_orchestration_agent)
        orchestration_enabled = True
    else:
        orchestration_enabled = False
    workflow.add_node("plan_node", plan_node)
    workflow.add_node("tools", execute_tools_with_injection)
    workflow.add_node("check_turn", check_turn)
    workflow.add_edge(START, 'plan_node')
    workflow.add_edge('plan_node', 'maze_execution_agent')
    
    workflow.add_conditional_edges(
        "maze_execution_agent",
        model_condition,
        {
            "tools": "tools",
            "continue": "maze_execution_agent",
            "check_turn": "check_turn",
            "end": END
        }
    )
    workflow.add_conditional_edges(
        "tools",
        should_continue_after_tools,
        {
            "continue": "maze_execution_agent",
            "end": END
        }
    )

    #loop back for next turn or go to orchestration | with added ablations for orchestration
    if orchestration_enabled:
        #loop back for next turn or go to orchestration
        workflow.add_conditional_edges(
            "check_turn",
            validate_turn_count,
            {
                "orchestrate": "maze_orchestration_agent",
                "continue": "maze_execution_agent",
            }
        )
        workflow.add_edge("maze_orchestration_agent", "maze_execution_agent") # execution agent policy update
    else:
        # Direct continuation without orchestration
        workflow.add_conditional_edges(
            "check_turn",
            lambda s: "continue",  # Always continue, never orchestrate
            {
                "continue": "maze_execution_agent",
            }
        )

    # return compiled workflow and pass into object 'chain' (main.py)
    return workflow.compile()

def plan_node(state: MazeState):
    print(f"DEBUG PLAN NODE: Launching plan node with state: {state}")
    plan =  [
            "Use get_current_view() to observe your surroundings. Return FINAL ANSWER with what you see.",
            "Choose the best direction and use ONE move tool (move_north/south/east/west). Return FINAL ANSWER with movement result.",
            "OPTIONAL: Only if you are absolutely certain a position is a dead end based on complete exploration, use mark_path_as_dead_end(). Else, if you keep on walking on previously visited tiles, call tool start_backtracking() to get back to the last known unexplored opening. Otherwise return FINAL ANSWER: Turn complete. Provide a brief resoning before issuing your response."
        ]
    return {'plan': plan, "step_index": 0, "turn_complete": False}


def check_turn(state: MazeState): #there is no handling logic for what happens if the turn count is greater than the number of iterations.
    try:
        current_count = state.get("turn_count")
        new_turn_count = current_count + 1
        
        # Calculate which agent should be active for the new turn
        all_agents = state.get("all_agents", [])
        new_current_agent = get_current_agent_by_turn(new_turn_count, all_agents)

        # FIX: Always reset step_index to 0 for each new turn (plan restarts)
        step_index = 0
        
        # Print agent rotation info for debugging
        old_agent = state.get("current_agent")
        if new_current_agent != old_agent:
            print(f"🔄 Agent rotation: {old_agent} -> {new_current_agent} (Turn {new_turn_count})")
        else:
            print(f"🔄 Turn {new_turn_count}: {new_current_agent} continues")
        
        return {
            "turn_count": new_turn_count, 
            "turn_complete": False, 
            "step_index": step_index,
            "current_agent": new_current_agent
        }
    except Exception as e:
        print(f"Error in turn count validation: {e}")

def validate_turn_count(state: MazeState):
    """Validate turn count"""
    if state["turn_count"] >= config.get_num_turn_iterations():
        return "orchestrate"
    else:
        return "continue"