"""
Retry utilities for LLM operations.

This module provides configurable retry mechanisms with exponential backoff
for handling transient failures in LLM API calls.
"""

import logging
import random
import time
from typing import Any, Dict, List, Optional

from tenacity import (
    Retrying,
    after_log,
    before_sleep_log,
    retry_if_exception,
    stop_after_attempt,
    wait_exponential,
)


class RetryConfig:
    """Configuration class for retry behavior."""

    def __init__(
        self,
        max_attempts: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        backoff_multiplier: float = 2.0,
        jitter: bool = True,
        retryable_errors: Optional[List[str]] = None,
    ):
        """
        Initialize retry configuration.

        Args:
            max_attempts: Maximum number of retry attempts
            base_delay: Base delay in seconds before first retry
            max_delay: Maximum delay in seconds between retries
            backoff_multiplier: Exponential backoff multiplier
            jitter: Whether to add random jitter to delays
            retryable_errors: List of error class names that should trigger retries
        """
        self.max_attempts = max_attempts
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.backoff_multiplier = backoff_multiplier
        self.jitter = jitter
        self.retryable_errors = retryable_errors or [
            "RateLimitError",
            "APIConnectionError",
            "InternalServerError",
            "ServiceUnavailableError",
            "APITimeoutError",
        ]

    @classmethod
    def from_dict(cls, config: Dict[str, Any]) -> "RetryConfig":
        """Create RetryConfig from dictionary."""
        return cls(
            max_attempts=config.get("max_attempts", 3),
            base_delay=config.get("base_delay", 1.0),
            max_delay=config.get("max_delay", 60.0),
            backoff_multiplier=config.get("backoff_multiplier", 2.0),
            jitter=config.get("jitter", True),
            retryable_errors=config.get("retryable_errors"),
        )


def is_retryable_openai_error(exception: Exception) -> bool:
    """
    Determine if an OpenAI exception should trigger a retry.

    Args:
        exception: The exception to check

    Returns:
        True if the exception should trigger a retry
    """
    # Get the exception class name
    error_name = exception.__class__.__name__

    # Check for policy violations, as these should not be retried
    if "ContentPolicyViolationError" in error_name:
        return False

    # Check for common retryable OpenAI errors
    retryable_errors = {
        "RateLimitError",
        "APIConnectionError",
        "InternalServerError",
        "ServiceUnavailableError",
        "APITimeoutError",
        "Timeout",
        "ConnectionError",
        "HTTPError",
    }

    # Check if it's a retryable error type
    if error_name in retryable_errors:
        return True

    # Check HTTP status codes if available
    if hasattr(exception, "status_code"):
        # Retry on 5xx server errors and 429 rate limiting
        status = getattr(exception, "status_code")
        if status == 429 or (500 <= status < 600):
            return True

    # Check for specific error messages that indicate transient issues
    error_message = str(exception).lower()
    transient_indicators = [
        "rate limit",
        "quota exceeded",
        "server error",
        "service unavailable",
        "timeout",
        "connection",
        "network",
        "temporary",
    ]

    return any(indicator in error_message for indicator in transient_indicators)


def create_retry_decorator(
    retry_config: RetryConfig,
    logger: Optional[logging.Logger] = None,
    error_classifier: Optional[callable] = None,
):
    """
    Create a retry decorator with the specified configuration.

    Args:
        retry_config: Configuration for retry behavior
        logger: Optional logger for retry events
        error_classifier: Function to determine if an error is retryable

    Returns:
        Configured retry decorator
    """
    if logger is None:
        logger = logging.getLogger(__name__)

    if error_classifier is None:
        error_classifier = is_retryable_openai_error

    # Create wait strategy with optional jitter
    wait_strategy = wait_exponential(
        multiplier=retry_config.base_delay,
        max=retry_config.max_delay,
        exp_base=retry_config.backoff_multiplier,
    )

    def retry_condition(exception):
        """Check if exception should trigger a retry."""
        should_retry = error_classifier(exception)
        if should_retry:
            logger.warning(f"Retryable error encountered: {exception}")
        else:
            logger.error(f"Non-retryable error encountered: {exception}")
        return should_retry

    return Retrying(
        stop=stop_after_attempt(retry_config.max_attempts),
        wait=wait_strategy,
        retry=retry_if_exception(retry_condition),
        before_sleep=before_sleep_log(logger, logging.WARNING),
        after=after_log(logger, logging.INFO),
        reraise=True,
    )


def with_retry(
    retry_config: Optional[RetryConfig] = None,
    logger: Optional[logging.Logger] = None,
):
    """
    Decorator to add retry logic to a function.

    Args:
        retry_config: Configuration for retry behavior
        logger: Optional logger for retry events

    Returns:
        Decorated function with retry logic
    """
    if retry_config is None:
        retry_config = RetryConfig()

    def decorator(func):
        def wrapper(*args, **kwargs):
            retry_decorator = create_retry_decorator(retry_config, logger)

            # Add jitter if enabled
            if retry_config.jitter:
                base_sleep = retry_config.base_delay
                jitter_sleep = base_sleep * (0.5 + random.random() * 0.5)
                time.sleep(min(jitter_sleep, 0.1))  # Small initial jitter

            # Execute with retry logic
            for attempt in retry_decorator:
                with attempt:
                    return func(*args, **kwargs)

        return wrapper

    return decorator
