
from typing import Optional, Dict, List
from pydantic import (
    Field,
    model_validator
)
from src.utils.logsetup import logger

from src.agent import BaseAgent
from src.llm import LLM
from src.schema.message import Message
from src.prompt.toolmaker import TOOLMAKERPROMPT, TOOLMAKERPROMPT_GOLDANSWER


class ToolMakerAgent(BaseAgent):
    """ToolMaker agent to make tools based on the task query."""
    name: str = "ToolMakerAgent"
    description: Optional[str] = "ToolMaker agent to make tools based on the task query."

    toolmaker_llm: Optional[LLM] = Field(
        default=None,
        description="LLM for toolmaker"
    )

    @model_validator(mode="after")
    def initialize_agent(self) -> "ToolMakerAgent":
        try:
            self.toolmaker_llm = LLM(config_name="toolmaker")
            logger.info(f"Successfully initialized ToolMakerIntentAgent, model: {self.toolmaker_llm.model}")
        except Exception as e:
            logger.warning(f"Cannot initialize ToolMakerIntentAgent,  error: {e}, using default LLM")
            self.toolmaker_llm = self.llm

        return self

    def _fix_tool_parameters_format(self, tools: List[Dict]) -> List[Dict]:
        """
        Fix tool parameter format to ensure OpenAI API compliance
        Convert list-format parameters to dictionary format
        """
        fixed_tools = []
        
        for tool in tools:
            try:
                if not isinstance(tool, dict):
                    logger.warning(f"Skipping non-dict tool: {tool}")
                    continue
                    
                # Check basic tool structure
                if "type" not in tool or "function" not in tool:
                    logger.warning(f"Tool missing basic structure: {tool}")
                    continue
                    
                function_def = tool["function"]
                if "parameters" not in function_def:
                    logger.warning(f"Tool missing parameters: {tool}")
                    continue
                
                parameters = function_def["parameters"]
                
                # If parameters is in list format, convert to dictionary format
                if isinstance(parameters, list):
                    logger.info(f"Converting list-format parameters to dict format for tool: {function_def.get('name', 'unknown')}")
                    
                    # Build standard OpenAI parameter format
                    properties = {}
                    required_params = []
                    
                    for param in parameters:
                        if isinstance(param, dict) and "name" in param:
                            param_name = param["name"]
                            properties[param_name] = {
                                "type": param.get("type", "string"),
                                "description": param.get("description", "")
                            }
                            
                            # If parameter has no default value, consider it required
                            if "default" not in param:
                                required_params.append(param_name)
                    
                    # Build correct parameter format
                    fixed_parameters = {
                        "type": "object",
                        "properties": properties,
                        "required": required_params
                    }
                    
                    # Update tool definition
                    function_def["parameters"] = fixed_parameters
                
                fixed_tools.append(tool)
                
            except Exception as e:
                logger.error(f"Error fixing tool parameters: {e}, tool: {tool}")
                continue
        
        logger.info(f"Fixed {len(fixed_tools)} tools out of {len(tools)} total tools")
        return fixed_tools

    async def run(
        self,
        task_query: str,
        gold_answer: Optional[str] = None,
        task_classification: Optional[Dict] = None
    ) -> List[Dict]:
        try:
            if gold_answer:
                messages = [
                    Message.user_message(TOOLMAKERPROMPT_GOLDANSWER.format(task_query, gold_answer))
                ]
            else:
                # Format the prompt with classification information
                if task_classification:
                    formatted_prompt = TOOLMAKERPROMPT.format(
                        task_query=task_query,
                        task_type=task_classification.get("primary_type", "unknown"),
                        complexity=task_classification.get("complexity_level", "medium"),
                        domain=task_classification.get("domain", "general"),
                        toolmaker_guidance=task_classification.get("toolmaker_guidance", "Generate appropriate tools for this task")
                    )
                else:
                    # Fallback to original format without classification
                    formatted_prompt = TOOLMAKERPROMPT.format(
                        task_query=task_query,
                        task_type="unknown",
                        complexity="medium", 
                        domain="general",
                        toolmaker_guidance="Generate appropriate tools for this task"
                    )
                
                messages = [
                    Message.user_message(formatted_prompt)
                ]
            
            tools = await self.toolmaker_llm.ask_json(
                messages,
            )
            
            # Add debug logs
            logger.info(f"🔧 DEBUG: ToolMaker returned original data type: {type(tools)}")
            logger.info(f"🔧 DEBUG: ToolMaker returned data content: {str(tools)[:500]}...")
            
            # Handle different return formats
            if isinstance(tools, list):
                # Validate and fix tool parameter format
                fixed_tools = self._fix_tool_parameters_format(tools)
                return fixed_tools
            elif isinstance(tools, dict):
                # Check if contains 'tools' key
                if 'tools' in tools and isinstance(tools['tools'], list):
                    logger.info("ToolMaker returned dict with 'tools' key, extracting tools array")
                    fixed_tools = self._fix_tool_parameters_format(tools['tools'])
                    return fixed_tools
                else:
                    # If it's a dictionary, try wrapping in list
                    logger.info("ToolMaker returned dict, wrapping in list")
                    fixed_tools = self._fix_tool_parameters_format([tools])
                    return fixed_tools
            elif isinstance(tools, str):
                # If it's a string, try parsing JSON
                try:
                    import json
                    parsed_tools = json.loads(tools)
                    logger.info(f"Successfully parsed JSON string, type: {type(parsed_tools)}")
                    if isinstance(parsed_tools, list):
                        fixed_tools = self._fix_tool_parameters_format(parsed_tools)
                        return fixed_tools
                    elif isinstance(parsed_tools, dict):
                        if 'tools' in parsed_tools and isinstance(parsed_tools['tools'], list):
                            fixed_tools = self._fix_tool_parameters_format(parsed_tools['tools'])
                            return fixed_tools
                        else:
                            fixed_tools = self._fix_tool_parameters_format([parsed_tools])
                            return fixed_tools
                except json.JSONDecodeError as e:
                    logger.error(f"Failed to parse JSON string: {e}")
            
            logger.warning(f"ToolMaker returned unsupported format, type: {type(tools)}, returning empty list")
            return []
                
        except Exception as e:
            logger.error(f"Error in ToolMaker: {e}")
            return []
        
    async def step(self, current_step = 1, **kwargs):
        return await super().step(current_step, **kwargs)