"""
Base class for all specialist ABCDE agents.
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Literal, Any
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.prompts import PromptTemplate
from langchain.agents import AgentExecutor, create_react_agent
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage

# Import state types
from ..orchestrator.state import MultiAgentState

class ABCDEAgent(ABC):
    """Base class for ABCDE agents with P&E/ReAct modes"""
    
    def __init__(
        self,
        agent_name: str,
        llm: BaseLanguageModel,
        tools: List[BaseTool],
        mode: Literal["plan_execute", "react", "function_calling"] = "plan_execute",
        vcot_policy: Literal["never", "on_low_conf", "borderline_only", "always"] = "on_low_conf",
        vcot_threshold_low: float = 0.3,
        vcot_threshold_high: float = 0.7
    ):
        self.agent_name = agent_name
        self.llm = llm
        self.tools = {tool.name: tool for tool in tools}
        self.mode = mode
        self.vcot_policy = vcot_policy
        self.vcot_threshold_low = vcot_threshold_low
        self.vcot_threshold_high = vcot_threshold_high
        
        if self.mode == "react":
            self._setup_react_agent()
        elif self.mode == "function_calling":
            self._setup_function_calling_agent()

    def _setup_react_agent(self):
        """Setup ReAct agent executor with proper tool binding"""
        print(f"[{self.agent_name}] Setting up ReAct agent...")
        
        # Create wrapped tools that auto-inject image_path
        wrapped_tools = [self._wrap_tool_with_image_path(tool) for tool in self.tools.values()]
        
        # Create ReAct prompt template
        prompt = PromptTemplate.from_template(self.get_react_prompt_template())
        
        # Create ReAct agent with wrapped tools
        agent = create_react_agent(
            llm=self.llm,
            tools=wrapped_tools,
            prompt=prompt
        )
        
        # Create agent executor with proper error handling
        self.agent_executor = AgentExecutor(
            agent=agent,
            tools=wrapped_tools,
            verbose=True,
            max_iterations=10,
            max_execution_time=300,  # 5 minutes
            handle_parsing_errors=True,
            return_intermediate_steps=True
        )
        
        print(f"[{self.agent_name}] ReAct agent setup complete with {len(self.tools)} tools")
    
    def _setup_function_calling_agent(self):
        """Setup function calling agent with StateGraph workflow (following main.py pattern)"""
        print(f"[{self.agent_name}] Setting up function calling agent...")
        
        # Create StateGraph workflow (exactly like main.py)
        from typing import TypedDict, Annotated
        import operator
        from langchain_core.messages import AnyMessage
        
        class AgentState(TypedDict):
            messages: Annotated[List[AnyMessage], operator.add]
            image_path: str  # Add image_path to state
        
        workflow = StateGraph(AgentState)
        workflow.add_node("process", self._process_request)
        workflow.add_node("execute", self._execute_tools)
        workflow.add_conditional_edges(
            "process", self._has_tool_calls, {True: "execute", False: END}
        )
        workflow.add_edge("execute", "process")
        workflow.set_entry_point("process")
        
        self.function_calling_workflow = workflow.compile()
        
        # Bind tools to model for function calling (exactly like main.py)
        self.model_with_tools = self.llm.bind_tools(list(self.tools.values()))
        
        print(f"[{self.agent_name}] Function calling agent setup complete with {len(self.tools)} tools")
    
    def _process_request(self, state):
        """Process request using function calling (following main.py pattern)"""
        messages = state["messages"]
        
        # Add system message with agent context
        system_msg = SystemMessage(content=f"""You are a {self.agent_name} specialist analyzing medical images.
Available tools: {list(self.tools.keys())}

Focus on your specialty area and use the appropriate tools to analyze the image.
Current image: {state.get('image_path', 'Unknown')}

