"""
Tests for the retry mechanism in the base LLM interface.

This module contains tests for the retry configuration in base.py.
"""

from unittest.mock import MagicMock, mock_open, patch

import pytest

from src.llm.base import LLMInterface
from src.utils.retry_utils import RetryConfig


class MockLLM(LLMInterface):
    """Mock LLM implementation for testing."""

    def generate(self, _messages, **_kwargs):
        return "mock response"

    @property
    def model_info(self):
        return {"name": "mock", "version": "test"}


class TestLLMInterfaceRetryMechanism:
    """Tests for the retry mechanism in LLMInterface."""

    def test_init_with_retry_config(self):
        """Test initialisation with custom retry configuration."""
        retry_config = RetryConfig(max_attempts=5, base_delay=2.0)

        llm = MockLLM(retry_config=retry_config)

        assert llm.retry_config == retry_config
        assert llm.retry_config.max_attempts == 5
        assert llm.retry_config.base_delay == 2.0

    def test_init_with_default_retry_config(self):
        """Test initialisation with default retry configuration."""
        llm = MockLLM()

        assert isinstance(llm.retry_config, RetryConfig)
        assert llm.retry_config.max_attempts == 3  # Default value
        assert llm.retry_config.base_delay == 1.0  # Default value

    def test_init_with_none_retry_config(self):
        """Test initialisation with None retry configuration creates default."""
        llm = MockLLM(retry_config=None)

        assert isinstance(llm.retry_config, RetryConfig)
        assert llm.retry_config.max_attempts == 3  # Default value

    def test_retry_config_properties(self):
        """Test that retry config properties are properly set."""
        retry_config = RetryConfig(
            max_attempts=10,
            base_delay=0.5,
            max_delay=30.0,
            backoff_multiplier=1.5,
            jitter=False,
            retryable_errors=["CustomError"],
        )

        llm = MockLLM(retry_config=retry_config)

        assert llm.retry_config.max_attempts == 10
        assert llm.retry_config.base_delay == 0.5
        assert llm.retry_config.max_delay == 30.0
        assert llm.retry_config.backoff_multiplier == 1.5
        assert llm.retry_config.jitter is False
        assert llm.retry_config.retryable_errors == ["CustomError"]

    def test_retry_config_inheritance(self):
        """Test that retry config is properly inherited by subclasses."""

        class CustomLLM(LLMInterface):
            def generate(self, _messages, **_kwargs):
                return "custom response"

            @property
            def model_info(self):
                return {"name": "custom"}

        retry_config = RetryConfig(max_attempts=7)
        llm = CustomLLM(retry_config=retry_config)

        assert llm.retry_config == retry_config
        assert llm.retry_config.max_attempts == 7


class TestLLMInterfaceCredentialHandling:
    """Tests for credential handling with retry mechanism."""

    @patch.dict(
        "os.environ",
        {"OPENAI_BASE_URL": "https://test.com", "OPENAI_API_KEY": "test-key"},
    )
    def test_load_api_info_from_env(self):
        """Test loading API info from environment variables."""
        llm = MockLLM()

        api_url, api_key = llm.load_api_info()

        assert api_url == "https://test.com"
        assert api_key == "test-key"

    @patch.dict("os.environ", {}, clear=True)
    @patch("pathlib.Path.exists")
    @patch(
        "builtins.open",
        mock_open(
            read_data='{"OPENAI_BASE_URL": "https://file.com", "OPENAI_API_KEY": "file-key"}'
        ),
    )
    def test_load_api_info_from_file(self, mock_exists):
        """Test loading API info from credentials file."""
        mock_exists.return_value = True

        llm = MockLLM()

        api_url, api_key = llm.load_api_info()

        assert api_url == "https://file.com"
        assert api_key == "file-key"

    @patch.dict("os.environ", {}, clear=True)
    @patch("pathlib.Path.exists")
    def test_load_api_info_missing_credentials(self, mock_exists):
        """Test error when credentials are missing."""
        mock_exists.return_value = False

        llm = MockLLM()

        with pytest.raises(ValueError, match="Missing API credentials"):
            llm.load_api_info()

    @patch.dict("os.environ", {}, clear=True)
    @patch("pathlib.Path.exists")
    @patch("builtins.open", mock_open(read_data="invalid json"))
    def test_load_api_info_invalid_json(self, mock_exists):
        """Test handling of invalid JSON in credentials file."""
        mock_exists.return_value = True

        llm = MockLLM()

        with pytest.raises(ValueError, match="Missing API credentials"):
            llm.load_api_info()

    @patch.dict("os.environ", {"OPENAI_BASE_URL": "https://env.com"}, clear=True)
    @patch("pathlib.Path.exists")
    @patch("builtins.open", mock_open(read_data='{"OPENAI_API_KEY": "file-key"}'))
    def test_load_api_info_mixed_sources(self, mock_exists):
        """Test loading API info from mixed sources (env + file)."""
        mock_exists.return_value = True

        llm = MockLLM()

        api_url, api_key = llm.load_api_info()

        assert api_url == "https://env.com"  # From environment
        assert api_key == "file-key"  # From file


class TestLLMInterfaceLogging:
    """Tests for logging functionality in LLMInterface."""

    def test_log_request(self):
        """Test request logging functionality."""
        llm = MockLLM()

        messages = [
            {"role": "system", "content": "You are helpful"},
            {"role": "user", "content": "Hello world"},
        ]

        with patch("logging.getLogger") as mock_get_logger:
            mock_logger = MagicMock()
            mock_get_logger.return_value = mock_logger

            llm.log_request(messages, temperature=0.7)

            # Verify logger was called
            mock_logger.debug.assert_called()
            call_args = [call[0][0] for call in mock_logger.debug.call_args_list]

            # Check that request info was logged
            assert any("Request: 2 messages" in arg for arg in call_args)
            assert any("Request parameters" in arg for arg in call_args)

    def test_log_response(self):
        """Test response logging functionality."""
        llm = MockLLM()

        response = "This is a test response"
        duration = 1.5

        with patch("logging.getLogger") as mock_get_logger:
            mock_logger = MagicMock()
            mock_get_logger.return_value = mock_logger

            llm.log_response(response, duration)

            # Verify logger was called
            mock_logger.debug.assert_called()
            call_args = [call[0][0] for call in mock_logger.debug.call_args_list]

            # Check that response info was logged
            assert any("Response:" in arg and "1.50s" in arg for arg in call_args)

    def test_log_request_long_content(self):
        """Test logging with long message content."""
        llm = MockLLM()

        long_content = "A" * 150  # 150 characters
        messages = [{"role": "user", "content": long_content}]

        with patch("logging.getLogger") as mock_get_logger:
            mock_logger = MagicMock()
            mock_get_logger.return_value = mock_logger

            llm.log_request(messages)

            # Verify that content was truncated in logs
            mock_logger.debug.assert_called()
            call_args = [call[0][0] for call in mock_logger.debug.call_args_list]

            # Should contain truncated content with "..."
            assert any("..." in arg for arg in call_args)

    def test_log_response_long_content(self):
        """Test logging with long response content."""
        llm = MockLLM()

        long_response = "B" * 200  # 200 characters
        duration = 2.0

        with patch("logging.getLogger") as mock_get_logger:
            mock_logger = MagicMock()
            mock_get_logger.return_value = mock_logger

            llm.log_response(long_response, duration)

            # Verify that content was truncated in logs
            mock_logger.debug.assert_called()
            call_args = [call[0][0] for call in mock_logger.debug.call_args_list]

            # Check that response info was logged
            assert any("Response:" in arg and "2.00s" in arg for arg in call_args)
