"""Unit tests for the Enhanced Graph of Thoughts (EGoT) algorithm."""

import pytest
import math
from unittest.mock import Mock, patch, MagicMock
from src.algorithms.egot.main import reasoning_model
from src.data_models.task_config import TaskConfig


class TestEGoTReasoningModel:
    """Test cases for the EGoT reasoning_model class."""
    
    @pytest.fixture
    def mock_task_config(self):
        """Create a mock TaskConfig for testing."""
        return TaskConfig(
            feasibility_check_points=["Check 1", "Check 2"],
            task_description="Test task description",
            known_solutions=["Solution 1", "Solution 2"]
        )
    
    @pytest.fixture
    def mock_llm_client(self):
        """Create a mock LLM client for testing."""
        mock_client = Mock()
        mock_client.call_openai = Mock()
        return mock_client
    
    def test_init_valid_parameters(self, mock_task_config):
        """Test initialization with valid parameters."""
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        # Check that required parameters are stored
        assert model.task_config == mock_task_config
        assert model.backbone_llm_name == "gemini-pro"
        assert model.num_final_solutions == 3
        
        # Check EGoT-specific parameters have default values
        assert model.graph_depth == 3
        assert model.num_root_nodes == 3
        assert model.tmax == 0.7
        assert model.threshold_extreme == 70
        assert model.threshold_normal == 50
        assert model.e == 2.718
        
        # Check intermediate_logs is initialized as empty list
        assert model.intermediate_logs == []
    
    def test_init_invalid_task_config(self):
        """Test initialization with invalid task_config."""
        with pytest.raises(ValueError, match="task_config must be a TaskConfig object"):
            reasoning_model(
                task_config="invalid",
                backbone_llm_name="gemini-pro",
                num_analogous_problems=10,
                num_solutions_per_problem=5,
                num_exploratory_ideas=50,
                num_new_rule_sets=3,
                num_final_solutions=3,
                num_solutions_combinational=10,
                num_thoughts_per_step=10,
                search_depth=3
            )
    
    def test_init_invalid_backbone_llm_name(self, mock_task_config):
        """Test initialization with invalid backbone_llm_name."""
        with pytest.raises(ValueError, match="backbone_llm_name must be a non-empty string"):
            reasoning_model(
                task_config=mock_task_config,
                backbone_llm_name="",
                num_analogous_problems=10,
                num_solutions_per_problem=5,
                num_exploratory_ideas=50,
                num_new_rule_sets=3,
                num_final_solutions=3,
                num_solutions_combinational=10,
                num_thoughts_per_step=10,
                search_depth=3
            )
    
    def test_init_invalid_num_final_solutions(self, mock_task_config):
        """Test initialization with invalid num_final_solutions."""
        with pytest.raises(ValueError, match="num_final_solutions must be a positive integer"):
            reasoning_model(
                task_config=mock_task_config,
                backbone_llm_name="gemini-pro",
                num_analogous_problems=10,
                num_solutions_per_problem=5,
                num_exploratory_ideas=50,
                num_new_rule_sets=3,
                num_final_solutions=0,
                num_solutions_combinational=10,
                num_thoughts_per_step=10,
                search_depth=3
            )
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_method_node(self, mock_llm_client_class, mock_task_config):
        """Test METHODNODE execution."""
        # Setup mock
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.return_value = '{"ma": "method analysis", "me": "method explanation"}'
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        ma, me = model._method_node("test prompt", 0)
        
        assert ma == "method analysis"
        assert me == "method explanation"
        mock_client.call_openai.assert_called_once()
        call_args = mock_client.call_openai.call_args
        assert call_args[1]['temperature'] == 0
        assert "test prompt" in call_args[1]['prompt']
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_answering_node(self, mock_llm_client_class, mock_task_config):
        """Test ANSWERINGNODE execution."""
        # Setup mock
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.return_value = '{"a": "answer", "ra": "rationale"}'
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        a, ra = model._answering_node("test prompt", 0.5)
        
        assert a == "answer"
        assert ra == "rationale"
        mock_client.call_openai.assert_called_once()
        call_args = mock_client.call_openai.call_args
        assert call_args[1]['temperature'] == 0.5
        assert "test prompt" in call_args[1]['prompt']
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_evaluation_node(self, mock_llm_client_class, mock_task_config):
        """Test EVALUATIONNODE execution."""
        # Setup mock
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.return_value = '{"s": 85, "rs": "reasoning", "Pr(s)": 0.8}'
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        s, rs, pr_s = model._evaluation_node("test prompt", 0)
        
        assert s == 85
        assert rs == "reasoning"
        assert pr_s == 0.8
        mock_client.call_openai.assert_called_once()
        call_args = mock_client.call_openai.call_args
        assert call_args[1]['temperature'] == 0
        assert "test prompt" in call_args[1]['prompt']
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_evaluation_node_score_bounds(self, mock_llm_client_class, mock_task_config):
        """Test EVALUATIONNODE score bounds enforcement."""
        # Setup mock with out-of-bounds scores
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.return_value = '{"s": 150, "rs": "reasoning", "Pr(s)": 1.5}'
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        s, rs, pr_s = model._evaluation_node("test prompt", 0)
        
        # Check that scores are bounded
        assert s == 100  # Should be clamped to 100
        assert pr_s == 1.0  # Should be clamped to 1.0
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_aggregate_rationale_node(self, mock_llm_client_class, mock_task_config):
        """Test AGGREGATERATIONALENODE execution."""
        # Setup mock
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.return_value = '{"rpr": "aggregated rationale"}'
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        rpr = model._aggregate_rationale_node("test prompt", 0)
        
        assert rpr == "aggregated rationale"
        mock_client.call_openai.assert_called_once()
        call_args = mock_client.call_openai.call_args
        assert call_args[1]['temperature'] == 0
        assert "test prompt" in call_args[1]['prompt']
    
    def test_calculate_temperature(self, mock_task_config):
        """Test temperature calculation with cosine annealing."""
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        # Test with normal values
        tu = model._calculate_temperature(s=80, pr_s=0.8, nc=1, nt=3)
        
        # Verify temperature is calculated correctly
        assert isinstance(tu, float)
        assert 0 <= tu <= 1  # Temperature should be bounded
        
        # Test with edge cases
        tu_min = model._calculate_temperature(s=0, pr_s=0, nc=0, nt=1)
        tu_max = model._calculate_temperature(s=100, pr_s=1, nc=3, nt=3)
        
        assert 0 <= tu_min <= 1
        assert 0 <= tu_max <= 1
    
    def test_calculate_temperature_bounds(self, mock_task_config):
        """Test temperature calculation bounds enforcement."""
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        # Test with extreme values that should be bounded
        tu = model._calculate_temperature(s=200, pr_s=2.0, nc=10, nt=1)
        
        # Temperature should still be bounded between 0 and 1
        assert 0 <= tu <= 1
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_traverse_graph_branch(self, mock_llm_client_class, mock_task_config):
        """Test graph branch traversal."""
        # Setup mock responses
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        
        # Mock responses for the new nested structure
        # The structure is: 1 METHODNODE + exponential growth of nodes
        # Level 0: 1 node -> 3 solutions -> 3 AGGREGATERATIONALENODEs
        # Level 1: 3 nodes -> 9 solutions -> 9 AGGREGATERATIONALENODEs  
        # Level 2: 9 nodes -> 27 solutions -> 0 AGGREGATERATIONALENODEs (last level)
        mock_responses = []
        
        # METHODNODE
        mock_responses.append('{"ma": "method analysis", "me": "method explanation"}')
        
        # Level 0: 1 node -> 3 solutions
        for solution_idx in range(3):
            mock_responses.append(f'{{"a": "answer 0-{solution_idx}", "ra": "rationale 0-{solution_idx}"}}')
            mock_responses.append(f'{{"s": {80 + solution_idx}, "rs": "reasoning 0-{solution_idx}", "Pr(s)": 0.{8 + solution_idx}}}')
            mock_responses.append(f'{{"rpr": "aggregated rationale 0-{solution_idx}"}}')
        
        # Level 1: 3 nodes -> 9 solutions
        for node_idx in range(3):
            for solution_idx in range(3):
                mock_responses.append(f'{{"a": "answer 1-{node_idx}-{solution_idx}", "ra": "rationale 1-{node_idx}-{solution_idx}"}}')
                mock_responses.append(f'{{"s": {85 + node_idx * 3 + solution_idx}, "rs": "reasoning 1-{node_idx}-{solution_idx}", "Pr(s)": 0.{9 + node_idx + solution_idx}}}')
                mock_responses.append(f'{{"rpr": "aggregated rationale 1-{node_idx}-{solution_idx}"}}')
        
        # Level 2: 9 nodes -> 27 solutions (no AGGREGATERATIONALENODE)
        for node_idx in range(9):
            for solution_idx in range(3):
                mock_responses.append(f'{{"a": "answer 2-{node_idx}-{solution_idx}", "ra": "rationale 2-{node_idx}-{solution_idx}"}}')
                mock_responses.append(f'{{"s": {90 + node_idx * 3 + solution_idx}, "rs": "reasoning 2-{node_idx}-{solution_idx}", "Pr(s)": 0.{10 + node_idx + solution_idx}}}')
        
        mock_client.call_openai.side_effect = mock_responses
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        solutions = model._traverse_graph_branch(0)
        
        # Should have 39 solutions total (3 + 9 + 27 = 39)
        # Level 0: 3 solutions, Level 1: 9 solutions, Level 2: 27 solutions
        assert len(solutions) == 39
        
        # Check solution structure
        for i, solution in enumerate(solutions):
            assert 'answer' in solution
            assert 'rationale' in solution
            assert 's' in solution
            assert 'rs' in solution
            assert 'pr_s' in solution
            assert 'root' in solution
            assert 'depth' in solution
            assert 'solution_idx' in solution
            assert solution['root'] == 1  # root_idx + 1
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_run_method(self, mock_llm_client_class, mock_task_config):
        """Test the main run method."""
        # Setup mock responses for all nodes across 3 root branches
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        
        # Create responses for 3 root nodes with exponential growth
        # For num_final_solutions=2 and graph_depth=3:
        # Each root: 1 METHODNODE + (2 + 4 + 8) ANSWERINGNODE + (2 + 4 + 8) EVALUATIONNODE + (2 + 4) AGGREGATERATIONALENODE
        mock_responses = []
        for root in range(3):
            # METHODNODE
            mock_responses.append('{"ma": "method analysis", "me": "method explanation"}')
            
            # Level 0: 1 node -> 2 solutions
            for solution_idx in range(2):
                mock_responses.append(f'{{"a": "answer {root}-0-{solution_idx}", "ra": "rationale {root}-0-{solution_idx}"}}')
                mock_responses.append(f'{{"s": {80 + root * 10 + solution_idx}, "rs": "reasoning {root}-0-{solution_idx}", "Pr(s)": 0.{8 + root + solution_idx}}}')
                mock_responses.append(f'{{"rpr": "aggregated rationale {root}-0-{solution_idx}"}}')
            
            # Level 1: 2 nodes -> 4 solutions
            for node_idx in range(2):
                for solution_idx in range(2):
                    mock_responses.append(f'{{"a": "answer {root}-1-{node_idx}-{solution_idx}", "ra": "rationale {root}-1-{node_idx}-{solution_idx}"}}')
                    mock_responses.append(f'{{"s": {85 + root * 10 + node_idx * 2 + solution_idx}, "rs": "reasoning {root}-1-{node_idx}-{solution_idx}", "Pr(s)": 0.{9 + root + node_idx + solution_idx}}}')
                    mock_responses.append(f'{{"rpr": "aggregated rationale {root}-1-{node_idx}-{solution_idx}"}}')
            
            # Level 2: 4 nodes -> 8 solutions (no AGGREGATERATIONALENODE)
            for node_idx in range(4):
                for solution_idx in range(2):
                    mock_responses.append(f'{{"a": "answer {root}-2-{node_idx}-{solution_idx}", "ra": "rationale {root}-2-{node_idx}-{solution_idx}"}}')
                    mock_responses.append(f'{{"s": {90 + root * 10 + node_idx * 2 + solution_idx}, "rs": "reasoning {root}-2-{node_idx}-{solution_idx}", "Pr(s)": 0.{10 + root + node_idx + solution_idx}}}')
        
        mock_client.call_openai.side_effect = mock_responses
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=2,  # Request only 2 final solutions
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        solution_text, intermediate_logs = model.run()
        
        # Check return values
        assert isinstance(solution_text, str)
        assert intermediate_logs == []  # EGoT should return empty intermediate logs
        
        # Check that solution text contains expected content
        assert "Enhanced Graph of Thoughts" in solution_text
        assert "Final Solutions" in solution_text
        
        # Should have called LLM multiple times (3 roots * multiple nodes per root)
        assert mock_client.call_openai.call_count > 0
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_run_method_no_solutions(self, mock_llm_client_class, mock_task_config):
        """Test run method when no solutions are generated."""
        # Setup mock to raise exception
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.side_effect = Exception("LLM call failed")
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        with pytest.raises(Exception, match="EGoT algorithm execution failed"):
            model.run()
    
    def test_format_solutions(self, mock_task_config):
        """Test solution formatting."""
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        # Test with empty solutions
        formatted = model._format_solutions([])
        assert formatted == "No solutions generated."
        
        # Test with sample solutions
        solutions = [
            {
                'answer': 'Solution 1',
                'rationale': 'Rationale 1',
                's': 85,
                'rs': 'Reasoning 1',
                'pr_s': 0.8,
                'root': 1,
                'depth': 1,
                'solution_idx': 1
            },
            {
                'answer': 'Solution 2',
                'rationale': 'Rationale 2',
                's': 90,
                'rs': 'Reasoning 2',
                'pr_s': 0.9,
                'root': 2,
                'depth': 2,
                'solution_idx': 2
            }
        ]
        
        formatted = model._format_solutions(solutions)
        
        assert "Enhanced Graph of Thoughts" in formatted
        assert "Solution 1" in formatted
        assert "Solution 2" in formatted
        assert "Confidence: 68.00" in formatted  # 85 * 0.8 = 68
        assert "Confidence: 81.00" in formatted  # 90 * 0.9 = 81
    
    @patch('src.algorithms.egot.main.LLMAPIClient')
    def test_node_methods_error_handling(self, mock_llm_client_class, mock_task_config):
        """Test error handling in node methods."""
        mock_client = Mock()
        mock_llm_client_class.return_value = mock_client
        mock_client.call_openai.side_effect = Exception("LLM API error")
        
        model = reasoning_model(
            task_config=mock_task_config,
            backbone_llm_name="gemini-pro",
            num_analogous_problems=10,
            num_solutions_per_problem=5,
            num_exploratory_ideas=50,
            num_new_rule_sets=3,
            num_final_solutions=3,
            num_solutions_combinational=10,
            num_thoughts_per_step=10,
            search_depth=3
        )
        
        # Test each node method raises appropriate exception
        with pytest.raises(Exception, match="METHODNODE execution failed"):
            model._method_node("test", 0)
        
        with pytest.raises(Exception, match="ANSWERINGNODE execution failed"):
            model._answering_node("test", 0.5)
        
        with pytest.raises(Exception, match="EVALUATIONNODE execution failed"):
            model._evaluation_node("test", 0)
        
        with pytest.raises(Exception, match="AGGREGATERATIONALENODE execution failed"):
            model._aggregate_rationale_node("test", 0)
