"""
Tests for the OpenAI LLM implementation.
"""

import os
import pytest
from unittest import mock

# Mock the OpenAI module
mock_openai = mock.MagicMock()
mock.patch.dict("sys.modules", {"openai": mock_openai}).start()

from src.llm.base import LLMInterface
from src.llm.openai_llm import OpenAILLM
from src.utils.retry_utils import RetryConfig


class TestOpenAILLM:
    """Tests for the OpenAILLM class."""

    def test_init(self):
        """Test initialisation of OpenAILLM."""
        # Mock the OpenAI client and configure_environment method
        with (
            mock.patch("src.llm.openai_llm.OpenAI") as mock_openai,
            mock.patch.object(LLMInterface, "configure_environment") as mock_configure,
        ):
            # Test with default parameters
            llm = OpenAILLM(version_name="gpt-4")
            assert llm.version_name == "gpt-4"
            assert llm.temperature == 0.0
            assert llm.top_p == 1.0
            assert llm.kwargs == {}

            # Check that the OpenAI client was initialised correctly
            mock_openai.assert_called_once()

            # Test with custom parameters
            llm = OpenAILLM(
                version_name="gpt-3.5-turbo",
                temperature=0.7,
                top_p=0.9,
                extra_param="value",
            )
            assert llm.version_name == "gpt-3.5-turbo"
            assert llm.temperature == 0.7
            assert llm.top_p == 0.9
            assert llm.kwargs == {"extra_param": "value"}

    def test_init_with_env_vars(self):
        """Test initialisation with environment variables."""
        # Set environment variables
        with mock.patch.dict(
            os.environ,
            {
                "OPENAI_BASE_URL": "https://custom-api.example.com",
                "OPENAI_API_KEY": "test-api-key",
            },
        ):
            # Mock the OpenAI client and configure_environment method
            with (
                mock.patch("src.llm.openai_llm.OpenAI") as mock_openai,
                mock.patch.object(
                    LLMInterface, "configure_environment"
                ) as mock_configure,
            ):
                # initialise the LLM
                llm = OpenAILLM(version_name="gpt-4")

                # Check that the OpenAI client was initialised with the correct parameters
                mock_openai.assert_called_once_with(
                    base_url="https://custom-api.example.com", api_key="test-api-key"
                )

    def test_generate(self):
        """Test generating text with OpenAILLM."""
        # Create a mock for the OpenAI client
        mock_client = mock.MagicMock()

        # Create a mock for the non-streaming response
        mock_response = mock.MagicMock()
        mock_response.choices = [mock.MagicMock()]
        mock_response.choices[0].message = mock.MagicMock()
        mock_response.choices[0].message.content = "Hello world!"

        # Set up the mock client to return the mock response
        mock_client.beta.chat.completions.parse.return_value = mock_response

        # Create an OpenAI LLM with the mock client and configure_environment method
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            llm = OpenAILLM(version_name="gpt-4")

            # Generate text
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Hello!"},
            ]
            response = llm.generate(messages)

            # Check that the response is correct
            assert response == "Hello world!"

            # Check that the client was called with the correct parameters
            mock_client.beta.chat.completions.parse.assert_called_once_with(
                model="gpt-4", messages=messages, temperature=0.0, top_p=1.0
            )

    def test_generate_with_kwargs(self):
        """Test generating text with additional kwargs."""
        # Create a mock for the OpenAI client
        mock_client = mock.MagicMock()

        # Create a mock for the non-streaming response
        mock_response = mock.MagicMock()
        mock_response.choices = [mock.MagicMock()]
        mock_response.choices[0].message = mock.MagicMock()
        mock_response.choices[0].message.content = "Response"

        # Set up the mock client to return the mock response
        mock_client.beta.chat.completions.parse.return_value = mock_response

        # Create an OpenAI LLM with the mock client and configure_environment method
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            llm = OpenAILLM(version_name="gpt-4", temperature=0.5)

            # Generate text with additional kwargs
            messages = [{"role": "user", "content": "Hello!"}]
            response = llm.generate(messages, temperature=0.8, max_tokens=100)

            # Check that the response is correct
            assert response == "Response"

            # Check that the client was called with the correct parameters
            # (kwargs should override the default values)
            mock_client.beta.chat.completions.parse.assert_called_once_with(
                model="gpt-4",
                messages=messages,
                temperature=0.8,  # Overridden from 0.5
                top_p=1.0,
                max_tokens=100,  # Additional kwarg
            )

    def test_model_info(self):
        """Test getting model information."""
        # Mock the OpenAI client and configure_environment method
        with (
            mock.patch("src.llm.openai_llm.OpenAI"),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            # Create an OpenAI LLM with custom parameters
            llm = OpenAILLM(
                version_name="gpt-4", temperature=0.7, top_p=0.9, extra_param="value"
            )

            # Get the model info
            info = llm.model_info

            # Check the info
            assert info["name"] == "OpenAI"
            assert info["version_name"] == "gpt-4"
            assert info["temperature"] == 0.7
            assert info["top_p"] == 0.9
            assert info["extra_param"] == "value"


class TestOpenAILLMRetryMechanism:
    """Tests for the retry mechanism in OpenAILLM."""

    def test_init_with_retry_config(self):
        """Test initialisation with custom retry configuration."""
        retry_config = RetryConfig(max_attempts=5, base_delay=2.0)
        
        with (
            mock.patch("src.llm.openai_llm.OpenAI") as mock_openai,
            mock.patch.object(LLMInterface, "configure_environment") as mock_configure,
            mock.patch("src.llm.openai_llm.create_retry_decorator") as mock_create_retry,
        ):
            llm = OpenAILLM(version_name="gpt-4", retry_config=retry_config)
            
            assert llm.retry_config == retry_config
            mock_create_retry.assert_called_once_with(retry_config, logger=mock.ANY)

    def test_init_with_default_retry_config(self):
        """Test initialisation with default retry configuration."""
        with (
            mock.patch("src.llm.openai_llm.OpenAI") as mock_openai,
            mock.patch.object(LLMInterface, "configure_environment") as mock_configure,
            mock.patch("src.llm.openai_llm.create_retry_decorator") as mock_create_retry,
        ):
            llm = OpenAILLM(version_name="gpt-4")
            
            assert isinstance(llm.retry_config, RetryConfig)
            assert llm.retry_config.max_attempts == 3  # Default value
            mock_create_retry.assert_called_once()

    def test_generate_with_retry_success_on_first_attempt(self):
        """Test generate method succeeds on first attempt."""
        mock_client = mock.MagicMock()
        mock_response = mock.MagicMock()
        mock_response.choices = [mock.MagicMock()]
        mock_response.choices[0].message = mock.MagicMock()
        mock_response.choices[0].message.content = "Success!"
        
        mock_client.beta.chat.completions.parse.return_value = mock_response
        
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            llm = OpenAILLM(version_name="gpt-4")
            
            messages = [{"role": "user", "content": "Hello!"}]
            response = llm.generate(messages)
            
            assert response == "Success!"
            assert mock_client.beta.chat.completions.parse.call_count == 1

    def test_generate_with_retry_success_after_failure(self):
        """Test generate method succeeds after retryable failure."""
        mock_client = mock.MagicMock()
        
        # Mock a retryable error followed by success
        class RateLimitError(Exception):
            pass
        
        mock_response = mock.MagicMock()
        mock_response.choices = [mock.MagicMock()]
        mock_response.choices[0].message = mock.MagicMock()
        mock_response.choices[0].message.content = "Success after retry!"
        
        # First call raises error, second call succeeds
        mock_client.beta.chat.completions.parse.side_effect = [
            RateLimitError("Rate limit exceeded"),
            mock_response
        ]
        
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            retry_config = RetryConfig(max_attempts=3, base_delay=0.01, jitter=False)
            llm = OpenAILLM(version_name="gpt-4", retry_config=retry_config)
            
            messages = [{"role": "user", "content": "Hello!"}]
            response = llm.generate(messages)
            
            assert response == "Success after retry!"
            assert mock_client.beta.chat.completions.parse.call_count == 2

    def test_generate_with_non_retryable_error(self):
        """Test generate method with non-retryable error."""
        mock_client = mock.MagicMock()
        
        class AuthenticationError(Exception):
            pass
        
        mock_client.beta.chat.completions.parse.side_effect = AuthenticationError("Invalid API key")
        
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            llm = OpenAILLM(version_name="gpt-4")
            
            messages = [{"role": "user", "content": "Hello!"}]
            
            with pytest.raises(AuthenticationError):
                llm.generate(messages)
            
            # Should only be called once since error is not retryable
            assert mock_client.beta.chat.completions.parse.call_count == 1

    def test_generate_exceeds_max_retries(self):
        """Test generate method when max retries are exceeded."""
        mock_client = mock.MagicMock()
        
        class RateLimitError(Exception):
            pass
        
        # Always fail with retryable error
        mock_client.beta.chat.completions.parse.side_effect = RateLimitError("Rate limit exceeded")
        
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
        ):
            retry_config = RetryConfig(max_attempts=2, base_delay=0.01, jitter=False)
            llm = OpenAILLM(version_name="gpt-4", retry_config=retry_config)
            
            messages = [{"role": "user", "content": "Hello!"}]
            
            with pytest.raises(RateLimitError):
                llm.generate(messages)
            
            # Should be called max_attempts times
            assert mock_client.beta.chat.completions.parse.call_count == 2

    def test_generate_with_custom_error_classifier(self):
        """Test generate method with custom error handling."""
        mock_client = mock.MagicMock()
        
        class CustomError(Exception):
            pass
        
        mock_client.beta.chat.completions.parse.side_effect = CustomError("Custom error")
        
        with (
            mock.patch("src.llm.openai_llm.OpenAI", return_value=mock_client),
            mock.patch.object(LLMInterface, "configure_environment"),
            mock.patch("src.llm.openai_llm.create_retry_decorator") as mock_create_retry,
        ):
            # Mock retry decorator that doesn't actually retry for this test
            mock_retry_decorator = mock.MagicMock()
            mock_retry_decorator.__iter__ = mock.MagicMock(return_value=iter([mock.MagicMock()]))
            mock_create_retry.return_value = mock_retry_decorator
            
            llm = OpenAILLM(version_name="gpt-4")
            
            messages = [{"role": "user", "content": "Hello!"}]
            
            with pytest.raises(CustomError):
                llm.generate(messages)
