from abc import ABC

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.classifier_models.base import SafetyClassifierBase, ConversationTurn, Role, SafetyClassifierOutput, PromptHarmfulness, ResponseHarmfulness
from src.templates.llama_guard import LlamaGuardAgentType, build_prompt, LLAMA_GUARD_CATEGORY, create_conversation


class LlamaGuard2(SafetyClassifierBase):
    """
    LlamaGuard2 can classify both prompt and response harmfulness. If only "prompt" fields are in the inputs, it will only do prompt_harmfulness.
    If both "prompt" and "response" fields are in the inputs, it will classify both prompt_harmfulness and response_harmfulness.
    """
    def __init__(self, batch_size: int = 4, **kwargs):
        super().__init__(batch_size)
        self.load_model()

    def load_model(self):
        self.model_name = "meta-llama/Meta-Llama-Guard-2-8B"
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.safe_loc = torch.tensor([self.tokenizer.encode("safe")[-1]]).unsqueeze(-1).to("cuda")
        self.unsafe_loc = torch.tensor([self.tokenizer.encode("unsafe")[-1]]).unsqueeze(-1).to("cuda")

    def get_required_input_fields(self) -> list[str]:
        return ["prompt"]

    def get_optional_input_fields(self) -> list[str]:
        return ["response"]

    def get_output_fields(self) -> list[str]:
        return ["prompt_harmfulness", "response_harmfulness"]

    @torch.inference_mode()
    def _llama_guard_classify(self, conversations: list[list[dict[str, str]]]) -> tuple[list[float], list[float]]:
        formatted_inputs: list[str] = [self.tokenizer.apply_chat_template(conversation, tokenize=False) for conversation in conversations]  # type: ignore
        assert all(isinstance(item, str) for item in formatted_inputs)
        encoded_inputs = self.tokenizer.batch_encode_plus(
            formatted_inputs,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=4000
        ).to('cuda')

        logits = self.model(**encoded_inputs).logits
        last_token_logits = logits[torch.arange(logits.shape[0]), -1]
        safe_logits = torch.gather(
            last_token_logits, -1,
            self.safe_loc.repeat(last_token_logits.shape[0], 1)
        ).squeeze(-1)
        unsafe_logits = torch.gather(
            last_token_logits, -1,
            self.unsafe_loc.repeat(last_token_logits.shape[0], 1)
        ).squeeze(-1)
        chosen_logits = torch.stack([safe_logits, unsafe_logits], dim=-1)
        probs = torch.nn.functional.softmax(chosen_logits, dim=-1)
        safe_probs = probs[:, 0].detach().cpu().tolist()
        unsafe_probs = probs[:, 1].detach().cpu().tolist()
        return safe_probs, unsafe_probs


    @torch.inference_mode()
    def _classify_batch(self, items: list[dict[str, str]]) -> list[SafetyClassifierOutput]:
        formatted_inputs = []
        assert all("response" in item for item in items) or not any("response" in item for item in items), "response must be included in all or none of items"
        for item in items:
            item_input = []
            item_input.append({"role": "user", "content": item["prompt"]})
            if "response" in item:
                item_input.append({"role": "assistant", "content": item["response"]})
            formatted_inputs.append(item_input)

        outputs: list[SafetyClassifierOutput] = []
        prompt_safe_probs, prompt_unsafe_probs = self._llama_guard_classify([[conversation[0]] for conversation in formatted_inputs])  # pass only user message
        if "response" in items[0]:
            response_safe_probs, response_unsafe_probs = self._llama_guard_classify(formatted_inputs)
        else:
            response_safe_probs, response_unsafe_probs = [None] * len(formatted_inputs), [None] * len(formatted_inputs)

        for i, (prompt_safe_prob, prompt_unsafe_prob, response_safe_prob, response_unsafe_prob) in enumerate(zip(prompt_safe_probs, prompt_unsafe_probs, response_safe_probs, response_unsafe_probs)):
            prompt_harmfulness = PromptHarmfulness.HARMFUL if prompt_unsafe_prob > prompt_safe_prob else PromptHarmfulness.UNHARMFUL
            metadata = {
                "prompt_safe_prob": prompt_safe_prob,
                "prompt_unsafe_prob": prompt_unsafe_prob
            }

            if response_safe_prob is not None and response_unsafe_prob is not None:
                response_harmfulness = ResponseHarmfulness.HARMFUL if response_unsafe_prob > response_safe_prob else ResponseHarmfulness.UNHARMFUL
                metadata["response_safe_prob"] = response_safe_prob
                metadata["response_unsafe_prob"] = response_unsafe_prob
            else:
                response_harmfulness = None

            output = SafetyClassifierOutput(
                prompt_harmfulness=prompt_harmfulness,
                response_harmfulness=response_harmfulness,
                metadata=metadata
            )
            outputs.append(output)
        return outputs


