from typing import Any

from src.classifier_models.aegis import AegisLlamaGuardPermissive, AegisLlamaGuardDefensive
from src.classifier_models.api_safety_classifiers import OpenAIModerationAPIClassifier
from src.classifier_models.base import SafetyClassifierBase, LegacySafetyClassifierBase, ConversationTurn, Role
from src.classifier_models.beaverdam import BeaverDam
from src.classifier_models.gpt_harmful_classifier import (
    GPTHarmfulDetectorRainbowTeaming,
    GPTHarmfulDetectorHarmbench,
    GPTRubricSafetyEvaluator,
    GPTHarmfulDetectorV0UserRequestOnly,
    GPTHarmfulDetectorV1FullTurn,
    GPTHarmfulDetectorV1UserRequestOnly,
    GPTHarmfulDetectorV2FullTurn,
    GPTNuancedHarmfulDetectorV0,
)
from src.classifier_models.gpt_refusal_classifier import GPTRefusalDetectorV0, GPTRefusalEvaluator, \
    GPTRefusalQuickEvaluator, GPTRefuseThenComplyRefusalDetector, GPTRefuseThenComplyHarmfulnessDetector
from src.classifier_models.gpt_refusal_classifier import GPTRefusalDetectorV0_2
from src.classifier_models.harmbench_classifier import HarmbenchClassifier, HarmbenchValidationClassifier
from src.classifier_models.keyword_based_refusal_classifiers import KeywordBasedRefusalClassifier
from src.classifier_models.librai_longformer import LongformerHarmClassifier, LongformerActionClassifier
from src.classifier_models.llama_guard import LlamaGuardUserRequest, LlamaGuardModelResponse, LlamaGuard2
from src.classifier_models.md_judge import MDJudgeResponseHarmClassifier
from src.classifier_models.off_topic_classifier import GPTOffTopicDetector
from src.classifier_models.open_instruct_gpt_classifier import OpenInstructGPTNuancedRefusalDetector, \
    OpenInstructGPTJailbreakTriggerClassifier
from src.classifier_models.openai_model_safety_classifier import OpenAIModelSafetyClassifier
from src.classifier_models.openai_model_safety_classifier_configs import OpenAIModelSafetyClassifierConfig
from src.classifier_models.sorrybench import SorryBenchClassifier
from src.classifier_models.wildguard import WildGuard


def load_classifier_model(model_name: str, **kwargs) -> SafetyClassifierBase:
    named_openai_classifiers = {
        "GPTHarmfulDetectorV0UserRequestOnly",
        "GPTRefusalDetectorV0JSON",
    }

    if model_name == "WildGuard":
        return WildGuard(**kwargs)
    elif model_name == "LlamaGuardUserRequest":
        return LlamaGuardUserRequest(**kwargs)
    elif model_name == "LlamaGuardModelResponse":
        return LlamaGuardModelResponse(**kwargs)
    elif model_name == "HarmbenchClassifier":
        return HarmbenchClassifier(**kwargs)
    elif model_name == "HarmbenchValidationClassifier":
        return HarmbenchValidationClassifier(**kwargs)
    elif model_name == "AegisLlamaGuardPermissive":
        return AegisLlamaGuardPermissive(**kwargs)
    elif model_name == "AegisLlamaGuardDefensive":
        return AegisLlamaGuardDefensive(**kwargs)
    elif model_name == "LlamaGuard2":
        return LlamaGuard2(**kwargs)
    elif model_name == "BeaverDam":
        return BeaverDam(**kwargs)
    elif model_name == "sorrybench":
        return SorryBenchClassifier(**kwargs)

    # GPT-based classifiers
    elif model_name in named_openai_classifiers:
        config = OpenAIModelSafetyClassifierConfig.from_yaml(f'config_openai/{model_name}.yaml').get_config_dict()
        return OpenAIModelSafetyClassifier(**config)
    elif model_name == "OpenInstructGPTNuancedRefusalDetector":
        return OpenInstructGPTNuancedRefusalDetector(**kwargs)
    elif model_name == "OpenInstructGPTJailbreakTriggerClassifier":
        return OpenInstructGPTJailbreakTriggerClassifier(**kwargs)
    elif model_name == "GPTRefusalDetectorV0":
        return GPTRefusalDetectorV0(**kwargs)
    elif model_name == "GPTRefusalDetectorV0_2":
        return GPTRefusalDetectorV0_2(**kwargs)
    elif model_name == "GPTRefuseThenComplyRefusalDetector":
        return GPTRefuseThenComplyRefusalDetector(**kwargs)
    elif model_name == "GPTRefuseThenComplyHarmfulnessDetector":
        return GPTRefuseThenComplyHarmfulnessDetector(**kwargs)

    # Arbitrary GPT-based classifiers
    elif model_name == "OpenAIModelSafetyClassifier":
        if not (
                all(k in kwargs for k in ["instruction", "output_fields"])
                or "config" in kwargs
                or "config_path" in kwargs
        ):
            raise ValueError("Must provide either 'instruction' and 'output_fields' or 'config'")
        if "config" in kwargs:
            assert isinstance(kwargs["config"], OpenAIModelSafetyClassifierConfig)
            config = kwargs["config"]
        elif "config_path" in kwargs:
            config = OpenAIModelSafetyClassifierConfig.from_yaml(kwargs['config_path'])
        else:
            config = OpenAIModelSafetyClassifierConfig(**kwargs)
        return OpenAIModelSafetyClassifier(**config.get_config_dict())

    # API Classifiers
    elif model_name == "OpenAIModeration":
        return OpenAIModerationAPIClassifier()

    # Miscellaneous external models
    elif model_name == "LongformerHarmClassifier":
        return LongformerHarmClassifier(**kwargs)
    elif model_name == "LongformerActionClassifier":
        return LongformerActionClassifier(**kwargs)
    elif model_name == "MDJudgeResponseHarmClassifier":
        return MDJudgeResponseHarmClassifier(**kwargs)

    # Keyword-based classifiers
    elif model_name == "KeywordBasedRefusalClassifier":
        return KeywordBasedRefusalClassifier()

    else:
        raise ValueError(f"Model {model_name} not found.")


