"""
Error Tracking Infrastructure for LLM Invocation Failures

This module provides infrastructure to track and manage LLM invocation errors
that occur after retry mechanisms have been exhausted. It captures full context
for debugging and provides statistics for experiment evaluation.
"""

import time
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional


@dataclass
class LLMInvocationError:
    """
    Represents a single LLM invocation error with full context.

    Attributes:
        timestamp: When the error occurred (Unix timestamp)
        error_message: The error message from the exception
        error_type: The type of exception that occurred
        input_messages: The messages sent to the LLM
        model_config: Configuration of the model that failed
        retry_attempts: Number of retry attempts that were made
        task_id: Optional identifier for the specific task/item that failed
    """

    timestamp: float
    error_message: str
    error_type: str
    input_messages: List[Dict[str, Any]]
    model_config: Dict[str, Any]
    retry_attempts: int
    task_id: Optional[str] = None


class ErrorTracker:
    """
    Tracks and manages LLM invocation errors during experiment execution.

    This class collects error information, provides statistics, and enables
    serialisation of error data for experiment output files.
    """

    def __init__(self):
        """Initialise an empty error tracker."""
        self.errors: List[LLMInvocationError] = []

    def record_error(
        self,
        error: Exception,
        input_messages: List[Dict[str, Any]],
        model_config: Dict[str, Any],
        retry_attempts: int,
        task_id: Optional[str] = None,
    ) -> None:
        """
        Record a new LLM invocation error.

        Args:
            error: The exception that occurred
            input_messages: The messages that were sent to the LLM
            model_config: Configuration of the model that failed
            retry_attempts: Number of retry attempts that were made
            task_id: Optional identifier for the specific task/item
        """
        llm_error = LLMInvocationError(
            timestamp=time.time(),
            error_message=str(error),
            error_type=type(error).__name__,
            input_messages=input_messages,
            model_config=model_config,
            retry_attempts=retry_attempts,
            task_id=task_id,
        )
        self.errors.append(llm_error)

    def get_error_count(self) -> int:
        """
        Get the total number of errors recorded.

        Returns:
            Total number of LLM invocation errors
        """
        return len(self.errors)

    def get_summary(self) -> Dict[str, Any]:
        """
        Get a summary of error statistics.

        Returns:
            Dictionary containing error statistics including:
            - failed_llm_calls: Total number of failed calls
            - error_types: Count by error type
            - average_retry_attempts: Average number of retries attempted
            - sample_errors: Up to 3 sample errors for debugging
        """
        if not self.errors:
            return {
                "failed_llm_calls": 0,
                "error_types": {},
                "average_retry_attempts": 0,
                "sample_errors": [],
            }

        # Count error types
        error_types = {}

        for error in self.errors:
            error_type = error.error_type
            error_types[error_type] = error_types.get(error_type, 0) + 1

        return {
            "failed_llm_calls": len(self.errors),
            "error_types": error_types,
            "errors": self.to_dict(),
        }

    def get_success_rate(self, total_attempts: int) -> float:
        """
        Calculate the success rate given total attempts.

        Args:
            total_attempts: Total number of LLM invocation attempts

        Returns:
            Success rate as a float between 0 and 1
        """
        if total_attempts == 0:
            return 1.0

        failed_attempts = len(self.errors)
        successful_attempts = total_attempts - failed_attempts
        return successful_attempts / total_attempts

    def to_dict(self) -> List[Dict[str, Any]]:
        """
        Serialize all errors to a list of dictionaries.

        Returns:
            List of error dictionaries suitable for JSON serialization
        """
        return [asdict(error) for error in self.errors]

    def _get_input_preview(self, input_messages: List[Dict[str, Any]]) -> str:
        """
        Get a preview of the input messages for error summaries.

        Args:
            input_messages: The input messages to preview

        Returns:
            A truncated string representation of the input
        """
        if not input_messages:
            return "No input messages"

        # Get the first user message content
        for message in input_messages:
            if message.get("role") == "user" and message.get("content"):
                content = str(message["content"])
                # Truncate if too long
                if len(content) > 200:
                    return content[:200] + "..."
                return content

        # Fallback to first message content
        first_message = input_messages[0]
        if first_message.get("content"):
            content = str(first_message["content"])
            if len(content) > 200:
                return content[:200] + "..."
            return content

        return f"Message with role: {first_message.get('role', 'unknown')}"
