"""
Tests for retry utilities.

This module contains tests for the retry mechanism implementation in retry_utils.py.
"""

import logging
import pytest
import time
from unittest.mock import Mock, patch

from src.utils.retry_utils import (
    RetryConfig,
    create_retry_decorator,
    is_retryable_openai_error,
    with_retry,
)


class TestRetryConfig:
    """Test cases for RetryConfig class."""

    def test_default_initialization(self):
        """Test RetryConfig with default values."""
        config = RetryConfig()
        
        assert config.max_attempts == 3
        assert config.base_delay == 1.0
        assert config.max_delay == 60.0
        assert config.backoff_multiplier == 2.0
        assert config.jitter is True
        assert "RateLimitError" in config.retryable_errors
        assert "APIConnectionError" in config.retryable_errors

    def test_custom_initialization(self):
        """Test RetryConfig with custom values."""
        config = RetryConfig(
            max_attempts=5,
            base_delay=2.0,
            max_delay=120.0,
            backoff_multiplier=3.0,
            jitter=False,
            retryable_errors=["CustomError"]
        )
        
        assert config.max_attempts == 5
        assert config.base_delay == 2.0
        assert config.max_delay == 120.0
        assert config.backoff_multiplier == 3.0
        assert config.jitter is False
        assert config.retryable_errors == ["CustomError"]

    def test_from_dict(self):
        """Test creating RetryConfig from dictionary."""
        config_dict = {
            "max_attempts": 4,
            "base_delay": 1.5,
            "max_delay": 90.0,
            "backoff_multiplier": 2.5,
            "jitter": False,
            "retryable_errors": ["TestError"]
        }
        
        config = RetryConfig.from_dict(config_dict)
        
        assert config.max_attempts == 4
        assert config.base_delay == 1.5
        assert config.max_delay == 90.0
        assert config.backoff_multiplier == 2.5
        assert config.jitter is False
        assert config.retryable_errors == ["TestError"]

    def test_from_dict_partial(self):
        """Test creating RetryConfig from partial dictionary."""
        config_dict = {"max_attempts": 5}
        
        config = RetryConfig.from_dict(config_dict)
        
        assert config.max_attempts == 5
        assert config.base_delay == 1.0  # Default value
        assert config.jitter is True  # Default value


class TestIsRetryableOpenAIError:
    """Test cases for is_retryable_openai_error function."""

    def test_retryable_error_names(self):
        """Test that known retryable error names are detected."""
        class RateLimitError(Exception):
            pass
        
        class APIConnectionError(Exception):
            pass
        
        assert is_retryable_openai_error(RateLimitError()) is True
        assert is_retryable_openai_error(APIConnectionError()) is True

    def test_non_retryable_error_names(self):
        """Test that non-retryable error names are not detected."""
        class AuthenticationError(Exception):
            pass
        
        class InvalidRequestError(Exception):
            pass
        
        assert is_retryable_openai_error(AuthenticationError()) is False
        assert is_retryable_openai_error(InvalidRequestError()) is False

    def test_content_policy_violation_error(self):
        """Test that ContentPolicyViolationError is never retried."""
        class ContentPolicyViolationError(Exception):
            pass
        
        class LiteLLMContentPolicyViolationError(Exception):
            pass
        
        class MyContentPolicyViolationErrorWrapper(Exception):
            pass
        
        # Should not be retried - exact match
        error_exact = ContentPolicyViolationError("Generic error message")
        assert is_retryable_openai_error(error_exact) is False
        
        # Should not be retried - with module prefix
        error_with_prefix = LiteLLMContentPolicyViolationError(
            "litellm.ContentPolicyViolationError: AzureException - The response was filtered due to the prompt triggering Azure OpenAI's content management policy"
        )
        assert is_retryable_openai_error(error_with_prefix) is False
        
        # Should not be retried - wrapped/custom versions
        error_wrapped = MyContentPolicyViolationErrorWrapper("Wrapped policy error")
        assert is_retryable_openai_error(error_wrapped) is False

    def test_error_with_status_code_429(self):
        """Test that errors with status code 429 are retryable."""
        class CustomError(Exception):
            def __init__(self):
                self.status_code = 429
        
        assert is_retryable_openai_error(CustomError()) is True

    def test_error_with_5xx_status_codes(self):
        """Test that errors with 5xx status codes are retryable."""
        class ServerError(Exception):
            def __init__(self, status_code):
                self.status_code = status_code
        
        assert is_retryable_openai_error(ServerError(500)) is True
        assert is_retryable_openai_error(ServerError(502)) is True
        assert is_retryable_openai_error(ServerError(503)) is True

    def test_error_with_4xx_status_codes(self):
        """Test that non-429 4xx errors are not retryable."""
        class ClientError(Exception):
            def __init__(self, status_code):
                self.status_code = status_code
        
        assert is_retryable_openai_error(ClientError(400)) is False
        assert is_retryable_openai_error(ClientError(401)) is False
        assert is_retryable_openai_error(ClientError(404)) is False

    def test_error_message_indicators(self):
        """Test that error messages with transient indicators are retryable."""
        class CustomError(Exception):
            def __init__(self, message):
                self.message = message
            
            def __str__(self):
                return self.message
        
        assert is_retryable_openai_error(CustomError("Rate limit exceeded")) is True
        assert is_retryable_openai_error(CustomError("Server error occurred")) is True
        assert is_retryable_openai_error(CustomError("Connection timeout")) is True
        assert is_retryable_openai_error(CustomError("Service unavailable")) is True

    def test_non_transient_error_messages(self):
        """Test that non-transient error messages are not retryable."""
        class CustomError(Exception):
            def __init__(self, message):
                self.message = message
            
            def __str__(self):
                return self.message
        
        assert is_retryable_openai_error(CustomError("Invalid API key")) is False
        assert is_retryable_openai_error(CustomError("Bad request format")) is False


