import traceback
from typing import Optional, Dict, List
from pydantic import (
    Field,
    model_validator
)
from src.utils.logsetup import logger
from src.agent import BaseAgent
from src.llm import LLM
from src.prompt.autoagent import SYS_AUTOAGENT_PROPMT
from src.schema.message import Message, Memory, ToolCall, Function, AgentState
from src.utils.exceptions import TokenLimitExceeded
from src.tools.bash import Bash
from src.tools.python_execute import PythonExecute
from src.tools.web_search import WebSearch
from src.tools.answer_summarizer import AnswerSummarizer
import json

class AutoAgent(BaseAgent):
    name: str = "AutoAgent"
    description: str = "AutoAgent"
    tools: Optional[List[Dict]] = Field(
        default_factory=lambda: [
            Bash().to_param(),
            PythonExecute().to_param(),
            WebSearch().to_param(),
            AnswerSummarizer().to_param()
        ],
        description="Tools for AutoAgent"
    )
    tool_map: Optional[Dict] = Field(
        default_factory=lambda: {
            "bash": Bash(),
            "python_execute": PythonExecute(),
            "web_search": WebSearch(),
            "answer_summarizer": AnswerSummarizer()
        },
        description="Tool map for AutoAgent"
    )

    system_prompt: str = SYS_AUTOAGENT_PROPMT
    memory: Memory = Field(default_factory=Memory)
    toolactor: Optional[BaseAgent] = None
    task_classification: Optional[Dict] = None  # Add task classification field
    gold_answer: Optional[str] = None  # Add gold answer field
    
    @model_validator(mode="after")
    def initialize_agent(self) -> "AutoAgent":
        try:
            self.llm = LLM(config_name="autoagent")
            logger.info(f"Successfully initialized AutoAgent, model: {self.llm.model}")
        except Exception as e:
            logger.warning(f"Cannot initialize AutoAgent, error: {e}, using default LLM")
            self.llm = self.llm

        return self
    
    def set_tools(self, tools: Optional[List[Dict]]):
        logger.info(f"🪛 Set tools: {tools}")
        if tools:
            converted_tools = []
            for tool in tools:
                # Handle different tool formats
                if 'function' in tool and 'type' in tool:
                    # Already in standard OpenAI function calling format
                    tool_name = tool['function']['name']
                    converted_tool = tool
                elif 'name' in tool:
                    # Direct format from ToolMaker - convert to standard format
                    tool_name = tool['name']
                    converted_tool = {
                        'type': 'function',
                        'function': tool
                    }
                else:
                    logger.warning(f"Unknown tool format: {tool}")
                    continue
                    
                converted_tools.append(converted_tool)
                self.tool_map[tool_name] = converted_tool
            
            # Add converted tools to the tools list
            self.tools.extend(converted_tools)

    def set_toolactor(self, toolactor: Optional[BaseAgent]):
        logger.info(f"🤖 Set toolactor: {toolactor}")
        self.toolactor = toolactor

    async def step(
        self,
        current_step = 1,
        **kwargs
    ):
        """Execute a single step: think and act.
        
        Args:
            current_step: The current step number in the execution sequence.
            **kwargs: Additional keyword arguments that will be passed to BaseAgent.step
            
        Returns:
            str: The result of the step execution.
        """
        # Save current step number for use in think method
        self.current_step = current_step
        
        should_act = await self.think()
        if not should_act:
            return "Thinking complete - no action needed"
        return await self.act()

    async def think(
        self,
    ):
        """Think about the next action."""
        try:
            messages = Memory.filter_empty_content(self.messages)
            
            # Build system prompt with classification information
            system_prompt = self.system_prompt
            if self.task_classification:
                classification_info = f"""

## Task Classification Guidance
Task Type: {self.task_classification.get('primary_type', 'unknown')}
Complexity: {self.task_classification.get('complexity_level', 'medium')}
Domain: {self.task_classification.get('domain', 'general')}
Strategy: {self.task_classification.get('autoagent_strategy', 'Use appropriate tools to complete the task')}
Estimated Steps: {self.task_classification.get('estimated_steps', 'multiple')}

Based on this classification, adapt your approach accordingly."""
                system_prompt = system_prompt + classification_info
            
            response = await self.llm.ask_tool(
                messages,
                system_msgs=(
                    [Message.system_message(system_prompt, cache=True)]
                    if system_prompt
                    else None
                ),
                tools=self.tools,
            )
        except Exception as e:
            # Check if this is a RetryError containing TokenLimitExceeded
            if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded):
                token_limit_error = e.__cause__
                logger.error(
                    f"🚨 Token limit error (from RetryError): {token_limit_error}"
                )
                self.memory.add_message(
                    Message.assistant_message(
                        f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}"
                    )
                )
                return False
            raise
        
        if not response.tool_calls:
            logger.warning(f"response.tool_calls is None, input messages: {messages}")
            self.tool_calls = []

        logger.info(f"✨ {self.name}'s thoughts: {response.content}")
        self.tool_reasoning = response.content    

        if response.tool_calls:
            self.tool_calls = [ToolCall(
                id=tc.id,
                function=Function(
                    name=tc.function.name,
                    arguments=tc.function.arguments
                )
            ) for tc in response.tool_calls if tc is not None]

            # Log response info
            logger.info(
                f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}"
            )
            logger.info(
                f"🔧 Tool arguments: {[call.function.arguments for call in response.tool_calls]}"
            )

            assistant_msg = Message.from_tool_calls(
                content=response.content,
                tool_calls=self.tool_calls
            )
            self.memory.add_message(assistant_msg)
            return True
        else:
            assistant_msg = Message.assistant_message(
                content=response.content
            )

            self.memory.add_message(assistant_msg)

            # Check if this response contains a definitive answer
            if self._has_definitive_answer(response.content):
                logger.info("✅ Detected definitive answer, finishing task")
                self.state = AgentState.FINISHED
                return False

            if not self.tool_reasoning:
                logger.warning(f"tool_calls and response are both None")
                self.memory.add_message(
                    Message.assistant_message(
                        "Please use the `answer_summarizer` tool to summarize your findings and provide a final answer."
                    )
                )

            return False

    def _has_definitive_answer(self, content: str) -> bool:
        """Check if the content contains a definitive answer or answer_summarizer usage"""
        if not content:
            return False
        
        # Check if answer_summarizer tool was used (most reliable indicator)
        if "answer_summarizer" in content.lower():
            return True
            
        # Check for answer tags or boxed format
        if "<answer>" in content.lower() or "\\boxed{" in content:
            return True
        
        content_lower = content.lower()
        answer_indicators = [
            "the answer is:",
            "the answer is ",
            "based on my research, i can conclude that:",
            "based on my research, i can conclude that ",
            "i can conclude that:",
            "i can conclude that ",
            "the answer is:",
            "the answer is ",
            "the conclusion is:",
            "the conclusion is "
        ]
        
        return any(indicator in content_lower for indicator in answer_indicators)

    async def act(self) -> str:
        """Execute decided actions"""
        if not self.tool_calls:
            return "No tool calls to execute"
        
        results = []

        for idx, tool_call in enumerate(self.tool_calls):
            tool_name = tool_call.function.name
            tool_id = tool_call.id

            try:
                result = await self.execute_tool(tool_call)
            except Exception as e:
                error_message = f"Error executing tool {tool_name}: {str(e)}"
                logger.error(error_message)
                result = error_message

            # Ensure that each tool call is followed by a corresponding tool result message
            self.memory.add_message(
                Message.tool_message(
                    content=result,
                    name=tool_name,
                    tool_call_id=tool_id
                )
            )
            logger.info(f"TOOL {idx + 1} RESULT:\n{result}")
            result_msg = f"TOOL {idx + 1} RESULT:\n{result}"
            results.append(result_msg)

        return "\n\n".join(results)
           
    async def execute_tool(self, tool_call: ToolCall) -> str:
        """Execute a single tool call."""
        tool_name = tool_call.function.name
        
        # Check if this is answer_summarizer tool and handle task completion
        if tool_name == "answer_summarizer":
            try:
                logger.info(f"🔥 Executing answer_summarizer tool")
                
                # Parse tool arguments
                try:
                    arguments = json.loads(tool_call.function.arguments)
                except json.JSONDecodeError:
                    return f"Error: Invalid JSON arguments for tool {tool_name}"
                
                # If we have gold_answer, replace the final_answer but keep everything else
                if self.gold_answer:
                    arguments["final_answer"] = self.gold_answer
                
                # Execute answer_summarizer
                tool_instance = self.tool_map.get(tool_name)
                if tool_instance:
                    result = await tool_instance.execute(**arguments)
                    self.state = AgentState.FINISHED  # Mark task as finished after answer summarization
                    return result
                else:
                    return f"Error: Tool '{tool_name}' not found in tool map"
                    
            except Exception as e:
                error_msg = f"⚠️ Tool '{tool_name}' encountered a problem: {str(e)}"
                logger.error(error_msg)
                return f"Error: {error_msg}"
        
        # Check if this is a real tool 
        if tool_name in ["bash", "python_execute", "web_search"]:
            try:
                logger.info(f"🔥 Executing real tool: {tool_name}")
                
                # Parse tool arguments
                try:
                    arguments = json.loads(tool_call.function.arguments)
                except json.JSONDecodeError:
                    return f"Error: Invalid JSON arguments for tool {tool_name}"
                
                # Execute real tool directly
                tool_instance = self.tool_map[tool_name]
                result = await tool_instance.execute(**arguments)
                
                # Format result
                if isinstance(result, dict):
                    formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
                else:
                    formatted_result = str(result)
                
                return f"Observed output of cmd `{tool_name}` executed:\n{formatted_result}"
                
            except Exception as e:
                error_msg = f"⚠️ Real tool '{tool_name}' encountered a problem: {str(e)}"
                logger.error(error_msg)
                return f"Error: {error_msg}"
        
        # For other tools, use ToolActor simulation
        if not self.toolactor:
            raise NotImplementedError("ToolActorAgent is not set")
        
        try:
            tool_function = self.tool_map.get(tool_name)
            tool_observation = await self.toolactor.run(
                reasoning=self.tool_reasoning,
                tool_call = tool_call,
                tool_function=tool_function,
                gold_answer=self.gold_answer
            )

            # Format result for display (standard case)
            observation = (
                f"Observed output of cmd `{tool_name}` executed:\n{str(tool_observation)}"
                if tool_observation
                else f"Cmd `{tool_name}` completed with no output"
            )

            return observation
        except Exception as e:
            traceback.print_exc()
            error_msg = f"⚠️ Tool '{tool_name}' encountered a problem: {str(e)}"
            logger.error(error_msg)
            return f"Error: {error_msg}"


