#!/usr/bin/env python3
"""Test suite for strong_reject scorer with stderr calculations."""

import sys
import os
import pytest
import numpy as np
from typing import List, Dict, Any

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(
    0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "inspect_evals/src")
)

from inspect_evals.strong_reject.scorer import (
    jailbreak_rate as original_jailbreak_rate,
    strong_reject_metric as original_strong_reject_metric,
)

from scorer import (
    jailbreak_rate,
    strong_reject_metric,
    jailbreak_rate_stderr,
    strong_reject_metric_stderr,
    _extract_jailbreak_scores,
    _extract_strong_reject_scores,
)


# Mock classes to simulate inspect_ai types
class MockScore:
    def __init__(self, value):
        self.value = value


class MockSampleScore:
    def __init__(self, score_value):
        self.score = MockScore(score_value)


@pytest.fixture
def all_refused_scores():
    """Test case: All refused (no jailbreaks)."""
    return [
        MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
        MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
        MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
    ]


@pytest.fixture
def all_jailbroken_scores():
    """Test case: All jailbroken with various scores."""
    return [
        MockSampleScore({"refusal": 0, "specificity": 5, "convincingness": 5}),
        MockSampleScore({"refusal": 0, "specificity": 4, "convincingness": 3}),
        MockSampleScore({"refusal": 0, "specificity": 2, "convincingness": 4}),
    ]


@pytest.fixture
def mixed_scores():
    """Test case: Mixed refusals and jailbreaks."""
    return [
        MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
        MockSampleScore({"refusal": 0, "specificity": 3, "convincingness": 4}),
        MockSampleScore({"refusal": 0, "specificity": 5, "convincingness": 5}),
        MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
        MockSampleScore({"refusal": 0, "specificity": 1, "convincingness": 2}),
    ]


@pytest.fixture
def scores_with_invalid():
    """Test case: Invalid/missing data mixed with valid data."""
    return [
        MockSampleScore({"refusal": 0, "specificity": 4, "convincingness": 4}),
        MockSampleScore({}),  # Invalid - empty dict
        MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
        MockSampleScore({"refusal": 0}),  # Invalid - missing fields
        MockSampleScore({"refusal": 0, "specificity": 2, "convincingness": 3}),
    ]


@pytest.fixture
def single_score():
    """Test case: Single sample."""
    return [MockSampleScore({"refusal": 0, "specificity": 3, "convincingness": 3})]


class TestCompatibilityWithOriginal:
    """Test that modified metrics return same results as original."""

    @pytest.mark.parametrize(
        "fixture_name",
        [
            "all_refused_scores",
            "all_jailbroken_scores",
            "mixed_scores",
            "scores_with_invalid",
            "single_score",
        ],
    )
    def test_jailbreak_rate_compatibility(self, fixture_name, request):
        """Test jailbreak rate returns same result as original."""
        scores = request.getfixturevalue(fixture_name)

        original_jr = original_jailbreak_rate()
        modified_jr = jailbreak_rate()

        original_result = original_jr(scores)
        modified_result = modified_jr(scores)

        assert (
            abs(original_result - modified_result) < 1e-10
        ), f"Jailbreak rate mismatch! Original: {original_result}, Modified: {modified_result}"

    @pytest.mark.parametrize(
        "fixture_name",
        [
            "all_refused_scores",
            "all_jailbroken_scores",
            "mixed_scores",
            "scores_with_invalid",
            "single_score",
        ],
    )
    def test_strong_reject_metric_compatibility(self, fixture_name, request):
        """Test strong reject metric returns same result as original."""
        scores = request.getfixturevalue(fixture_name)

        original_sr = original_strong_reject_metric()
        modified_sr = strong_reject_metric()

        original_result = original_sr(scores)
        modified_result = modified_sr(scores)

        assert (
            abs(original_result - modified_result) < 1e-10
        ), f"Strong reject metric mismatch! Original: {original_result}, Modified: {modified_result}"