class TestCreateRetryDecorator:
    """Test cases for create_retry_decorator function."""

    def test_retry_decorator_creation(self):
        """Test that retry decorator is created successfully."""
        config = RetryConfig(max_attempts=3)
        logger = logging.getLogger("test")
        
        decorator = create_retry_decorator(config, logger)
        
        assert decorator is not None
        assert decorator.stop.max_attempt_number == 3

    def test_retry_decorator_with_custom_classifier(self):
        """Test retry decorator with custom error classifier."""
        config = RetryConfig(max_attempts=2)
        
        def custom_classifier(exception):
            return isinstance(exception, ValueError)
        
        decorator = create_retry_decorator(config, error_classifier=custom_classifier)
        
        assert decorator is not None

    def test_retry_decorator_defaults(self):
        """Test retry decorator with default parameters."""
        config = RetryConfig()
        
        decorator = create_retry_decorator(config)
        
        assert decorator is not None


class TestWithRetryDecorator:
    """Test cases for with_retry decorator."""

    def test_successful_function_no_retry(self):
        """Test that successful functions execute without retry."""
        @with_retry(RetryConfig(max_attempts=3))
        def successful_function():
            return "success"
        
        result = successful_function()
        assert result == "success"

    def test_function_with_retryable_error(self):
        """Test function that fails with retryable error then succeeds."""
        call_count = 0
        
        @with_retry(RetryConfig(max_attempts=3, base_delay=0.01, jitter=False))
        def failing_then_success():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                class RateLimitError(Exception):
                    pass
                raise RateLimitError("Rate limit exceeded")
            return "success"
        
        result = failing_then_success()
        assert result == "success"
        assert call_count == 2

    def test_function_with_non_retryable_error(self):
        """Test function that fails with non-retryable error."""
        @with_retry(RetryConfig(max_attempts=3, jitter=False))
        def always_fails():
            raise ValueError("Invalid input")
        
        with pytest.raises(ValueError):
            always_fails()

    def test_function_exceeds_max_attempts(self):
        """Test function that exceeds maximum retry attempts."""
        @with_retry(RetryConfig(max_attempts=2, base_delay=0.01, jitter=False))
        def always_fails_retryable():
            class RateLimitError(Exception):
                pass
            raise RateLimitError("Rate limit exceeded")
        
        with pytest.raises(Exception):  # Should raise the original exception after retries
            always_fails_retryable()

    def test_default_retry_config(self):
        """Test decorator with default retry configuration."""
        @with_retry()
        def test_function():
            return "default_config"
        
        result = test_function()
        assert result == "default_config"

    @patch('time.sleep')
    def test_jitter_behavior(self, mock_sleep):
        """Test that jitter is applied when enabled."""
        @with_retry(RetryConfig(max_attempts=1, jitter=True))
        def test_function():
            return "jitter_test"
        
        result = test_function()
        assert result == "jitter_test"
        # Should have called sleep for initial jitter
        mock_sleep.assert_called_once()

    @patch('time.sleep')
    def test_no_jitter_behavior(self, mock_sleep):
        """Test that jitter is not applied when disabled."""
        @with_retry(RetryConfig(max_attempts=1, jitter=False))
        def test_function():
            return "no_jitter_test"
        
        result = test_function()
        assert result == "no_jitter_test"
        # Should not have called sleep for jitter
        mock_sleep.assert_not_called()

    def test_custom_logger(self):
        """Test decorator with custom logger."""
        logger = logging.getLogger("custom_test")
        
        @with_retry(RetryConfig(max_attempts=1), logger=logger)
        def test_function():
            return "custom_logger"
        
        result = test_function()
        assert result == "custom_logger"

    def test_function_with_args_and_kwargs(self):
        """Test that function arguments are preserved."""
        @with_retry(RetryConfig(max_attempts=1, jitter=False))
        def function_with_params(arg1, arg2, kwarg1=None):
            return f"{arg1}-{arg2}-{kwarg1}"
        
        result = function_with_params("a", "b", kwarg1="c")
        assert result == "a-b-c"