Be systematic and thorough in your analysis.""")
        
        messages_with_system = [system_msg] + messages
        response = self.model_with_tools.invoke(messages_with_system)
        
        return {"messages": [response]}
    
    def _execute_tools(self, state):
        """Execute tool calls from the model's response (following main.py pattern)"""
        tool_calls = state["messages"][-1].tool_calls
        results = []
        
        for call in tool_calls:
            print(f"[{self.agent_name}] Executing tool: {call['name']}")
            
            if call["name"] not in self.tools:
                print(f"[{self.agent_name}] Invalid tool: {call['name']}")
                result = f"Error: Invalid tool {call['name']}"
            else:
                # Auto-inject image_path if not provided (enhancement over main.py)
                args = call["args"].copy()
                if "image_path" not in args:
                    args["image_path"] = state["image_path"]
                
                try:
                    result = self.tools[call["name"]].invoke(args)
                except Exception as e:
                    result = f"Error executing {call['name']}: {str(e)}"
            
            results.append(ToolMessage(
                tool_call_id=call["id"],
                name=call["name"],
                content=str(result)
            ))
        
        return {"messages": results}
    
    def _has_tool_calls(self, state):
        """Check if the response contains tool calls (exactly like main.py)"""
        response = state["messages"][-1]
        return hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
    
    def _wrap_tool_with_image_path(self, tool):
        """Wrap a tool to auto-inject image_path from current state"""
        from langchain_core.tools import Tool
        
        def wrapped_invoke(tool_input):
            print(f"[{self.agent_name}] Wrapping tool {tool.name} with input: {tool_input}")
            
            # Handle different input formats
            if isinstance(tool_input, str):
                # Try to parse as JSON first
                try:
                    import json
                    parsed_input = json.loads(tool_input)
                    if isinstance(parsed_input, dict):
                        tool_input = parsed_input
                    else:
                        # If not a dict, create a dict structure
                        tool_input = {}
                except:
                    # If JSON parsing fails, assume it's empty parameters
                    tool_input = {}
            
            # Ensure tool_input is a dict
            if not isinstance(tool_input, dict):
                tool_input = {}
            
            # Auto-inject image_path if not provided
            if "image_path" not in tool_input:
                if hasattr(self, '_current_image_path'):
                    tool_input["image_path"] = self._current_image_path
                    print(f"[{self.agent_name}] Auto-injected image_path: {self._current_image_path}")
                else:
                    print(f"[{self.agent_name}] WARNING: No current image path set!")
            
            print(f"[{self.agent_name}] Final tool input: {tool_input}")
            
            # Special handling for grounding tool
            if "grounding" in tool.name.lower() and "phrase" not in tool_input:
                error_msg = f"ERROR: Grounding tool {tool.name} requires 'phrase' parameter but it was not provided. Tool input: {tool_input}"
                print(f"[{self.agent_name}] {error_msg}")
                return {"error": error_msg, "tool_input_received": tool_input}
            
            try:
                result = tool.invoke(tool_input)
                print(f"[{self.agent_name}] Tool {tool.name} executed successfully")
                return result
            except Exception as e:
                error_msg = f"Error executing {tool.name}: {str(e)}"
                print(f"[{self.agent_name}] {error_msg}")
                return {"error": error_msg, "tool_input_received": tool_input}
        
        return Tool(
            name=tool.name,
            description=tool.description,
            func=wrapped_invoke
        )
    
    @abstractmethod
    def analyze(self, state: MultiAgentState) -> MultiAgentState:
        """
        The main method for the agent. It takes the current state,
        performs its analysis, and returns the updated state.
        """
        pass
    
    def get_react_prompt_template(self) -> str:
        """Get ReAct prompt template for this agent (default implementation)"""
        if self.mode == "react":
            raise NotImplementedError(f"{self.agent_name} must implement get_react_prompt_template() for ReAct mode")
        return ""  # Not needed for non-ReAct modes
    
    def should_trigger_vcot(self, confidence: float, is_borderline: bool = False) -> bool:
        """Determine if V-CoT should be triggered based on policy"""
        if self.vcot_policy == "never":
            return False
        elif self.vcot_policy == "always":
            return True
        elif self.vcot_policy == "borderline_only":
            return is_borderline
        elif self.vcot_policy == "on_low_conf":
            return confidence < self.vcot_threshold_high
        return False
    
    def set_current_image_path(self, image_path: str):
        """Set the current image path for auto-injection in ReAct mode"""
        self._current_image_path = image_path 