"""Unit tests for the Tree of Thoughts (ToT) algorithm."""

import unittest
from unittest.mock import Mock, patch, MagicMock
import datetime
import uuid
from typing import List, Tuple, Dict

from src.algorithms.tot.main import reasoning_model
from data_models.task_config import TaskConfig


class TestTreeOfThoughtsReasoningModel(unittest.TestCase):
    """Test cases for the Tree of Thoughts reasoning_model class."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.task_config = TaskConfig(
            feasibility_check_points=["Check 1", "Check 2"],
            task_description="Test task description",
            known_solutions=["Solution 1", "Solution 2"]
        )
        self.backbone_llm_name = "gpt-4"
        self.num_analogous_problems = 10
        self.num_solutions_per_problem = 5
        self.num_final_solutions = 3
    
    def test_init_valid_parameters(self):
        """Test that __init__ correctly initializes with valid parameters."""
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        # Test that backbone_llm_name is correctly stored
        self.assertEqual(model.backbone_llm_name, self.backbone_llm_name)
        
        # Test that num_final_solutions is correctly stored
        self.assertEqual(model.num_final_solutions, self.num_final_solutions)
        
        # Test that task_description_text is correctly stored
        self.assertEqual(model.task_description_text, self.task_config.task_description)
        
        # Test ToT-specific parameters
        self.assertEqual(model.num_thoughts_per_step, 5)
        self.assertEqual(model.search_depth, 2)
        
        # Test that intermediate_logs is initialized as List[Tuple[str, List[Dict]]]
        self.assertIsInstance(model.intermediate_logs, list)
        self.assertEqual(len(model.intermediate_logs), 0)
        
        # Test that llm_client is initialized
        self.assertIsNotNone(model.llm_client)
    
    def test_init_invalid_task_config(self):
        """Test that __init__ raises ValueError for invalid task_config."""
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                "invalid_config",
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
        self.assertIn("task_config must be a TaskConfig object", str(context.exception))
    
    def test_init_invalid_backbone_llm_name(self):
        """Test that __init__ raises ValueError for invalid backbone_llm_name."""
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                "",
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
        self.assertIn("backbone_llm_name must be a non-empty string", str(context.exception))
        
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                "   ",
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                self.num_final_solutions
            )
        self.assertIn("backbone_llm_name must be a non-empty string", str(context.exception))
    
    def test_init_invalid_num_final_solutions(self):
        """Test that __init__ raises ValueError for invalid num_final_solutions."""
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                0
            )
        self.assertIn("num_final_solutions must be a positive integer", str(context.exception))
        
        with self.assertRaises(ValueError) as context:
            reasoning_model(
                self.task_config,
                self.backbone_llm_name,
                self.num_analogous_problems,
                self.num_solutions_per_problem,
                -1
            )
        self.assertIn("num_final_solutions must be a positive integer", str(context.exception))
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    @patch('src.algorithms.tot.main.uuid.uuid4')
    @patch('src.algorithms.tot.main.datetime')
    @patch('src.algorithms.tot.main.extract_json_from_response')
    def test_run_successful_generation(self, mock_extract_json, mock_datetime, mock_uuid, mock_llm_client_class):
        """Test successful solution generation with proper logging."""
        # Setup mocks
        mock_uuid.return_value = "test-uuid-123"
        mock_datetime.datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00Z"
        
        # Mock LLM responses for tree-based thought generation, voting, and final solution
        mock_llm_client = Mock()
        mock_llm_client.call_openai.side_effect = [
            # Root level (Level 0)
            '["root_thought1", "root_thought2", "root_thought3", "root_thought4", "root_thought5"]',  # Root thoughts
            '[1, 3, 5]',  # Root voting (selects 3 thoughts)
            
            # Level 1 - Node 1 (from root_thought1)
            '["l1n1_thought1", "l1n1_thought2", "l1n1_thought3", "l1n1_thought4", "l1n1_thought5"]',  # Level 1 Node 1 thoughts
            '[2, 4]',  # Level 1 Node 1 voting (selects 2 thoughts)
            
            # Level 1 - Node 2 (from root_thought3)
            '["l1n2_thought1", "l1n2_thought2", "l1n2_thought3", "l1n2_thought4", "l1n2_thought5"]',  # Level 1 Node 2 thoughts
            '[1, 3]',  # Level 1 Node 2 voting (selects 2 thoughts)
            
            # Level 1 - Node 3 (from root_thought5)
            '["l1n3_thought1", "l1n3_thought2", "l1n3_thought3", "l1n3_thought4", "l1n3_thought5"]',  # Level 1 Node 3 thoughts
            '[2, 5]',  # Level 1 Node 3 voting (selects 2 thoughts)
            
            "Final solution text"  # Final solution
        ]
        mock_llm_client_class.return_value = mock_llm_client
        
        # Mock JSON extraction
        mock_extract_json.side_effect = [
            # Root level
            ["root_thought1", "root_thought2", "root_thought3", "root_thought4", "root_thought5"],  # Root thoughts
            [1, 3, 5],  # Root voting
            
            # Level 1
            ["l1n1_thought1", "l1n1_thought2", "l1n1_thought3", "l1n1_thought4", "l1n1_thought5"],  # Level 1 Node 1 thoughts
            [2, 4],  # Level 1 Node 1 voting
            ["l1n2_thought1", "l1n2_thought2", "l1n2_thought3", "l1n2_thought4", "l1n2_thought5"],  # Level 1 Node 2 thoughts
            [1, 3],  # Level 1 Node 2 voting
            ["l1n3_thought1", "l1n3_thought2", "l1n3_thought3", "l1n3_thought4", "l1n3_thought5"],  # Level 1 Node 3 thoughts
            [2, 5],  # Level 1 Node 3 voting
        ]
        
        # Create model
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        # Run the model
        solution_text, intermediate_logs = model.run()
        
        # Verify solution text
        self.assertEqual(solution_text, "Final solution text")
        
        # Verify intermediate logs structure - should have 9 entries (1 root gen + 1 root vote + 3 level1 gen + 3 level1 vote + 1 final)
        self.assertIsInstance(intermediate_logs, list)
        self.assertEqual(len(intermediate_logs), 9)
        
        # Verify step names
        expected_step_names = [
            "ToT: Root Thought Generation",
            "ToT: Root Thought Voting",
            "ToT: Thought Generation (Level 2, Node 1)",
            "ToT: Thought Voting (Level 2, Node 1)",
            "ToT: Thought Generation (Level 2, Node 2)",
            "ToT: Thought Voting (Level 2, Node 2)",
            "ToT: Thought Generation (Level 2, Node 3)",
            "ToT: Thought Voting (Level 2, Node 3)",
            "ToT: Final Solution Generation"
        ]
        
        for i, (step_name, log_entries) in enumerate(intermediate_logs):
            self.assertEqual(step_name, expected_step_names[i])
            self.assertIsInstance(log_entries, list)
            self.assertEqual(len(log_entries), 1)
            
            # Verify log entry structure
            log_entry = log_entries[0]
            self.assertIsInstance(log_entry, dict)
            self.assertEqual(log_entry["llm_call_id"], "test-uuid-123")
            self.assertEqual(log_entry["llm_model_name"], self.backbone_llm_name)
            self.assertEqual(log_entry["temperature"], 0.7)
            self.assertEqual(log_entry["timestamp"], "2024-01-01T12:00:00Z")
        
        # Verify LLM was called 9 times (1 root gen + 1 root vote + 3 level1 gen + 3 level1 vote + 1 final)
        self.assertEqual(mock_llm_client.call_openai.call_count, 9)
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    @patch('src.algorithms.tot.main.extract_json_from_response')
    def test_generate_thoughts_successful(self, mock_extract_json, mock_llm_client_class):
        """Test successful thought generation."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = '["thought1", "thought2", "thought3", "thought4", "thought5"]'
        mock_llm_client_class.return_value = mock_llm_client
        
        mock_extract_json.return_value = ["thought1", "thought2", "thought3", "thought4", "thought5"]
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        thoughts = model._generate_thoughts("test context", 0)
        
        # Verify thoughts
        self.assertEqual(len(thoughts), 5)
        self.assertEqual(thoughts, ["thought1", "thought2", "thought3", "thought4", "thought5"])
        
        # Verify LLM call
        mock_llm_client.call_openai.assert_called_once()
        call_args = mock_llm_client.call_openai.call_args
        self.assertEqual(call_args.kwargs["model_name"], self.backbone_llm_name)
        self.assertEqual(call_args.kwargs["temperature"], 0.7)
        self.assertIn("5 diverse and creative thoughts", call_args.kwargs["prompt"])
        self.assertIn("test context", call_args.kwargs["prompt"])
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    @patch('src.algorithms.tot.main.extract_json_from_response')
    def test_generate_thoughts_parsing_failure(self, mock_extract_json, mock_llm_client_class):
        """Test thought generation with parsing failure fallback."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = "Invalid response"
        mock_llm_client_class.return_value = mock_llm_client
        
        mock_extract_json.side_effect = Exception("Parse error")
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        thoughts = model._generate_thoughts("test context", 0)
        
        # Verify fallback thoughts
        self.assertEqual(len(thoughts), 5)
        self.assertEqual(thoughts, ["Generated thought 1", "Generated thought 2", "Generated thought 3", "Generated thought 4", "Generated thought 5"])
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    @patch('src.algorithms.tot.main.extract_json_from_response')
    def test_vote_and_select_thoughts_successful(self, mock_extract_json, mock_llm_client_class):
        """Test successful thought voting and selection."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = '[1, 3, 5]'
        mock_llm_client_class.return_value = mock_llm_client
        
        mock_extract_json.return_value = [1, 3, 5]
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        thoughts = ["thought1", "thought2", "thought3", "thought4", "thought5"]
        selected_thoughts = model._vote_and_select_thoughts(thoughts, "test context", 0)
        
        # Verify selected thoughts (indices 1, 3, 5 -> 0, 2, 4 in 0-based)
        self.assertEqual(len(selected_thoughts), 3)
        self.assertEqual(selected_thoughts, ["thought1", "thought3", "thought5"])
        
        # Verify LLM call
        mock_llm_client.call_openai.assert_called_once()
        call_args = mock_llm_client.call_openai.call_args
        self.assertEqual(call_args.kwargs["model_name"], self.backbone_llm_name)
        self.assertEqual(call_args.kwargs["temperature"], 0.7)
        self.assertIn("3 most promising thoughts", call_args.kwargs["prompt"])
        self.assertIn("test context", call_args.kwargs["prompt"])
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    @patch('src.algorithms.tot.main.extract_json_from_response')
    def test_vote_and_select_thoughts_parsing_failure(self, mock_extract_json, mock_llm_client_class):
        """Test thought voting with parsing failure fallback."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = "Invalid response"
        mock_llm_client_class.return_value = mock_llm_client
        
        mock_extract_json.side_effect = Exception("Parse error")
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        thoughts = ["thought1", "thought2", "thought3", "thought4", "thought5"]
        selected_thoughts = model._vote_and_select_thoughts(thoughts, "test context", 0)
        
        # Verify fallback selection (first 3 thoughts)
        self.assertEqual(len(selected_thoughts), 3)
        self.assertEqual(selected_thoughts, ["thought1", "thought2", "thought3"])
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    def test_generate_final_solutions_successful(self, mock_llm_client_class):
        """Test successful final solution generation."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = "Final solution text"
        mock_llm_client_class.return_value = mock_llm_client
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        accumulated_thoughts = ["thought1", "thought2", "thought3"]
        final_solution = model._generate_final_solutions(accumulated_thoughts)
        
        # Verify final solution
        self.assertEqual(final_solution, "Final solution text")
        
        # Verify LLM call
        mock_llm_client.call_openai.assert_called_once()
        call_args = mock_llm_client.call_openai.call_args
        self.assertEqual(call_args.kwargs["model_name"], self.backbone_llm_name)
        self.assertEqual(call_args.kwargs["temperature"], 0.7)
        self.assertIn("3 distinct creative solutions", call_args.kwargs["prompt"])
        self.assertIn("thought1", call_args.kwargs["prompt"])
        self.assertIn("thought2", call_args.kwargs["prompt"])
        self.assertIn("thought3", call_args.kwargs["prompt"])
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    def test_generate_final_solutions_empty_response(self, mock_llm_client_class):
        """Test final solution generation with empty response."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.return_value = ""
        mock_llm_client_class.return_value = mock_llm_client
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        accumulated_thoughts = ["thought1", "thought2", "thought3"]
        
        with self.assertRaises(RuntimeError) as context:
            model._generate_final_solutions(accumulated_thoughts)
        self.assertIn("Failed to generate final solutions - empty response from LLM", str(context.exception))
    
    @patch('src.algorithms.tot.main.LLMAPIClient')
    def test_run_llm_call_exception(self, mock_llm_client_class):
        """Test that run raises RuntimeError when LLM call fails."""
        mock_llm_client = Mock()
        mock_llm_client.call_openai.side_effect = Exception("API Error")
        mock_llm_client_class.return_value = mock_llm_client
        
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        with self.assertRaises(RuntimeError) as context:
            model.run()
        self.assertIn("Error during solution generation", str(context.exception))
        self.assertIn("API Error", str(context.exception))
    
    def test_run_return_type(self):
        """Test that run returns the correct tuple type."""
        with patch('src.algorithms.tot.main.LLMAPIClient') as mock_llm_client_class:
            mock_llm_client = Mock()
            mock_llm_client.call_openai.side_effect = [
                # Root level
                '["root_thought1", "root_thought2", "root_thought3", "root_thought4", "root_thought5"]',  # Root thoughts
                '[1, 3, 5]',  # Root voting
                
                # Level 1 - simplified for this test
                '["l1_thought1", "l1_thought2", "l1_thought3", "l1_thought4", "l1_thought5"]',  # Level 1 thoughts
                '[2, 4]',  # Level 1 voting
                '["l1_thought6", "l1_thought7", "l1_thought8", "l1_thought9", "l1_thought10"]',  # Level 1 thoughts
                '[1, 3]',  # Level 1 voting
                '["l1_thought11", "l1_thought12", "l1_thought13", "l1_thought14", "l1_thought15"]',  # Level 1 thoughts
                '[2, 5]',  # Level 1 voting
                
                # Level 2 - simplified for this test
                '["l2_thought1", "l2_thought2", "l2_thought3", "l2_thought4", "l2_thought5"]',  # Level 2 thoughts
                '[1, 3]',  # Level 2 voting
                '["l2_thought6", "l2_thought7", "l2_thought8", "l2_thought9", "l2_thought10"]',  # Level 2 thoughts
                '[2, 4]',  # Level 2 voting
                '["l2_thought11", "l2_thought12", "l2_thought13", "l2_thought14", "l2_thought15"]',  # Level 2 thoughts
                '[1, 5]',  # Level 2 voting
                '["l2_thought16", "l2_thought17", "l2_thought18", "l2_thought19", "l2_thought20"]',  # Level 2 thoughts
                '[3, 4]',  # Level 2 voting
                '["l2_thought21", "l2_thought22", "l2_thought23", "l2_thought24", "l2_thought25"]',  # Level 2 thoughts
                '[1, 2]',  # Level 2 voting
                '["l2_thought26", "l2_thought27", "l2_thought28", "l2_thought29", "l2_thought30"]',  # Level 2 thoughts
                '[2, 3]',  # Level 2 voting
                
                "Final solution text"  # Final solution
            ]
            mock_llm_client_class.return_value = mock_llm_client
            
            with patch('src.algorithms.tot.main.extract_json_from_response') as mock_extract_json:
                mock_extract_json.side_effect = [
                    # Root level
                    ["root_thought1", "root_thought2", "root_thought3", "root_thought4", "root_thought5"],  # Root thoughts
                    [1, 3, 5],  # Root voting
                    
                    # Level 1
                    ["l1_thought1", "l1_thought2", "l1_thought3", "l1_thought4", "l1_thought5"],  # Level 1 thoughts
                    [2, 4],  # Level 1 voting
                    ["l1_thought6", "l1_thought7", "l1_thought8", "l1_thought9", "l1_thought10"],  # Level 1 thoughts
                    [1, 3],  # Level 1 voting
                    ["l1_thought11", "l1_thought12", "l1_thought13", "l1_thought14", "l1_thought15"],  # Level 1 thoughts
                    [2, 5],  # Level 1 voting
                    
                    # Level 2
                    ["l2_thought1", "l2_thought2", "l2_thought3", "l2_thought4", "l2_thought5"],  # Level 2 thoughts
                    [1, 3],  # Level 2 voting
                    ["l2_thought6", "l2_thought7", "l2_thought8", "l2_thought9", "l2_thought10"],  # Level 2 thoughts
                    [2, 4],  # Level 2 voting
                    ["l2_thought11", "l2_thought12", "l2_thought13", "l2_thought14", "l2_thought15"],  # Level 2 thoughts
                    [1, 5],  # Level 2 voting
                    ["l2_thought16", "l2_thought17", "l2_thought18", "l2_thought19", "l2_thought20"],  # Level 2 thoughts
                    [3, 4],  # Level 2 voting
                    ["l2_thought21", "l2_thought22", "l2_thought23", "l2_thought24", "l2_thought25"],  # Level 2 thoughts
                    [1, 2],  # Level 2 voting
                    ["l2_thought26", "l2_thought27", "l2_thought28", "l2_thought29", "l2_thought30"],  # Level 2 thoughts
                    [2, 3],  # Level 2 voting
                ]
                
                model = reasoning_model(
                    self.task_config,
                    self.backbone_llm_name,
                    self.num_analogous_problems,
                    self.num_solutions_per_problem,
                    self.num_final_solutions
                )
                
                result = model.run()
                self.assertIsInstance(result, tuple)
                self.assertEqual(len(result), 2)
                
                solution_text, intermediate_logs = result
                self.assertIsInstance(solution_text, str)
                self.assertIsInstance(intermediate_logs, list)
                
                # Verify intermediate_logs structure
                for step_name, log_entries in intermediate_logs:
                    self.assertIsInstance(step_name, str)
                    self.assertIsInstance(log_entries, list)
                    for log_entry in log_entries:
                        self.assertIsInstance(log_entry, dict)
    
    def test_search_depth_parameter(self):
        """Test that the search depth parameter is correctly set."""
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        # Verify ToT-specific parameters
        self.assertEqual(model.search_depth, 2)
        self.assertEqual(model.num_thoughts_per_step, 5)
    
    def test_num_thoughts_per_step_parameter(self):
        """Test that the number of thoughts per step parameter is correctly set."""
        model = reasoning_model(
            self.task_config,
            self.backbone_llm_name,
            self.num_analogous_problems,
            self.num_solutions_per_problem,
            self.num_final_solutions
        )
        
        # Verify ToT-specific parameters
        self.assertEqual(model.num_thoughts_per_step, 5)
        self.assertEqual(model.search_depth, 2)


if __name__ == '__main__':
    unittest.main()
