"""Unit tests for commercialized_reasoning_model algorithm reasoning_model class."""

import pytest
from unittest.mock import patch, MagicMock
from src.algorithms.commercialized_reasoning_model.main import reasoning_model
from src.data_models.task_config import TaskConfig


class TestCommercializedReasoningModel:
    """Test cases for commercialized_reasoning_model reasoning_model class."""
    
    def test_init_success(self):
        """Test successful initialization of reasoning_model."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1", "Check 2"],
            task_description="Test task description",
            known_solutions=["Solution 1", "Solution 2"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        assert model.task_description_text == "Test task description"
        assert model.backbone_llm_name == "gpt-4o"  # Should use the provided backbone_llm_name
        assert model.llm_client is not None
    
    def test_init_invalid_task_config(self):
        """Test initialization with invalid task config."""
        with pytest.raises(ValueError, match="task_config must be a TaskConfig object"):
            reasoning_model("invalid", "gpt-4", 10, 5, 3)
        
        with pytest.raises(ValueError, match="task_config must be a TaskConfig object"):
            reasoning_model(None, "gpt-4", 10, 5, 3)
    
    def test_init_invalid_backbone_llm_name(self):
        """Test initialization with invalid backbone LLM name."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        with pytest.raises(ValueError, match="backbone_llm_name must be a non-empty string"):
            reasoning_model(task_config, "", 10, 5, 3)
        
        with pytest.raises(ValueError, match="backbone_llm_name must be a non-empty string"):
            reasoning_model(task_config, None, 10, 5, 3)
        
        with pytest.raises(ValueError, match="backbone_llm_name must be a non-empty string"):
            reasoning_model(task_config, "   ", 10, 5, 3)
    
    def test_backbone_llm_name_dynamic(self):
        """Test that backbone_llm_name uses the provided value dynamically."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model1 = reasoning_model(task_config, "gpt-4o", 10, 5)
        model2 = reasoning_model(task_config, "claude-3.5-sonnet", 10, 5, 3)
        model3 = reasoning_model(task_config, "deepseek-reasoner", 10, 5, 3)
        
        assert model1.backbone_llm_name == "gpt-4o"
        assert model2.backbone_llm_name == "claude-3.5-sonnet"
        assert model3.backbone_llm_name == "deepseek-reasoner"
    
    def test_run_success(self):
        """Test successful execution of the reasoning model."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task description",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Mock the LLMAPIClient call_llm_model method to return string
        mock_solution = "Generated solution for the test task"
        model.llm_client.call_llm_model = MagicMock(return_value=mock_solution)
        
        solution_text, intermediate_logs = model.run()
        
        assert solution_text == mock_solution
        assert len(intermediate_logs) == 1
        assert intermediate_logs[0][0] == "LLM Call"
        assert len(intermediate_logs[0][1]) == 1
        assert intermediate_logs[0][1][0]["model_name"] == "gpt-4o"
        assert intermediate_logs[0][1][0]["temperature"] == 0.7
        assert intermediate_logs[0][1][0]["response"] == mock_solution
        model.llm_client.call_llm_model.assert_called_once_with(
            prompt="Test task description\n\nPlease provide 3 distinct creative solutions.",
            model_name="gpt-4o",
            temperature=0.7
        )
    
    def test_run_with_different_task_description(self):
        """Test execution with different task descriptions."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Another test task with different content",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "claude-3.5-sonnet", 10, 5, 3)
        
        mock_solution = "Solution for another task"
        model.llm_client.call_llm_model = MagicMock(return_value=mock_solution)
        
        solution_text, intermediate_logs = model.run()
        
        assert solution_text == mock_solution
        assert len(intermediate_logs) == 1
        assert intermediate_logs[0][0] == "LLM Call"
        assert len(intermediate_logs[0][1]) == 1
        assert intermediate_logs[0][1][0]["model_name"] == "claude-3.5-sonnet"
        assert intermediate_logs[0][1][0]["temperature"] == 0.7
        assert intermediate_logs[0][1][0]["response"] == mock_solution
        model.llm_client.call_llm_model.assert_called_once_with(
            prompt="Another test task with different content\n\nPlease provide 3 distinct creative solutions.",
            model_name="claude-3.5-sonnet",
            temperature=0.7
        )
    
    def test_run_llm_client_exception(self):
        """Test that LLMAPIClient exceptions are caught and re-raised as RuntimeError."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Mock LLMAPIClient to raise an exception
        model.llm_client.call_llm_model = MagicMock(side_effect=ValueError("Missing API key"))
        
        with pytest.raises(RuntimeError, match="Error during solution generation: Missing API key"):
            model.run()
    
    def test_run_empty_solution(self):
        """Test that empty solutions raise RuntimeError."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Mock LLMAPIClient to return empty solution
        model.llm_client.call_llm_model = MagicMock(return_value="")
        
        with pytest.raises(RuntimeError, match="Failed to generate solution"):
            model.run()
    
    def test_run_none_solution(self):
        """Test that None solutions raise RuntimeError."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Mock LLMAPIClient to return None
        model.llm_client.call_llm_model = MagicMock(return_value=None)
        
        with pytest.raises(RuntimeError, match="Failed to generate solution"):
            model.run()
    
    def test_run_llm_client_import_error(self):
        """Test handling of import errors from LLMAPIClient."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Mock LLMAPIClient to raise ImportError
        model.llm_client.call_llm_model = MagicMock(side_effect=ImportError("No module named 'openai'"))
        
        with pytest.raises(RuntimeError, match="Error during solution generation: No module named 'openai'"):
            model.run()
    
    def test_run_llm_client_api_error(self):
        """Test handling of API errors from LLMAPIClient."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Mock LLMAPIClient to raise a generic API error
        model.llm_client.call_llm_model = MagicMock(side_effect=Exception("API rate limit exceeded"))
        
        with pytest.raises(RuntimeError, match="Error during solution generation: API rate limit exceeded"):
            model.run()
    
    def test_llm_client_initialization(self):
        """Test that LLMAPIClient is properly initialized."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        # Verify that llm_client is an instance of LLMAPIClient
        from src.utils.llm_api_client import LLMAPIClient
        assert isinstance(model.llm_client, LLMAPIClient)
    
    def test_temperature_parameter_dynamic(self):
        """Test that temperature is set to 0.7 in the API call."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        mock_solution = "Test solution"
        model.llm_client.call_llm_model = MagicMock(return_value=mock_solution)
        
        model.run()
        
        # Verify temperature is 0.7
        call_args = model.llm_client.call_llm_model.call_args
        assert call_args[1]['temperature'] == 0.7
    
    def test_intermediate_logs_initialization(self):
        """Test that intermediate_logs is properly initialized."""
        task_config = TaskConfig(
            feasibility_check_points=["Check 1"],
            task_description="Test task",
            known_solutions=["Solution 1"]
        )
        model = reasoning_model(task_config, "gpt-4o", 10, 5, 3)
        
        assert hasattr(model, 'intermediate_logs')
        assert isinstance(model.intermediate_logs, list)
        assert len(model.intermediate_logs) == 0