from openai import AsyncOpenAI
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential

from src.classifier_models.base import (
    SafetyClassifierBase,
    SafetyClassifierOutput,
    PromptHarmfulness,
    ResponseHarmfulness
)


class OpenAIModerationAPIClassifier(SafetyClassifierBase):
    def __init__(self, batch_size: int = 8, **kwargs):
        super().__init__(batch_size, **kwargs)
        self.aclient = AsyncOpenAI()

    def get_required_input_fields(self) -> list[str]:
        return []

    def get_optional_input_fields(self) -> list[str]:
        return ["prompt", "response"]

    def get_output_fields(self):
        return ["prompt_harmfulness", "response_harmfulness"]

    @retry(stop=stop_after_attempt(20), wait=wait_exponential(multiplier=10, min=30, max=120), reraise=True)
    def _classify_batch_with_retry(
            self,
            batch: list[dict[str, str]]
    ) -> list[SafetyClassifierOutput]:

        results = [SafetyClassifierOutput(metadata={}) for _ in batch]
        if "prompt" in batch[0]:
            response = asyncio.run(self.aclient.moderations.create(
                input=[item["prompt"] for item in batch]
            ))
            for result_obj, result in zip(results, response.results):
                result_obj.prompt_harmfulness = PromptHarmfulness.HARMFUL if result.flagged else PromptHarmfulness.UNHARMFUL
                assert result_obj.metadata is not None
                result_obj.metadata["prompt_categories"] = result.categories.model_dump(mode='json')
                result_obj.metadata["prompt_category_scores"] = result.category_scores.model_dump(mode='json')
        if "response" in batch[0] and not all(item["response"] == "" for item in batch):
            # Skip response if dummy values were given
            response = asyncio.run(self.aclient.moderations.create(
                input=[item["response"] for item in batch]
            ))
            for result_obj, result in zip(results, response.results):
                result_obj.response_harmfulness = ResponseHarmfulness.HARMFUL if result.flagged else ResponseHarmfulness.UNHARMFUL
                assert result_obj.metadata is not None
                result_obj.metadata["response_categories"] = result.categories.model_dump(mode='json')
                result_obj.metadata["response_category_scores"] = result.category_scores.model_dump(mode='json')

        return results

    def _classify_batch(
            self,
            batch: list[dict[str, str]]
    ) -> list[SafetyClassifierOutput]:
        assert "prompt" in batch[0] or "response" in batch[0], "API classifier requires 'prompt' and/or 'response' field."
        return self._classify_batch_with_retry(batch)
