"""
Tests for Error Tracking Infrastructure

This module tests the error tracking functionality for LLM invocation failures.
"""

import json
import time

import pytest

from src.utils.error_tracking import ErrorTracker, LLMInvocationError


class TestLLMInvocationError:
    """Test the LLMInvocationError dataclass."""

    def test_error_creation(self):
        """Test creating an LLMInvocationError instance."""
        timestamp = time.time()
        error = LLMInvocationError(
            timestamp=timestamp,
            error_message="Test error message",
            error_type="ValueError",
            input_messages=[{"role": "user", "content": "Test question"}],
            model_config={"name": "test-model", "version": "1.0"},
            retry_attempts=3,
            task_id="task_123",
        )

        assert error.timestamp == timestamp
        assert error.error_message == "Test error message"
        assert error.error_type == "ValueError"
        assert error.input_messages == [{"role": "user", "content": "Test question"}]
        assert error.model_config == {"name": "test-model", "version": "1.0"}
        assert error.retry_attempts == 3
        assert error.task_id == "task_123"

    def test_error_creation_without_task_id(self):
        """Test creating an error without task_id."""
        error = LLMInvocationError(
            timestamp=time.time(),
            error_message="Test error",
            error_type="RuntimeError",
            input_messages=[],
            model_config={},
            retry_attempts=0,
        )

        assert error.task_id is None