class LlamaGuardBase(SafetyClassifierBase, ABC):
    def __init__(self, batch_size: int = 4, **kwargs):
        super().__init__(batch_size)
        self.load_model()

    def load_model(self):
        self.model_name = "llamas-community/LlamaGuard-7b"
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.safe_loc = torch.tensor([self.tokenizer.encode("safe")[-1]]).unsqueeze(-1).to("cuda")
        self.unsafe_loc = torch.tensor([self.tokenizer.encode("unsafe")[-1]]).unsqueeze(-1).to("cuda")

    @staticmethod
    def build_input_prompts(dialogues: list[list[ConversationTurn]]) -> list[str]:
        inputs = []

        for dialogue in dialogues:
            formatted_prompt = build_prompt(
                LlamaGuardAgentType.AGENT if dialogue[-1].role == Role.ASSISTANT else LlamaGuardAgentType.USER,
                LLAMA_GUARD_CATEGORY,
                create_conversation([turn.content for turn in dialogue]))
            inputs.append(formatted_prompt)
        return inputs


class LlamaGuardUserRequest(LlamaGuardBase):
    def __init__(self, batch_size: int = 2, **kwargs):
        super().__init__(batch_size=batch_size)

    def get_required_input_fields(self) -> list[str]:
        return ["prompt"]

    def get_output_fields(self) -> list[str]:
        return ["prompt_harmfulness"]

    @torch.inference_mode()
    def _classify_batch(self, items: list[dict[str, str]]) -> list[SafetyClassifierOutput]:
        dialogues = [[ConversationTurn(Role.USER, item["prompt"])] for item in items]

        batch_inputs = self.build_input_prompts(dialogues)
        outputs: list[SafetyClassifierOutput] = []
        encoded_inputs = self.tokenizer.batch_encode_plus(
            batch_inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=4000,
        ).to("cuda")
        logits = self.model(**encoded_inputs).logits
        last_token_logits = logits[torch.arange(logits.shape[0]), -1]
        safe_logits = torch.gather(
            last_token_logits, -1,
            self.safe_loc.repeat(last_token_logits.shape[0], 1)
        ).squeeze(-1)
        unsafe_logits = torch.gather(
            last_token_logits, -1,
            self.unsafe_loc.repeat(last_token_logits.shape[0], 1)
        ).squeeze(-1)
        chosen_logits = torch.stack([safe_logits, unsafe_logits], dim=-1)
        probs = torch.nn.functional.softmax(chosen_logits, dim=-1)
        safe_probs = probs[:, 0].detach().cpu().tolist()
        unsafe_probs = probs[:, 1].detach().cpu().tolist()

        for i, (safe_prob, unsafe_prob) in enumerate(zip(safe_probs, unsafe_probs)):
            output = SafetyClassifierOutput(
                prompt_harmfulness=PromptHarmfulness.HARMFUL if unsafe_prob > safe_prob else PromptHarmfulness.UNHARMFUL,
                metadata={
                    "safe_prob": safe_prob,
                    "unsafe_prob": unsafe_prob
                }
            )
            outputs.append(output)
        return outputs


class LlamaGuardModelResponse(LlamaGuardBase):
    def __init__(self, batch_size: int = 2, **kwargs):
        super().__init__(batch_size=batch_size)

    def get_required_input_fields(self) -> list[str]:
        return ["prompt", "response"]

    def get_output_fields(self) -> list[str]:
        return ["response_harmfulness"]

    @torch.inference_mode()
    def _classify_batch(self, items: list[dict[str, str]]) -> list[SafetyClassifierOutput]:
        assert "prompt" in items[0] and "response" in items[0], "This model requires both prompt and response."
        dialogues = [[
            ConversationTurn(Role.USER, item["prompt"]),
            ConversationTurn(Role.ASSISTANT, item["response"])
        ] for item in items]
        for dialogue in dialogues:
            assert dialogue[-1].role == Role.ASSISTANT, "The last turn in each dialogue should be from the assistant."

        batch_inputs = self.build_input_prompts(dialogues)
        outputs = []
        encoded_inputs = self.tokenizer.batch_encode_plus(
            batch_inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=4000,
        ).to("cuda")
        logits = self.model(**encoded_inputs).logits
        last_token_logits = logits[torch.arange(logits.shape[0]), -1]
        safe_logits = torch.gather(
            last_token_logits, -1,
            self.safe_loc.repeat(last_token_logits.shape[0], 1)
        ).squeeze(-1)
        unsafe_logits = torch.gather(
            last_token_logits, -1,
            self.unsafe_loc.repeat(last_token_logits.shape[0], 1)
        ).squeeze(-1)
        chosen_logits = torch.stack([safe_logits, unsafe_logits], dim=-1)
        probs = torch.nn.functional.softmax(chosen_logits, dim=-1)
        safe_probs = probs[:, 0].detach().cpu().tolist()
        unsafe_probs = probs[:, 1].detach().cpu().tolist()

        for i, (safe_prob, unsafe_prob) in enumerate(zip(safe_probs, unsafe_probs)):
            output = SafetyClassifierOutput(
                response_harmfulness=ResponseHarmfulness.HARMFUL if unsafe_prob > safe_prob else ResponseHarmfulness.UNHARMFUL,
                metadata={
                    "safe_prob": safe_prob,
                    "unsafe_prob": unsafe_prob
                }
            )
            outputs.append(output)
        return outputs
