from pydantic import BaseModel
from typing import Any
from enum import Enum, EnumMeta
import yaml
import os

from src.classifier_models.base import SafetyClassifierOutput


class OpenAIModelSafetyClassifierConfig(BaseModel):
    batch_size: int = 16
    openai_model: str = 'gpt-4-0125-preview'
    # Instruction to the model, including "{output_description}" where the
    # specification of output fields and descriptions should be filled in.
    # Do not directly include the field descriptions in the instruction.
    instruction: str
    # Keys are output field names, and values are the plaintext descriptions of the field.
    # Field names must correspond to SafetyClassifierOutput fields.
    output_fields: dict[str, str]

    def model_post_init(self, __context: Any):
        if '{output_description}' not in self.instruction:
            raise ValueError("Instruction must contain '{output_description}' to fill in output fields")
        for field in self.output_fields:
            if field not in SafetyClassifierOutput.get_fields_and_types():
                raise ValueError(f"{field} is not a valid output field")

    def get_output_schema(self):
        output_schema = {}
        for field in self.output_fields:
            field_type = SafetyClassifierOutput.get_fields_and_types()[field]

            field_type_desc = ""
            if isinstance(field_type, EnumMeta):
                field_options = [t.value for t in field_type if isinstance(t, Enum)]
                field_type_options = set(type(field_option) for field_option in field_options)
                # field_desc: type1/type2 (value1/value2/value3)
                field_type_desc = '/'.join([str(t.__name__) for t in field_type_options]) + ' (' + '/'.join([str(o) for o in field_options]) + ')'
            else:
                field_type_desc = str(field_type.__name__)
            
            output_schema[field] = field_type_desc

        return output_schema

    def get_config_dict(self):
        output_description = '\n'.join(f"{field_name}: {field_desc}" for field_name, field_desc in self.output_fields.items())
        class MissingKeyNoopDict(dict):
            # Used to just replace {output_description} but leave other formatting variables
            # like {prompt} as-is to be substituted later
            def __missing__(self, key):
                return '{' + key + '}'

        return {
            "batch_size": self.batch_size,
            "openai_model": self.openai_model,
            "input_prompt": self.instruction.format_map(MissingKeyNoopDict(output_description=output_description)),
            "output_schema": self.get_output_schema()
        }

    @staticmethod
    def from_yaml(config_path: str) -> 'OpenAIModelSafetyClassifierConfig':
        """
        Config yaml path should be either absolute path or starting from src/classifier_models/ directory.
        """
        if not config_path.startswith('/'):
            config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_path)
        with open(config_path, 'r') as f:
            config_dict = yaml.safe_load(f)
        return OpenAIModelSafetyClassifierConfig(**config_dict)
