from string import Formatter
import json
from enum import EnumMeta

from src.classifier_models.base import SafetyClassifierOutput, SafetyClassifierBase, PromptHarmfulness, ResponseRefusal
from src.dispatch_openai_requests import dispatch_openai_chat_requests, dispatch_batch_openai_requests_with_retry


OPENAI_JSON_FORMAT_INSTRUCTION = """
Respond directly in JSON format (without Markdown code blocks or other formatting).
The JSON should follow this schema:
{json_schema}"""


class OpenAIModelSafetyClassifier(SafetyClassifierBase):
    """
    OpenAIModelSafetyClassifier should be constructed using an OpenAIModelSafetyClassifierConfig object:
        classifier = OpenAIModelSafetyClassifier(**config.get_config_dict())

    Takes in the following parameters:
    - `batch_size`: The number of inputs to process in a single batch.
    - `prompt`: The prompt to use when sending requests to the OpenAI model.
    - `output_schema`: A dictionary defining the expected output fields and their types.
                       Can also describe bounds on values e.g. (yes/no) or (1-5)
    - `openai_model`: The name of the OpenAI model to use for classification.
    """
    def __init__(
            self,
            batch_size: int,
            input_prompt: str,
            output_schema: dict[str, str],
            openai_model: str,
            **kwargs
    ):
        super().__init__(batch_size)
        self.openai_model = openai_model
        assert all(key in SafetyClassifierOutput.get_fields_and_types().keys() for key in output_schema.keys()), "Output fields should correspond to SafetyClassifierOutput fields."
        self.output_schema = output_schema
        self.input_prompt = input_prompt
        self.output_formatting_prompt = '\n' + OPENAI_JSON_FORMAT_INSTRUCTION.format(json_schema=json.dumps(self.output_schema, indent=2))

    def get_required_input_fields(self) -> list[str]:
        return [field_name for _, field_name, _, _ in Formatter().parse(self.input_prompt) if field_name is not None]

    def get_output_fields(self):
        return list(self.output_schema.keys())

    async def _get_and_parse_responses(
            self,
            messages_list: list[list[dict[str, str]]],
            model: str = ''  # only used for compatibility with dispatch function
    ) -> list[SafetyClassifierOutput]:
        outputs = []
        responses = await dispatch_openai_chat_requests(
                messages_list=messages_list,
                model=self.openai_model,
                response_format={'type': 'json_object'}
        )
        for response in responses:
            try:
                output = response.choices[0].message.content
                assert output is not None, "OpenAI response is None"
                output = json.loads(output.strip())

                for field, val in output.items():
                    # Reconstruct Enum fields from raw string values
                    field_type = SafetyClassifierOutput.get_fields_and_types()[field]
                    if isinstance(field_type, EnumMeta):
                        output[field] = field_type(val)

                if not all(field in output for field in self.output_schema):
                    output["is_parsing_error"] = True
                output = SafetyClassifierOutput(**output)
            except Exception:
                output = SafetyClassifierOutput(is_parsing_error=True)
            outputs.append(output)
        return outputs

    def _classify_batch(
            self,
            batch: list[dict[str, str]]
    ) -> list[SafetyClassifierOutput]:
        messages_list = [[{'role': 'user', 'content': self.input_prompt.format(**item) + self.output_formatting_prompt}] for item in batch]
        return dispatch_batch_openai_requests_with_retry(
            message_or_prompt_batch=messages_list,
            model=self.openai_model,
            dispatch_func=self._get_and_parse_responses
        )
