"""Manual React agent with fine-grained interaction control using LangGraph."""

import logging
from typing import List, Any, Dict, TypedDict, Union, Annotated
from enum import Enum
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages

SAFE_RECURSION_LIMIT = 120

class AgentStopType(Enum):
    """Types of agent stopping conditions."""
    INTERACTION_SCALING = "interaction_scaling"
    DEFAULT = "default"


class ManualReactState(TypedDict):
    """State for the manual React agent."""
    messages: Annotated[List[Union[SystemMessage, HumanMessage, AIMessage, ToolMessage]], add_messages]
    interaction_count: int
    target_interactions: int
    agent_stop_type: str
    has_final_answer: bool
    tools_available: List[Any]
    model_name: str
    reached_target_interactions: bool

class ManualReactAgent:
    """Manual React agent with fine-grained interaction control using LangGraph."""
    
    def __init__(
        self,
        model: Any,
        system_prompt: str,
        search_tools: List[Any],
        logger: logging.Logger,
        model_name: str,
        agent_stop_type: str = "default",
        agent_stop_kwargs: Dict[str, Any] = None
    ):
        """Initialize the manual React agent.
        
        Args:
            model: Language model instance
            system_prompt: System prompt for the agent
            search_tools: List of search tools available to the agent
            logger: Logger instance
            model_name: Name of the model (for caching decisions)
            agent_stop_type: Type of stopping condition
            agent_stop_kwargs: Additional configuration for stopping conditions
        """
        
        self.model = model
        self.system_prompt = system_prompt.strip()
        self.search_tools = search_tools
        self.logger = logger
        self.model_name = model_name
        self.agent_stop_type = AgentStopType(agent_stop_type.lower())
        self.agent_stop_kwargs = agent_stop_kwargs or {}
        
        # Configure model with tools
        self.model_with_tools = self.model.bind_tools(self.search_tools)
        
        # Build the manual React graph
        self._build_graph()
        
        # Configuration
        self.target_interactions = self.agent_stop_kwargs.get('interaction_rounds', None)
        
    def _build_graph(self):
        """Build the manual React agent graph."""
        builder = StateGraph(ManualReactState)
        
        # Add nodes
        builder.add_node("agent_reasoning", self._agent_reasoning_node)
        builder.add_node("tool_execution", self._tool_execution_node)
        builder.add_node("check_stopping_condition", self._check_stopping_condition_node)
        builder.add_node("prompt_final_answer", self._prompt_final_answer_node)
        builder.add_node("prompt_continue_info", self._prompt_continue_info_node)
        
        # Add edges
        builder.add_edge(START, "agent_reasoning")
        builder.add_edge("agent_reasoning", "tool_execution")
        builder.add_edge("tool_execution", "check_stopping_condition")
        builder.add_edge("prompt_final_answer", "agent_reasoning")
        builder.add_edge("prompt_continue_info", "agent_reasoning")
        
        # Conditional edges from stopping condition check
        builder.add_conditional_edges(
            "check_stopping_condition",
            self._should_continue,
            {
                "continue": "agent_reasoning",
                "prompt_final": "prompt_final_answer", 
                "prompt_more_info": "prompt_continue_info",
                "stop": END
            }
        )
        
        self.graph = builder.compile()
        
    def _agent_reasoning_node(self, state: ManualReactState, config: RunnableConfig) -> Dict[str, Any]:
        """Agent reasoning and tool selection node."""
        # Prepare messages for the model
        messages = state["messages"].copy()
        
        # Add system prompt if not already present
        if not messages or not isinstance(messages[0], SystemMessage):
            messages.insert(0, SystemMessage(content=self.system_prompt))
        
        # Get model response
        try:
            response = self.model_with_tools.invoke(messages)
            self.logger.debug(f"Model response: {response}")
            
            # Check if response contains final answer
            has_final_answer = self._check_final_answer(response)
            
            return {
                "messages": [response],
                "has_final_answer": has_final_answer
            }
            
        except Exception as e:
            self.logger.error(f"Error in agent reasoning: {e}")
            # Return an error response
            error_response = AIMessage(content=f"Error in reasoning: {str(e)}")
            return {
                "messages": [error_response],
                "has_final_answer": False
            }
    
    def _tool_execution_node(self, state: ManualReactState, config: RunnableConfig) -> Dict[str, Any]:
        """Tool execution node."""
        messages = state["messages"]
        last_message = messages[-1]
        
        # If no tool calls, skip tool execution
        if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
            return {"messages": []}
        
        # Create tools lookup
        tools_by_name = {
            tool.name if hasattr(tool, 'name') else tool.__name__: tool 
            for tool in state["tools_available"]
        }
        
        # Execute all tool calls
        tool_messages = []
        for tool_call in last_message.tool_calls:
            tool_name = tool_call["name"]
            tool_call_id = tool_call["id"]
            tool_args = tool_call["args"]
            
            self.logger.debug(f"Executing tool: {tool_name} with args: {tool_args}")
            
            if tool_name in tools_by_name:
                try:
                    # Execute the tool
                    tool = tools_by_name[tool_name]
                    result = tool.invoke(tool_args)
                    tool_messages.append(ToolMessage(
                        content=str(result),
                        name=tool_name,
                        tool_call_id=tool_call_id
                    ))
                    
                    
                except Exception as e:
                    tool_messages.append(ToolMessage(
                        content=f"Error executing {tool_name}: {str(e)}",
                        name=tool_name,
                        tool_call_id=tool_call_id
                    ))
            else:
                raise ValueError(f"Tool '{tool_name}' not found. Available tools: {list(tools_by_name.keys())}")
        
        return {"messages": tool_messages}
    
    def _check_stopping_condition_node(self, state: ManualReactState, config: RunnableConfig) -> Dict[str, Any]:
        """Check if the agent should stop or continue."""
        current_count = state["interaction_count"] + 1
        return {
            "interaction_count": current_count
        }
    
    def _prompt_final_answer_node(self, state: ManualReactState, config: RunnableConfig) -> Dict[str, Any]:
        """Node to prompt for final answer and set the flag."""
        self.logger.info("Prompting for final answer after reaching target interactions")
        return {
            "messages": [HumanMessage(content="Please stop searching. Based on the current information, provide your final answer.")],
            "reached_target_interactions": True
        }
    
    def _prompt_continue_info_node(self, state: ManualReactState, config: RunnableConfig) -> Dict[str, Any]:
        """Node to prompt for more information gathering."""
        self.logger.info("Prompting to continue collecting more information")
        return {
            "messages": [HumanMessage(content="Please continue collecting more information to strengthen your answer.")]
        }
    
    def _should_continue(self, state: ManualReactState) -> str:
        """Determine if the agent should continue or stop."""
        
        # In case our target interactions is not set, we use a safety check to prevent infinite loops with a maximum recursion limit
        max_recursion_limit = SAFE_RECURSION_LIMIT
        if state["interaction_count"] >= max_recursion_limit:
            self.logger.warning(f"Stopping: Reached maximum recursion limit ({max_recursion_limit})")
            return "stop"
        
        # Reach the target interactions
        if state["target_interactions"] and state["interaction_count"] >= state["target_interactions"]:
            # Two cases, either we have the final answer or we don't.
            if state["has_final_answer"]:
                self.logger.info("Stopping: Reached target interactions with final answer")
                return "stop"
            elif state["reached_target_interactions"]:
                self.logger.info("We have prompted at the very end to obtain the final answer, but still didn't get the final answer. So we stop here.")
                return "stop"
            else:
                self.logger.info("Need to prompt once at the very end to obtain the final answer")
                return "prompt_final"

        # stops before reaching the target interactions
        if state["has_final_answer"]:
            if state["agent_stop_type"] == AgentStopType.INTERACTION_SCALING.value:
                # For interaction scaling, continue even with final answer to gather more info
                return "prompt_more_info"
            else:
                return "stop"
        else:
            if isinstance(state["messages"][-1], ToolMessage):
                # If the last message is a tool call, we continue
                self.logger.info("The last message is a tool call, we continue")
                return "continue"
            elif state["agent_stop_type"] == AgentStopType.INTERACTION_SCALING.value:
                # For interaction scaling, continue collecting information
                return "prompt_more_info"
            elif not state["reached_target_interactions"]:
                # In some cases, we are not using the interaction scaling mode, but the model may end at the middle and output nothing. It often happens from over-thinking. So we give it a final chance to output the final answer.
                self.logger.info("Model ended early, need to prompt for final answer")
                return "prompt_final"
            else:
                return "stop"
 
    
    def _check_final_answer(self, message: AIMessage) -> bool:
        """Check if message contains non-empty content."""
        if hasattr(message, 'content') and isinstance(message.content, str) and message.content.strip() and not message.tool_calls:
            return True
        return False
    
    def invoke(self, inputs: Dict[str, Any], config: Dict[str, Any] = None) -> Dict[str, Any]:
        """Invoke the manual React agent.
        
        Args:
            inputs: Input dictionary containing messages
            config: Configuration dictionary
            
        Returns:
            Agent response dictionary
        """
        
        config = config or {}
        
        # Prepare initial state
        messages = inputs.get("messages", [])
        
        # For default mode, target_interactions is not really used, but we set a high value for safety
        target_interactions = self.target_interactions if self.agent_stop_type == AgentStopType.INTERACTION_SCALING else 150
        
        initial_state = {
            "messages": messages,
            "interaction_count": 0,
            "target_interactions": target_interactions,
            "agent_stop_type": self.agent_stop_type.value,
            "has_final_answer": False,
            "tools_available": self.search_tools,
            "model_name": self.model_name,
            "reached_target_interactions": False
        }
        
        self.logger.info(f"Starting manual React agent with {len(self.search_tools)} tools")
        
        # Execute the graph
        try:
            result = self.graph.invoke(initial_state, config)
            self.logger.info(f"Manual React agent completed after {result['interaction_count']} interactions")
            return result
        except Exception as e:
            self.logger.error(f"Error in manual React agent: {e}")
            raise
    
    def get_interaction_count(self) -> int:
        """Get the current interaction count (compatibility method)."""
        return 0  # This will be tracked in the state during execution
    
    def get_target_interactions(self) -> int:
        """Get the target number of interactions."""
        return self.target_interactions


def create_manual_react_agent(
    model: Any,
    system_prompt: str,
    search_tools: List[Any],
    logger: logging.Logger,
    model_name: str,
    agent_stop_type: str = "default",
    agent_stop_kwargs: Dict[str, Any] = None
) -> ManualReactAgent:
    """Factory function to create a manual React agent.
    
    Args:
        model: Language model instance
        system_prompt: System prompt for the agent
        search_tools: List of search tools available to the agent
        logger: Logger instance
        model_name: Name of the model (for caching decisions)
        agent_stop_type: Type of stopping condition
        agent_stop_kwargs: Additional configuration for stopping conditions
        
    Returns:
        ManualReactAgent instance
    """
    return ManualReactAgent(
        model=model,
        system_prompt=system_prompt,
        search_tools=search_tools,
        logger=logger,
        model_name=model_name,
        agent_stop_type=agent_stop_type,
        agent_stop_kwargs=agent_stop_kwargs
    )
