"""
Tool use functionality for Bedrock Claude client.

This module provides classes and utilities for implementing tool use (function calling)
with Claude models, supporting both serial and parallel execution strategies.
"""

import asyncio
import json
import logging
import time
import inspect
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional, Callable, Union, Awaitable


@dataclass
class Tool:
    """Represents a tool definition that Claude can use."""
    
    name: str
    description: str
    input_schema: Dict[str, Any]
    function: Callable
    async_function: Optional[Callable[..., Awaitable]] = None
    
    def __post_init__(self):
        """Validate tool definition after initialization."""
        if not self.name:
            raise ValueError("Tool name cannot be empty")
        if not self.description:
            raise ValueError("Tool description cannot be empty")
        if not self.input_schema:
            raise ValueError("Tool input_schema cannot be empty")
        if not callable(self.function):
            raise ValueError("Tool function must be callable")
    
    def to_claude_format(self) -> Dict[str, Any]:
        """Convert tool to Claude API format."""
        return {
            "name": self.name,
            "description": self.description,
            "input_schema": self.input_schema
        }
    
    def is_async(self) -> bool:
        """Check if tool has async execution capability."""
        return self.async_function is not None or asyncio.iscoroutinefunction(self.function)


@dataclass
class ToolCall:
    """Represents a tool call request from Claude."""
    
    id: str
    name: str
    input: Dict[str, Any]
    
    @classmethod
    def from_claude_response(cls, tool_use_block: Dict[str, Any]) -> "ToolCall":
        """Create ToolCall from Claude's response format."""
        return cls(
            id=tool_use_block["id"],
            name=tool_use_block["name"],
            input=tool_use_block["input"]
        )


@dataclass
class ToolResult:
    """Represents the result of tool execution."""
    
    tool_use_id: str
    content: Any
    is_error: bool = False
    execution_time: Optional[float] = None
    error_message: Optional[str] = None
    
    def to_claude_format(self) -> Dict[str, Any]:
        """Convert result to Claude API format."""
        if self.is_error:
            return {
                "tool_use_id": self.tool_use_id,
                "type": "tool_result",
                "content": f"Error: {self.error_message or str(self.content)}",
                "is_error": True
            }
        else:
            # Ensure content is JSON serializable
            content = self.content
            if not isinstance(content, (str, int, float, bool, list, dict, type(None))):
                content = str(content)
                
            return {
                "tool_use_id": self.tool_use_id,
                "type": "tool_result", 
                "content": content
            }


