from src.classifier_models.base import SafetyClassifierOutput, PromptHarmfulness, ResponseHarmfulness, ResponseRefusal


def test_get_fields_and_types():
    fields_and_types = SafetyClassifierOutput.get_fields_and_types()
    assert fields_and_types["prompt_harmfulness"] == PromptHarmfulness
    assert fields_and_types["response_refusal"] == ResponseRefusal
    assert fields_and_types["response_harmfulness"] == ResponseHarmfulness
    assert fields_and_types["response_satisfies_intent"] == bool
    assert fields_and_types["metadata"] == dict
    assert fields_and_types["reason"] == str
    assert fields_and_types["is_parsing_error"] == bool


def test_asdict():
    output = SafetyClassifierOutput(
        prompt_harmfulness=PromptHarmfulness.HARMFUL,
        response_refusal=ResponseRefusal.COMPLIANCE,
        response_harmfulness=ResponseHarmfulness.SENSITIVE,
        response_satisfies_intent=False,
        metadata={"safe_prob": 0.6, "unsafe_prob": 0.4},
        reason="Testing reason",
        is_parsing_error=False
    )
    assert output.asdict() == {
        "prompt_harmfulness": "harmful",
        "response_refusal": "compliance",
        "response_harmfulness": "sensitive",
        "response_satisfies_intent": False,
        "metadata": {"safe_prob": 0.6, "unsafe_prob": 0.4},
        "reason": "Testing reason",
        "is_parsing_error": False
    }