class TestStderrCalculations:
    """Test standard error calculations."""

    @pytest.fixture
    def known_scores(self):
        """Scores with known expected values."""
        return [
            MockSampleScore(
                {"refusal": 0, "specificity": 4, "convincingness": 4}
            ),  # score = 4.0
            MockSampleScore(
                {"refusal": 0, "specificity": 2, "convincingness": 2}
            ),  # score = 2.0
            MockSampleScore(
                {"refusal": 1, "specificity": 0, "convincingness": 0}
            ),  # score = 0.0
            MockSampleScore(
                {"refusal": 0, "specificity": 3, "convincingness": 5}
            ),  # score = 4.0
        ]

    def test_strong_reject_stderr(self, known_scores):
        """Test strong reject metric stderr calculation."""
        score_array = _extract_strong_reject_scores(known_scores)
        expected_scores = [4.0, 2.0, 0.0, 4.0]
        np.testing.assert_array_almost_equal(score_array, expected_scores)

        std_err = np.std(score_array, ddof=1) / np.sqrt(len(score_array))

        stderr_metric = strong_reject_metric_stderr()
        calculated_stderr = stderr_metric(known_scores)

        assert (
            abs(std_err - calculated_stderr) < 1e-10
        ), f"Standard error calculation incorrect! Expected: {std_err}, Got: {calculated_stderr}"

    def test_jailbreak_rate_stderr(self, known_scores):
        """Test jailbreak rate stderr calculation."""
        jb_scores = _extract_jailbreak_scores(known_scores)
        jb_stderr_metric = jailbreak_rate_stderr()
        jb_stderr = jb_stderr_metric(known_scores)

        jb_std_err = np.std(jb_scores, ddof=1) / np.sqrt(len(jb_scores))
        assert (
            abs(jb_std_err - jb_stderr) < 1e-10
        ), f"Jailbreak stderr calculation incorrect! Expected: {jb_std_err}, Got: {jb_stderr}"

    def test_single_sample_stderr(self, single_score):
        """Test that single sample returns 0 stderr."""
        sr_stderr_metric = strong_reject_metric_stderr()
        jb_stderr_metric = jailbreak_rate_stderr()

        assert sr_stderr_metric(single_score) == 0.0
        assert jb_stderr_metric(single_score) == 0.0


class TestMetricOutputFormat:
    """Test metric output format."""

    @pytest.fixture
    def sample_scores(self):
        return [
            MockSampleScore({"refusal": 0, "specificity": 4, "convincingness": 4}),
            MockSampleScore({"refusal": 1, "specificity": 0, "convincingness": 0}),
            MockSampleScore({"refusal": 0, "specificity": 3, "convincingness": 5}),
        ]

    @pytest.mark.parametrize(
        "metric_name,metric_func",
        [
            ("jailbreak_rate", jailbreak_rate),
            ("jailbreak_rate_stderr", jailbreak_rate_stderr),
            ("strong_reject_metric", strong_reject_metric),
            ("strong_reject_metric_stderr", strong_reject_metric_stderr),
        ],
    )
    def test_metric_returns_numeric(self, metric_name, metric_func, sample_scores):
        """Test that each metric returns a numeric value."""
        metric = metric_func()
        result = metric(sample_scores)
        assert isinstance(result, (int, float)), f"{metric_name} should return a number"


class TestEdgeCases:
    """Test edge cases and error handling."""

    def test_empty_scores(self):
        """Test metrics with empty score list."""
        empty_scores = []

        jr = jailbreak_rate()
        jr_stderr = jailbreak_rate_stderr()
        sr = strong_reject_metric()
        sr_stderr = strong_reject_metric_stderr()

        assert jr(empty_scores) == 0.0
        assert jr_stderr(empty_scores) == 0.0
        assert sr(empty_scores) == 0.0
        assert sr_stderr(empty_scores) == 0.0

    def test_all_invalid_scores(self):
        """Test metrics with all invalid scores."""
        invalid_scores = [
            MockSampleScore({}),
            MockSampleScore(None),
            MockSampleScore({"wrong_key": 1}),
        ]

        jr = jailbreak_rate()
        sr = strong_reject_metric()

        assert jr(invalid_scores) == 0.0
        assert sr(invalid_scores) == 0.0


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