class ToolRegistry:
    """Registry for managing available tools."""
    
    def __init__(self):
        self._tools: Dict[str, Tool] = {}
        self.logger = logging.getLogger(__name__)
    
    def register(self, tool: Tool) -> None:
        """Register a tool in the registry."""
        if tool.name in self._tools:
            self.logger.warning(f"Overwriting existing tool: {tool.name}")
        self._tools[tool.name] = tool
        self.logger.info(f"Registered tool: {tool.name}")
    
    def unregister(self, tool_name: str) -> None:
        """Remove a tool from the registry."""
        if tool_name in self._tools:
            del self._tools[tool_name]
            self.logger.info(f"Unregistered tool: {tool_name}")
    
    def get(self, tool_name: str) -> Optional[Tool]:
        """Get a tool by name."""
        return self._tools.get(tool_name)
    
    def list_tools(self) -> List[str]:
        """Get list of registered tool names."""
        return list(self._tools.keys())
    
    def get_tools_for_claude(self, tool_names: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        """Get tools in Claude API format."""
        if tool_names is None:
            tools_to_include = self._tools.values()
        else:
            tools_to_include = [self._tools[name] for name in tool_names if name in self._tools]
        
        return [tool.to_claude_format() for tool in tools_to_include]
    
    def clear(self) -> None:
        """Clear all registered tools."""
        self._tools.clear()


class ToolExecutor:
    """Handles tool execution with serial and parallel strategies."""
    
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers
        self.logger = logging.getLogger(__name__)
    
    def execute_single_tool(self, tool: Tool, tool_call: ToolCall) -> ToolResult:
        """Execute a single tool synchronously."""
        start_time = time.time()
        
        try:
            self.logger.info(f"Executing tool: {tool_call.name} with input: {tool_call.input}")
            
            # Validate input against schema (basic validation)
            self._validate_input(tool_call.input, tool.input_schema)
            
            # Execute the tool function
            if asyncio.iscoroutinefunction(tool.function):
                # Handle async function in sync context
                loop = asyncio.new_event_loop()
                asyncio.set_event_loop(loop)
                try:
                    result = loop.run_until_complete(tool.function(**tool_call.input))
                finally:
                    loop.close()
            else:
                result = tool.function(**tool_call.input)
            
            execution_time = time.time() - start_time
            self.logger.info(f"Tool {tool_call.name} completed in {execution_time:.2f}s")
            
            return ToolResult(
                tool_use_id=tool_call.id,
                content=result,
                execution_time=execution_time
            )
            
        except Exception as e:
            execution_time = time.time() - start_time
            error_msg = f"Tool {tool_call.name} failed: {str(e)}"
            self.logger.error(error_msg)
            
            return ToolResult(
                tool_use_id=tool_call.id,
                content=None,
                is_error=True,
                error_message=str(e),
                execution_time=execution_time
            )
    
    def execute_serial(self, tools: Dict[str, Tool], tool_calls: List[ToolCall]) -> List[ToolResult]:
        """Execute tools one by one in sequence."""
        results = []
        
        for tool_call in tool_calls:
            if tool_call.name not in tools:
                result = ToolResult(
                    tool_use_id=tool_call.id,
                    content=None,
                    is_error=True,
                    error_message=f"Tool '{tool_call.name}' not found"
                )
            else:
                result = self.execute_single_tool(tools[tool_call.name], tool_call)
            
            results.append(result)
            
            # Log progress
            self.logger.info(f"Serial execution: {len(results)}/{len(tool_calls)} tools completed")
        
        return results
    
    def execute_parallel(self, tools: Dict[str, Tool], tool_calls: List[ToolCall]) -> List[ToolResult]:
        """Execute tools concurrently using ThreadPoolExecutor."""
        if not tool_calls:
            return []
        
        results = []
        
        with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tool_calls))) as executor:
            # Submit all tool calls
            future_to_call = {}
            for tool_call in tool_calls:
                if tool_call.name not in tools:
                    # Handle missing tool immediately
                    results.append(ToolResult(
                        tool_use_id=tool_call.id,
                        content=None,
                        is_error=True,
                        error_message=f"Tool '{tool_call.name}' not found"
                    ))
                else:
                    future = executor.submit(
                        self.execute_single_tool, 
                        tools[tool_call.name], 
                        tool_call
                    )
                    future_to_call[future] = tool_call
            
            # Collect results as they complete
            for future in as_completed(future_to_call):
                result = future.result()
                results.append(result)
                
                # Log progress
                completed = len(results)
                total = len(tool_calls)
                self.logger.info(f"Parallel execution: {completed}/{total} tools completed")
        
        # Sort results to match input order
        results.sort(key=lambda r: next(
            i for i, tc in enumerate(tool_calls) 
            if tc.id == r.tool_use_id
        ))
        
        return results
    
    async def execute_async(self, tools: Dict[str, Tool], tool_calls: List[ToolCall]) -> List[ToolResult]:
        """Execute tools asynchronously for maximum performance."""
        async def execute_async_tool(tool: Tool, tool_call: ToolCall) -> ToolResult:
            start_time = time.time()
            
            try:
                self.logger.info(f"Executing async tool: {tool_call.name}")
                
                # Validate input
                self._validate_input(tool_call.input, tool.input_schema)
                
                # Execute async function
                if tool.async_function:
                    result = await tool.async_function(**tool_call.input)
                elif asyncio.iscoroutinefunction(tool.function):
                    result = await tool.function(**tool_call.input)
                else:
                    # Run sync function in thread pool
                    loop = asyncio.get_event_loop()
                    result = await loop.run_in_executor(None, tool.function, **tool_call.input)
                
                execution_time = time.time() - start_time
                return ToolResult(
                    tool_use_id=tool_call.id,
                    content=result,
                    execution_time=execution_time
                )
                
            except Exception as e:
                execution_time = time.time() - start_time
                return ToolResult(
                    tool_use_id=tool_call.id,
                    content=None,
                    is_error=True,
                    error_message=str(e),
                    execution_time=execution_time
                )
        
        # Create tasks for all tool calls
        tasks = []
        for tool_call in tool_calls:
            if tool_call.name in tools:
                task = execute_async_tool(tools[tool_call.name], tool_call)
                tasks.append(task)
            else:
                # Handle missing tool
                tasks.append(asyncio.create_task(asyncio.coroutine(lambda: ToolResult(
                    tool_use_id=tool_call.id,
                    content=None,
                    is_error=True,
                    error_message=f"Tool '{tool_call.name}' not found"
                ))()))
        
        # Execute all tasks concurrently
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Handle any exceptions that weren't caught
        final_results = []
        for i, result in enumerate(results):
            if isinstance(result, Exception):
                final_results.append(ToolResult(
                    tool_use_id=tool_calls[i].id,
                    content=None,
                    is_error=True,
                    error_message=f"Async execution failed: {str(result)}"
                ))
            else:
                final_results.append(result)
        
        return final_results
    
    def _validate_input(self, input_data: Dict[str, Any], schema: Dict[str, Any]) -> None:
        """Basic input validation against schema."""
        # This is a simplified validation - you could use jsonschema library for full validation
        required = schema.get("required", [])
        properties = schema.get("properties", {})
        
        # Check required fields
        for field in required:
            if field not in input_data:
                raise ValueError(f"Required field '{field}' missing from input")
        
        # Check field types (basic)
        for field, value in input_data.items():
            if field in properties:
                expected_type = properties[field].get("type")
                if expected_type and not self._check_type(value, expected_type):
                    raise ValueError(f"Field '{field}' has incorrect type. Expected: {expected_type}")
    
    def _check_type(self, value: Any, expected_type: str) -> bool:
        """Basic type checking."""
        type_map = {
            "string": str,
            "number": (int, float),
            "integer": int,
            "boolean": bool,
            "array": list,
            "object": dict
        }
        
        if expected_type in type_map:
            return isinstance(value, type_map[expected_type])
        return True  # Unknown type, allow it


# Utility functions for creating common tools
def create_simple_tool(
    name: str, 
    description: str, 
    function: Callable,
    parameters: Dict[str, Any]
) -> Tool:
    """Helper function to create a simple tool."""
    
    # Auto-generate schema from function signature if not provided
    if "properties" not in parameters:
        sig = inspect.signature(function)
        properties = {}
        required = []
        
        for param_name, param in sig.parameters.items():
            if param.annotation != param.empty:
                # Map Python types to JSON schema types
                type_map = {
                    str: "string",
                    int: "integer", 
                    float: "number",
                    bool: "boolean",
                    list: "array",
                    dict: "object"
                }
                param_type = type_map.get(param.annotation, "string")
            else:
                param_type = "string"  # Default type
                
            properties[param_name] = {"type": param_type}
            
            if param.default == param.empty:
                required.append(param_name)
        
        parameters["properties"] = properties
        if required:
            parameters["required"] = required
    
    schema = {
        "type": "object",
        **parameters
    }
    
    return Tool(
        name=name,
        description=description, 
        input_schema=schema,
        function=function
    )