def load_legacy_classifier_model(model_name: str, **kwargs) -> LegacySafetyClassifierBase:
    batch_size = kwargs.pop("batch_size", 2)

    # Refusal models
    if model_name == "GPTRefusalQuickEvaluator":
        return GPTRefusalQuickEvaluator(batch_size)
    elif model_name == "DoAnythingNowGPTRefusalDetector":
        return GPTRefusalEvaluator(batch_size)

    # Harmful models, request only
    elif model_name == "GPTHarmfulDetectorV0UserRequestOnly":
        return GPTHarmfulDetectorV0UserRequestOnly(batch_size)
    elif model_name == "GPTHarmfulDetectorV1UserRequestOnly":
        return GPTHarmfulDetectorV1UserRequestOnly(batch_size)
    elif model_name == "GPTNuancedHarmfulDetectorV0":
        return GPTNuancedHarmfulDetectorV0(batch_size)

    # Harmful models, full turn
    elif model_name == "GPTHarmfulDetectorRainbowTeaming":
        return GPTHarmfulDetectorRainbowTeaming(batch_size)
    elif model_name == "GPTHarmfulDetectorHarmbench":
        return GPTHarmfulDetectorHarmbench(batch_size)
    elif model_name == "DoNotAnswerGPTRubricEvaluator":
        return GPTRubricSafetyEvaluator(rubric_type="do_not_answer")
    elif model_name == "DoAnythingNowGPTRubricEvaluator":
        return GPTRubricSafetyEvaluator(rubric_type="do_anything_now")
    elif model_name == "GPTHarmfulDetectorV1FullTurn":
        return GPTHarmfulDetectorV1FullTurn(batch_size)
    elif model_name == "GPTHarmfulDetectorV2FullTurn":
        return GPTHarmfulDetectorV2FullTurn(batch_size)

    # Off-topic models
    elif model_name == "GPTOffTopicDetector":
        return GPTOffTopicDetector(batch_size)
    else:
        raise ValueError(f"Model {model_name} not found.")


def convert_data_items_to_legacy_classifier_input(
        model: LegacySafetyClassifierBase,
        data: list[dict[str, Any]]
) -> list[list[ConversationTurn]]:
    """
    Converts a list of data items to a list of ConversationTurn objects that can be used as input to the specified classifier model.
    """
    formatted = []
    for item in data:
        if "is_assistant_response_harmful" in model.get_classes():
            formatted.append(
                [ConversationTurn(content=item["prompt"], role=Role.USER),
                 ConversationTurn(content=item.get("response", ""), role=Role.ASSISTANT)]
            )
        else:
            formatted.append(
                [ConversationTurn(content=item["prompt"], role=Role.USER)]
            )
    return formatted