class TestErrorTracker:
    """Test the ErrorTracker class."""

    @pytest.fixture
    def error_tracker(self):
        """Create a fresh ErrorTracker instance."""
        return ErrorTracker()

    @pytest.fixture
    def sample_error(self):
        """Create a sample exception."""
        return ValueError("Sample error message")

    @pytest.fixture
    def sample_input_messages(self):
        """Create sample input messages."""
        return [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "What is 2 + 2?"},
        ]

    @pytest.fixture
    def sample_model_config(self):
        """Create sample model configuration."""
        return {"name": "OpenAI", "version_name": "gpt-4", "temperature": 0.0}

    def test_empty_tracker(self, error_tracker):
        """Test behavior of empty error tracker."""
        assert error_tracker.get_error_count() == 0
        assert error_tracker.get_success_rate(10) == 1.0
        assert error_tracker.to_dict() == []

        summary = error_tracker.get_summary()
        assert summary["failed_llm_calls"] == 0
        assert summary["error_types"] == {}
        assert summary["average_retry_attempts"] == 0
        assert summary["sample_errors"] == []

    def test_record_single_error(
        self, error_tracker, sample_error, sample_input_messages, sample_model_config
    ):
        """Test recording a single error."""
        error_tracker.record_error(
            error=sample_error,
            input_messages=sample_input_messages,
            model_config=sample_model_config,
            retry_attempts=3,
            task_id="task_1",
        )

        assert error_tracker.get_error_count() == 1
        assert error_tracker.get_success_rate(5) == 0.8  # 4 successes out of 5 total

        errors_dict = error_tracker.to_dict()
        assert len(errors_dict) == 1
        assert errors_dict[0]["error_message"] == "Sample error message"
        assert errors_dict[0]["error_type"] == "ValueError"
        assert errors_dict[0]["task_id"] == "task_1"
        assert errors_dict[0]["retry_attempts"] == 3

    def test_record_multiple_errors(
        self, error_tracker, sample_input_messages, sample_model_config
    ):
        """Test recording multiple errors."""
        # Record different types of errors
        error_tracker.record_error(
            error=ValueError("First error"),
            input_messages=sample_input_messages,
            model_config=sample_model_config,
            retry_attempts=2,
            task_id="task_1",
        )

        error_tracker.record_error(
            error=RuntimeError("Second error"),
            input_messages=sample_input_messages,
            model_config=sample_model_config,
            retry_attempts=5,
            task_id="task_2",
        )

        error_tracker.record_error(
            error=ValueError("Third error"),
            input_messages=sample_input_messages,
            model_config=sample_model_config,
            retry_attempts=1,
            task_id="task_3",
        )

        assert error_tracker.get_error_count() == 3
        assert error_tracker.get_success_rate(10) == 0.7  # 7 successes out of 10 total

        summary = error_tracker.get_summary()
        assert summary["failed_llm_calls"] == 3
        assert summary["error_types"]["ValueError"] == 2
        assert summary["error_types"]["RuntimeError"] == 1
        assert len(summary["errors"]) == 3

    def test_summary_errors(
        self, error_tracker, sample_input_messages, sample_model_config
    ):
        """Test that summary includes errors with correct format."""
        error_tracker.record_error(
            error=ValueError("Test error"),
            input_messages=sample_input_messages,
            model_config=sample_model_config,
            retry_attempts=2,
            task_id="task_1",
        )

        summary = error_tracker.get_summary()
        errors = summary["errors"]

        assert len(errors) == 1
        error = errors[0]

        assert "timestamp" in error
        assert error["error_type"] == "ValueError"
        assert error["error_message"] == "Test error"
        assert error["task_id"] == "task_1"
        assert error["retry_attempts"] == 2
        assert error["input_messages"][1]["content"] == "What is 2 + 2?"

    def test_input_preview_user_message(self, error_tracker, sample_model_config):
        """Test input preview extraction from user message."""
        input_messages = [
            {"role": "system", "content": "System prompt"},
            {"role": "user", "content": "This is a user question"},
        ]

        error_tracker.record_error(
            error=ValueError("Test"),
            input_messages=input_messages,
            model_config=sample_model_config,
            retry_attempts=0,
        )

        preview = error_tracker._get_input_preview(input_messages)
        assert preview == "This is a user question"

    def test_input_preview_truncation(self, error_tracker, sample_model_config):
        """Test input preview truncation for long messages."""
        long_content = "x" * 250  # Longer than 200 character limit
        input_messages = [{"role": "user", "content": long_content}]

        error_tracker.record_error(
            error=ValueError("Test"),
            input_messages=input_messages,
            model_config=sample_model_config,
            retry_attempts=0,
        )

        preview = error_tracker._get_input_preview(input_messages)
        assert len(preview) == 203  # 200 chars + "..."
        assert preview.endswith("...")

    def test_input_preview_no_messages(self, error_tracker, sample_model_config):
        """Test input preview with empty messages."""
        error_tracker.record_error(
            error=ValueError("Test"),
            input_messages=[],
            model_config=sample_model_config,
            retry_attempts=0,
        )

        preview = error_tracker._get_input_preview([])
        assert preview == "No input messages"

    def test_input_preview_fallback(self, error_tracker, sample_model_config):
        """Test input preview fallback to first message."""
        input_messages = [{"role": "assistant", "content": "Assistant message"}]

        error_tracker.record_error(
            error=ValueError("Test"),
            input_messages=input_messages,
            model_config=sample_model_config,
            retry_attempts=0,
        )

        preview = error_tracker._get_input_preview(input_messages)
        assert preview == "Assistant message"

    def test_success_rate_edge_cases(self, error_tracker):
        """Test success rate calculation edge cases."""
        # Test with zero total attempts
        assert error_tracker.get_success_rate(0) == 1.0

        # Test with errors but zero total (shouldn't happen in practice)
        error_tracker.record_error(
            error=ValueError("Test"),
            input_messages=[],
            model_config={},
            retry_attempts=0,
        )
        assert error_tracker.get_success_rate(0) == 1.0

    def test_serialization(
        self, error_tracker, sample_error, sample_input_messages, sample_model_config
    ):
        """Test that error data can be serialized to JSON."""
        error_tracker.record_error(
            error=sample_error,
            input_messages=sample_input_messages,
            model_config=sample_model_config,
            retry_attempts=3,
            task_id="task_1",
        )

        # Test that to_dict() output is JSON serializable
        errors_dict = error_tracker.to_dict()
        json_str = json.dumps(errors_dict)

        # Verify it can be loaded back
        loaded_data = json.loads(json_str)
        assert len(loaded_data) == 1
        assert loaded_data[0]["error_message"] == "Sample error message"
