"""
Enhanced monitoring and output formatting for LMTune agent tool calls.
"""

import json
import sys
from collections import defaultdict
from typing import Any, Dict, Optional

from rich.console import Console
from rich.table import Table
from rich.syntax import Syntax
from rich.panel import Panel
from langchain_core.callbacks.base import BaseCallbackHandler


# Global Rich Console instance with color support
console = Console(color_system="truecolor")


def log_system(msg: str, title: str = None) -> None:
    """Log a system message with optional title."""
    if title:
        console.print(f"[bold blue]{title}[/bold blue]: {msg}")
    else:
        console.print(f"[bold blue]system[/bold blue]: {msg}")
    
    # Ensure output is flushed immediately
    sys.stdout.flush()
    console.file.flush() if hasattr(console, "file") else None


def format_tool_output(result: Any) -> str:
    """Format tool outputs into readable text."""
    # Handle error objects
    if hasattr(result, "content"):
        return result.content
    
    # Handle dictionary responses
    if isinstance(result, dict):
        if result.get("isError") is True:
            return f"ERROR: {result.get('content', 'Unknown error')}"
        return result.get("content", str(result))
    
    # Handle string responses
    if isinstance(result, str):
        return result.replace("\\n", "\n")
    
    # Default: convert to string
    return str(result).replace("\\n", "\n")


class ToolStats:
    """Track tool usage statistics."""
    _instance = None
    
    def __init__(self):
        self.tool_calls = defaultdict(int)
        self.total_calls = 0
        self.enabled = True
        self.tool_outputs = {}  # Store last output for each tool
    
    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance
    
    def record_tool_call(self, tool_name: str, args: Dict[str, Any] = None, output: str = None):
        """Record a tool call with optional arguments and output."""
        if self.enabled:
            self.tool_calls[tool_name] += 1
            self.total_calls += 1
            if output:
                self.tool_outputs[tool_name] = output
    
    def get_stats(self) -> Dict[str, int]:
        """Get tool usage statistics."""
        return dict(self.tool_calls)
    
    def display_stats(self):
        """Display tool usage statistics in a table."""
        if self.enabled and self.total_calls > 0:
            table = Table(title="Tool Usage Statistics")
            table.add_column("Tool Name", style="cyan")
            table.add_column("Calls", style="yellow")
            
            # Sort tools by number of calls (descending)
            sorted_tools = sorted(
                self.tool_calls.items(), 
                key=lambda x: x[1], 
                reverse=True
            )
            
            for tool_name, count in sorted_tools:
                table.add_row(tool_name, str(count))
            
            table.add_row("TOTAL", str(self.total_calls), style="bold")
            
            console.print("\n")
            console.print(table)


