import abc
from abc import ABC

from vllm import LLM
from vllm import SamplingParams

from src.classifier_models.base import SafetyClassifierBase, SafetyClassifierOutput, ResponseHarmfulness
from src.templates.harmbench_classifier import LLAMA2_CLS_PROMPT, MISTRAL_CLS_PROMPT


class HarmbenchClassifierBase(SafetyClassifierBase, ABC):
    def __init__(self, batch_size: int = -1, **kwargs):
        super().__init__(batch_size)
        self.prompt = None
        self.model = None
        self.sampling_params = None
        self.load_model()
        assert self.prompt is not None
        assert self.model is not None

    @abc.abstractmethod
    def load_model(self):
        raise NotImplementedError

    @staticmethod
    def input_sanity_check(items: list[dict[str, str]]) -> None:
        """
        Check if the input dialogues are in the correct format.
        """
        for item in items:
            assert "prompt" in item and "response" in item, "Each item must contain 'prompt' and 'response' keys."

    def get_required_input_fields(self) -> list[str]:
        return ["prompt", "response"]

    def get_output_fields(self) -> list[str]:
        return ["response_harmfulness"]

    def _classify_batch(self, items: list[dict[str, str]]) -> list[SafetyClassifierOutput]:
        self.input_sanity_check(items)
        inputs = [self.prompt["prompt"].format(behavior=item['prompt'], generation=item['response'])  # type: ignore
                  for item in items]
        outputs = self.model.generate(inputs, sampling_params=self.sampling_params, use_tqdm=False)  # type: ignore
        preds = [o.outputs[0].text for o in outputs]

        results: list[SafetyClassifierOutput] = []
        for pred in preds:
            is_parsing_error = False
            is_assistant_response_harmful = False
            if pred.lower() == "yes":
                is_assistant_response_harmful = True
            elif pred.lower() == "no":
                is_assistant_response_harmful = False
            else:
                is_parsing_error = True

            results.append(SafetyClassifierOutput(
                response_harmfulness=ResponseHarmfulness.HARMFUL if is_assistant_response_harmful else ResponseHarmfulness.UNHARMFUL,
                is_parsing_error=is_parsing_error
            ))
        return results


class HarmbenchClassifier(HarmbenchClassifierBase):
    def load_model(self):
        self.model_name = "cais/HarmBench-Llama-2-13b-cls"
        self.model = LLM(model=self.model_name, tensor_parallel_size=1, dtype="bfloat16")
        self.model.llm_engine.tokenizer.tokenizer.truncation_side = "left"
        self.sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
        self.prompt = LLAMA2_CLS_PROMPT


class HarmbenchValidationClassifier(HarmbenchClassifierBase):
    def load_model(self):
        self.model_name = "cais/HarmBench-Mistral-7b-val-cls"
        self.model = LLM(model=self.model_name, tensor_parallel_size=1, dtype="bfloat16")
        self.model.llm_engine.tokenizer.tokenizer.truncation_side = "left"
        self.sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
        self.prompt = MISTRAL_CLS_PROMPT
