"""Unit tests for the chain_of_thoughts algorithm."""

import unittest
from unittest.mock import Mock, patch, MagicMock
import datetime
import uuid
from typing import List, Tuple, Dict

from src.algorithms.chain_of_thoughts.main import reasoning_model
from src.data_models.task_config import TaskConfig


class TestChainOfThoughtsReasoningModel(unittest.TestCase):
    """Test cases for the chain_of_thoughts reasoning_model class."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.task_config = TaskConfig(
            feasibility_check_points=["Check 1", "Check 2"],
            task_description="Test task description",
            known_solutions=["Solution 1", "Solution 2"]
        )
        self.backbone_llm_name = "gpt-4"
        self.num_analogous_problems = 10
        self.num_solutions_per_problem = 5
        self.num_final_solutions = 3
    
    def test_init_valid_parameters(self):
        """Test that __init__ correctly initializes with valid parameters."""
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        # Test that backbone_llm_name is correctly stored (not hardcoded)
        self.assertEqual(model.backbone_llm_name, self.backbone_llm_name)
        self.assertNotEqual(model.backbone_llm_name, 'o1')
        
        # Test that num_final_solutions is correctly stored
        self.assertEqual(model.num_final_solutions, self.num_final_solutions)
        
        # Test that task_description_text is correctly stored
        self.assertEqual(model.task_description_text, self.task_config.task_description)
        
        # Test that intermediate_logs is initialized as List[Tuple[str, List[Dict]]]
        self.assertIsInstance(model.intermediate_logs, list)
        self.assertEqual(len(model.intermediate_logs), 0)
        
        # Test that llm_client is initialized
        self.assertIsNotNone(model.llm_client)
    
    def test_init_invalid_task_config(self):
        """Test that __init__ raises ValueError for invalid task_config."""
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                "invalid_config",
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
        self.assertIn("task_config must be a TaskConfig object", str(context.exception))
    
    def test_init_invalid_backbone_llm_name(self):
        """Test that __init__ raises ValueError for invalid backbone_llm_name."""
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                "",
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
        self.assertIn("backbone_llm_name must be a non-empty string", str(context.exception))
        
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                "   ",
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
        self.assertIn("backbone_llm_name must be a non-empty string", str(context.exception))
    
    def test_init_invalid_num_final_solutions(self):
        """Test that __init__ raises ValueError for invalid num_final_solutions."""
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                0
            )
        self.assertIn("num_final_solutions must be a positive integer", str(context.exception))
        
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                -1
            )
        self.assertIn("num_final_solutions must be a positive integer", str(context.exception))
    
    @patch('src.algorithms.chain_of_thoughts.main.LLMAPIClient')
    @patch('src.algorithms.chain_of_thoughts.main.uuid.uuid4')
    @patch('src.algorithms.chain_of_thoughts.main.datetime')
    def test_run_successful_generation(self, mock_datetime, mock_uuid, mock_llm_client_class):
        """Test successful solution generation with proper logging."""
        # Setup mocks
        mock_uuid.return_value = "test-uuid-123"
        mock_datetime.datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00Z"
        
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = "Test solution with reasoning steps"
        mock_llm_client_class.return_value = mock_llm_client
        
        # Create model
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        # Run the model
        solution_text, intermediate_logs = model.run()
        
        # Verify solution text
        self.assertEqual(solution_text, "Test solution with reasoning steps")
        
        # Verify intermediate logs structure
        self.assertIsInstance(intermediate_logs, list)
        self.assertEqual(len(intermediate_logs), 1)
        
        step_name, log_entries = intermediate_logs[0]
        self.assertEqual(step_name, "Chain-of-Thought Generation")
        self.assertIsInstance(log_entries, list)
        self.assertEqual(len(log_entries), 1)
        
        # Verify log entry structure
        log_entry = log_entries[0]
        self.assertIsInstance(log_entry, dict)
        self.assertEqual(log_entry["llm_call_id"], "test-uuid-123")
        self.assertEqual(log_entry["llm_model_name"], self.backbone_llm_name)
        self.assertEqual(log_entry["temperature"], 0.7)
        self.assertEqual(log_entry["timestamp"], "2024-01-01T12:00:00Z")
        self.assertEqual(log_entry["raw_response"], "Test solution with reasoning steps")
        self.assertIsNone(log_entry["parsed_output"])
        
        # Verify prompt contains chain-of-thought instruction and num_final_solutions
        prompt = log_entry["prompt"]
        self.assertIn("Think step-by-step", prompt)
        self.assertIn("outline your reasoning", prompt)
        self.assertIn("provide the final solution", prompt)
        self.assertIn(self.task_config.task_description, prompt)
        self.assertIn(f"{self.num_final_solutions} distinct creative solutions", prompt)
        
        # Verify LLM call parameters
        mock_llm_client.call_openai.assert_called_once()
        call_args = mock_llm_client.call_openai.call_args
        self.assertEqual(call_args.kwargs["model_name"], self.backbone_llm_name)
        self.assertEqual(call_args.kwargs["temperature"], 0.7)
        self.assertEqual(call_args.kwargs["prompt"], prompt)
    
    @patch('src.algorithms.chain_of_thoughts.main.LLMAPIClient')
    def test_run_empty_response(self, mock_llm_client_class):
        """Test that run raises RuntimeError for empty LLM response."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = ""
        mock_llm_client_class.return_value = mock_llm_client
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        with self.assertRaises(RuntimeError) as context:
            model.run()
        self.assertIn("Failed to generate solution - empty response from LLM", str(context.exception))
    
    @patch('src.algorithms.chain_of_thoughts.main.LLMAPIClient')
    def test_run_whitespace_only_response(self, mock_llm_client_class):
        """Test that run raises RuntimeError for whitespace-only LLM response."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = "   \n\t   "
        mock_llm_client_class.return_value = mock_llm_client
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        with self.assertRaises(RuntimeError) as context:
            model.run()
        self.assertIn("Failed to generate solution - empty response from LLM", str(context.exception))
    
    @patch('src.algorithms.chain_of_thoughts.main.LLMAPIClient')
    def test_run_llm_call_exception(self, mock_llm_client_class):
        """Test that run raises RuntimeError when LLM call fails."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.side_effect = Exception("API Error")
        mock_llm_client_class.return_value = mock_llm_client
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        with self.assertRaises(RuntimeError) as context:
            model.run()
        self.assertIn("Error during solution generation: API Error", str(context.exception))
    
    def test_run_return_type(self):
        """Test that run returns the correct tuple type."""
        with patch('src.algorithms.chain_of_thoughts.main.LLMAPIClient') as mock_llm_client_class:
            mock_llm_client = Mock()
            mock_llm_client.call_openai.return_value = "Test solution"
            mock_llm_client_class.return_value = mock_llm_client
            
            model = reasoning_model(
                self.task_config,
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
            
            result = model.run()
            self.assertIsInstance(result, tuple)
            self.assertEqual(len(result), 2)
            
            solution_text, intermediate_logs = result
            self.assertIsInstance(solution_text, str)
            self.assertIsInstance(intermediate_logs, list)
            
            # Verify intermediate_logs structure
            for step_name, log_entries in intermediate_logs:
                self.assertIsInstance(step_name, str)
                self.assertIsInstance(log_entries, list)
                for log_entry in log_entries:
                    self.assertIsInstance(log_entry, dict)


if __name__ == '__main__':
    unittest.main()