class EnhancedToolCallbackHandler(BaseCallbackHandler):
    """Enhanced callback handler with detailed tool information and formatting."""

    def __init__(self, show_args: bool = True, show_output: bool = True, truncate_output: int = 500):
        """
        Initialize the enhanced handler.
        
        Args:
            show_args: Whether to display tool arguments
            show_output: Whether to display tool output
            truncate_output: Maximum characters to show in output (0 = no limit)
        """
        super().__init__()
        self._current_tool = None
        self._current_args = None
        self.show_args = show_args
        self.show_output = show_output
        self.truncate_output = truncate_output
        self.tool_stats = ToolStats.get_instance()

    def on_tool_start(
        self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
    ) -> None:
        """Display tool start with formatted arguments."""
        tool_name = serialized.get("name", "unknown_tool")
        self._current_tool = tool_name
        
        # Parse input to extract arguments if possible
        try:
            # Try to parse as JSON if it looks like JSON
            if input_str.strip().startswith('{'):
                args = json.loads(input_str)
                self._current_args = args
            else:
                # For non-JSON input, just store as string
                self._current_args = {"input": input_str}
        except (json.JSONDecodeError, AttributeError):
            self._current_args = {"input": input_str}
        
        # Display tool call with arguments
        console.print(f"[bold cyan]▶ {tool_name}[/bold cyan]", end="")
        
        if self.show_args and self._current_args:
            # Format arguments nicely
            if len(str(self._current_args)) > 100:
                console.print()
                console.print(Panel(
                    Syntax(json.dumps(self._current_args, indent=2), "json"),
                    title="Arguments",
                    expand=False
                ))
            else:
                args_str = json.dumps(self._current_args, indent=None)
                console.print(f" [dim]{args_str}[/dim]")
        else:
            console.print()
        
        # Record tool call
        self.tool_stats.record_tool_call(tool_name, self._current_args)

    def on_tool_end(self, output: Any, **kwargs: Any) -> None:
        """Display tool completion with formatted output."""
        if not self._current_tool:
            return
            
        formatted_output = format_tool_output(output)
        
        # Check for errors in the output
        is_error = isinstance(output, str) and any(
            error_term in output.lower()
            for error_term in [
                "error:",
                "failed",
                "execution failed",
                "traceback",
                "exception",
            ]
        )
        
        if is_error:
            console.print(f"[bold red] ✗ Failed[/bold red]")
            if self.show_output:
                # Show error output in a panel
                console.print(Panel(
                    formatted_output,
                    title=f"[red]Error from {self._current_tool}[/red]",
                    style="red"
                ))
        else:
            console.print(f"[bold green] ✓ Success[/bold green]")
            if self.show_output and formatted_output:
                # Truncate output if needed
                display_output = formatted_output
                if self.truncate_output > 0 and len(display_output) > self.truncate_output:
                    display_output = display_output[:self.truncate_output] + "... [truncated]"
                
                # Show successful output
                if len(display_output) > 200 or '\n' in display_output:
                    # Use panel for long or multi-line output
                    console.print(Panel(
                        display_output,
                        title=f"[green]Output from {self._current_tool}[/green]",
                        style="green"
                    ))
                else:
                    # Inline for short output
                    console.print(f"  [dim green]→ {display_output}[/dim green]")
        
        # Record output in stats
        self.tool_stats.record_tool_call(self._current_tool, output=formatted_output)
        
        # Reset current tool
        self._current_tool = None
        self._current_args = None

    def on_tool_error(
        self, error: Exception, **kwargs: Any
    ) -> None:
        """Display tool error with detailed information."""
        if self._current_tool:
            console.print(f"[bold red] ✗ Error[/bold red]")
            
            error_message = str(error)
            error_type = type(error).__name__
            
            console.print(Panel(
                f"[red]{error_type}[/red]: {error_message}",
                title=f"[red]Tool Error: {self._current_tool}[/red]",
                style="red"
            ))
        
        # Reset current tool
        self._current_tool = None
        self._current_args = None

    # Disable all other callbacks for minimal output
    def on_llm_start(self, *args, **kwargs) -> None:
        pass

    def on_llm_end(self, *args, **kwargs) -> None:
        pass

    def on_llm_error(self, *args, **kwargs) -> None:
        pass

    def on_chain_start(self, *args, **kwargs) -> None:
        pass

    def on_chain_end(self, *args, **kwargs) -> None:
        pass

    def on_chain_error(self, *args, **kwargs) -> None:
        pass

    def on_text(self, *args, **kwargs) -> None:
        pass


class SimpleToolCallbackHandler(BaseCallbackHandler):
    """Simple callback handler similar to the original MinimalToolCallbackHandler."""

    def __init__(self):
        """Initialize the handler."""
        super().__init__()
        self._current_tool = None
        self.tool_stats = ToolStats.get_instance()

    def on_tool_start(
        self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
    ) -> None:
        """Store the tool name for the current tool call."""
        tool_name = serialized.get("name", "unknown_tool")
        self._current_tool = tool_name
        console.print(f"[Tool] ▶ {tool_name}", end="")
        
        # Record tool call
        self.tool_stats.record_tool_call(tool_name)

    def on_tool_end(self, output: Any, **kwargs: Any) -> None:
        """Print tool result on the same line."""
        if output:
            # Check for errors in the output
            if isinstance(output, str) and any(
                error_term in output.lower()
                for error_term in [
                    "error:",
                    "failed",
                    "execution failed",
                    "traceback",
                    "exception",
                ]
            ):
                console.print(f" ✗")
                # Print the detailed error for the agent to see
                if "Code execution failed with error:" in output:
                    # Extract and display the actual error message
                    console.print(output)
            else:
                console.print(f" ✓")
        else:
            console.print(f" ✓")

    def on_tool_error(
        self, error: Exception, **kwargs: Any
    ) -> None:
        """Print tool error on the same line and provide detailed error information."""
        console.print(f" ✗")
        # Print detailed error information
        error_message = str(error)
        error_type = type(error).__name__
        console.print(f"Tool execution error: {error_type}: {error_message}")

    # Disable all other callbacks for minimal output
    def on_llm_start(self, *args, **kwargs) -> None:
        pass

    def on_llm_end(self, *args, **kwargs) -> None:
        pass

    def on_llm_error(self, *args, **kwargs) -> None:
        pass

    def on_chain_start(self, *args, **kwargs) -> None:
        pass

    def on_chain_end(self, *args, **kwargs) -> None:
        pass

    def on_chain_error(self, *args, **kwargs) -> None:
        pass

    def on_text(self, *args, **kwargs) -> None:
        pass