import abc
import gc
import os
from collections import Counter
from copy import deepcopy
from typing import Optional

import torch
import vllm
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM
from vllm import RequestOutput

from src.classifier_models.base import (
    SafetyClassifierBase,
    PromptHarmfulness,
    ResponseHarmfulness,
)
from src.classifier_models.loader import load_classifier_model
from src.generation_utils import delete_vllm_model, subprocess_inference_with_vllm
from src.generation_utils import (
    load_hf_lm_and_tokenizer,
    generate_completions,
    query_openai_chat_model,
)
from src.templates.single_turn import get_template


class GeneratorModelBase(abc.ABC):
    def __init__(
            self,
            model_name_or_path: str,
            model_input_template_path_or_name: str,
            hf_revision: str | None = None,
            batch_size: int | None = None,
            filter_classifier_name: str | None = None,
            filter_model_override_path: str | None = None,
    ):
        """
        Initialize the model.
        """
        self.model_name_or_path = model_name_or_path
        self.hf_revision = hf_revision
        self.model_input_template_path_or_name = model_input_template_path_or_name
        if model_input_template_path_or_name.endswith(".txt"):
            with open(model_input_template_path_or_name, "r") as f:
                self.model_input_template = f.read()
        else:
            self.model_input_template = get_template(
                model_name_or_path=model_name_or_path,
                chat_template=model_input_template_path_or_name,
            )["prompt"]

        print(f"*** Model input template: {self.model_input_template}")
        self.batch_size = batch_size
        self.filter_classifier_name = filter_classifier_name
        self.filter_model_override_path = filter_model_override_path

    @abc.abstractmethod
    def load_model(self, model_name_or_path: str) -> tuple[object, object]:
        """
        Load the model from the given name or path.

        Args:
            model_name_or_path: Name or path of the model to load.
                If it is a name, it will be loaded from the Hugging Face model hub, or using the API-based models (e.g. GPT-4).
                If it is a path, it will be loaded from the local path.

        Returns:
            Tuple of model and tokenizer.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def generate_completions(
            self,
            inputs: list[dict[str, str]],
            max_new_tokens: int,
            temperature: float,
            top_p: float,
            repetition_penalty: float | None = None,
            n: int = 1,
            stop: Optional[str | list[str]] = None,
            return_full_outputs: bool = False,
    ) -> list[str] | list[RequestOutput]:
        """
        Generate completions for the given input data.

        Args:
            inputs: List of input data, which will be used to generate the completions.

        Returns:
            List of completions.
        """
        raise NotImplementedError


class HfGeneratorModel(GeneratorModelBase):
    def load_model(
            self, model_name_or_path: str
    ) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
        model, tokenizer = load_hf_lm_and_tokenizer(model_name_or_path, revision=self.hf_revision)
        return model, tokenizer

    def generate_completions(
            self,
            inputs: list[dict[str, str]],
            max_new_tokens: int,
            temperature: float,
            top_p: float,
            repetition_penalty: float | None = None,
            n: int = 1,
            stop: Optional[str | list[str]] = None,
            return_full_outputs: bool = False,
    ) -> list[str]:
        assert self.batch_size is not None, "Batch size must be provided for HF model."
        formatted_inputs = [self.model_input_template.format(**d) for d in inputs]
        model, tokenizer = self.load_model(model_name_or_path=self.model_name_or_path)
        if stop:
            stop_id_sequences = [tokenizer.encode(x, add_special_tokens=False) for x in stop]
        else:
            stop_id_sequences = None

        if n > 1:
            formatted_inputs = [[x] * n for x in formatted_inputs]
            formatted_inputs = [item for sublist in formatted_inputs for item in sublist]
        outputs = generate_completions(
            model=model,
            tokenizer=tokenizer,
            prompts=formatted_inputs,
            model_name_or_path=self.model_name_or_path,
            batch_size=self.batch_size,
            max_new_tokens=max_new_tokens,
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            stop_id_sequences=stop_id_sequences,
        )
        del model
        torch.cuda.empty_cache()
        gc.collect()

        if self.filter_classifier_name is not None:
            filter_model = load_classifier_model(
                self.filter_classifier_name,
                local_model_path=self.filter_model_override_path,
            )
            assert "prompt_harmfulness" in filter_model.get_output_fields()
            cls_inputs = deepcopy(inputs)
            for item in cls_inputs:
                assert "prompt" in item
                item["response"] = ""
            cls_outputs = filter_model.classify(inputs)
            for i, cls_output in enumerate(cls_outputs):
                if cls_output.prompt_harmfulness == PromptHarmfulness.HARMFUL:
                    outputs[i] = "I'm sorry, I can't help with that."
            if isinstance(getattr(filter_model, "model", None), vllm.LLM):
                delete_vllm_model(filter_model.model)
            del filter_model
            torch.cuda.empty_cache()
            gc.collect()

        return outputs

    @torch.no_grad()
    def get_next_word_predictions(self,
                                  model: AutoModelForCausalLM,
                                  tokenizer: AutoTokenizer,
                                  prompts: list[str],  # list of prompts
                                  choices: list[str],
                                  add_special_tokens=True,
                                  ) -> tuple[list[list[int]], list[list[float]]]:
        predictions, probs = [], []
        progress = tqdm(total=len(prompts), desc="Getting Predictions")
        candidate_token_ids = [tokenizer.encode(" " + answer_choice, add_special_tokens=False)[-1]
                               for answer_choice in choices]

        for i in range(0, len(prompts), self.batch_size):
            batch_prompts = prompts[i: i + self.batch_size]
            tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt",
                                          add_special_tokens=add_special_tokens)
            batch_input_ids = tokenized_prompts.input_ids
            attention_mask = tokenized_prompts.attention_mask

            batch_input_ids = batch_input_ids.cuda()
            attention_mask = attention_mask.cuda()

            batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
            batch_probs = torch.softmax(batch_logits, dim=-1)
            batch_probs = batch_probs[:, candidate_token_ids]
            batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
            predictions += batch_prediction_indices.tolist()
            probs += batch_probs.tolist()

            progress.update(len(batch_prompts))

        assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
        return predictions, probs


class VllmGeneratorModel(GeneratorModelBase):
    def __init__(
            self,
            model_name_or_path: str,
            model_input_template_path_or_name: str,
            hf_revision: str = "main",
            delete_model_after_job_finished: bool = False,
            filter_classifier_name: str | None = None,
            filter_model_override_path: str | None = None,
    ):
        super().__init__(
            model_name_or_path=model_name_or_path,
            hf_revision=hf_revision,
            model_input_template_path_or_name=model_input_template_path_or_name,
            filter_classifier_name=filter_classifier_name,
            filter_model_override_path=filter_model_override_path,
        )
        self.delete_model_after_job_finished = delete_model_after_job_finished

    def load_model(self, model_name_or_path: str) -> LLM:
        ...

    def generate_completions(
            self,
            inputs: list[dict[str, str]],
            max_new_tokens: int,
            temperature: float,
            top_p: float,
            repetition_penalty: float = 1.0,
            n: int = 1,
            stop: Optional[str | list[str]] = None,
            return_full_outputs: bool = False,
    ) -> list[str] | list[RequestOutput]:
        formatted_inputs = [self.model_input_template.format(**d) for d in inputs]
        outputs = subprocess_inference_with_vllm(
            do_chat_formatting=False,
            prompts=formatted_inputs,
            model_name_or_path=self.model_name_or_path,
            hf_revision=self.hf_revision,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_new_tokens,
            use_tqdm=True,
            n=n,
            stop=stop,
            return_full_outputs=return_full_outputs,
            repetition_penalty=repetition_penalty,
        )

        if return_full_outputs and self.filter_classifier_name is not None:
            raise NotImplementedError("Filtering is not supported for full outputs.")

        if self.filter_classifier_name is not None:
            filter_model = load_classifier_model(
                self.filter_classifier_name,
                local_model_path=self.filter_model_override_path,
            )
            assert (
                    "prompt_harmfulness" in filter_model.get_output_fields()
                    or "response_harmfulness" in filter_model.get_output_fields()
            )
            cls_inputs = deepcopy(inputs)
            for item in cls_inputs:
                assert "instruction" in item
                item["prompt"] = item["instruction"]
                item["response"] = ""
            cls_outputs = filter_model.classify(cls_inputs)
            for i, cls_output in enumerate(cls_outputs):
                if (
                        "prompt_harmfulness" in filter_model.get_output_fields()
                        and cls_output.prompt_harmfulness == PromptHarmfulness.HARMFUL
                ):
                    outputs[i] = "I'm sorry, I can't help with that."
                if (
                        "response_harmfulness" in filter_model.get_output_fields()
                        and cls_output.response_harmfulness == ResponseHarmfulness.HARMFUL
                ):
                    outputs[i] = "I'm sorry, I can't help with that."
            if isinstance(getattr(filter_model, "model", None), vllm.LLM):
                delete_vllm_model(filter_model.model)
            del filter_model
            torch.cuda.empty_cache()
            gc.collect()

        return outputs

    def delete_model_from_file_system(self):
        # remove the model from the cache
        assert self.delete_model_after_job_finished
        print(f"Removing the model from the cache: {self.model_name_or_path}")
        os.system(f"rm -rf {self.model_name_or_path}")


class APIGeneratorModel(GeneratorModelBase):
    def load_model(self, model_name_or_path: str) -> str | str:
        assert (
                ":" in model_name_or_path
        ), "API model requires format <provider>:<model_identifier>"
        provider, identifier = model_name_or_path.split(":", 1)
        return provider, identifier

    def generate_completions(
            self,
            inputs: list[dict[str, str]],
            max_new_tokens: int,
            temperature: float,
            top_p: float,
            return_full_outputs: bool = False,
    ) -> list[str]:
        assert self.batch_size is not None, "Batch size must be provided for API model."
        provider, identifier = self.load_model(
            model_name_or_path=self.model_name_or_path
        )
        if provider.lower() == "openai":
            formatted_inputs = [{"prompt": item["instruction"] for item in inputs}]
            outputs = query_openai_chat_model(
                identifier,
                formatted_inputs,
                batch_size=self.batch_size,
                max_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        else:
            raise NotImplementedError("Only OpenAI API models are supported for now.")

        if self.filter_classifier_name is not None:
            filter_model = load_classifier_model(
                self.filter_classifier_name,
                local_model_path=self.filter_model_override_path,
            )
            assert "prompt_harmfulness" in filter_model.get_output_fields()
            cls_inputs = deepcopy(inputs)
            for item in cls_inputs:
                assert "prompt" in item
                item["response"] = ""
            cls_outputs = filter_model.classify(inputs)
            for i, cls_output in enumerate(cls_outputs):
                if cls_output.prompt_harmfulness == PromptHarmfulness.HARMFUL:
                    outputs[i] = "I'm sorry, I can't help with that."
            if isinstance(getattr(filter_model, "model", None), vllm.LLM):
                delete_vllm_model(filter_model.model)
            del filter_model
            torch.cuda.empty_cache()
            gc.collect()

        return outputs


class TaskBase(abc.ABC):
    def __init__(self, **kwargs):
        """
        Initialize the task for evaluation.
        """
        self.data = self.put_id_on_data(self.load())
        assert len(self.data) > 0, "No data loaded for the task."
        assert all("id" in d for d in self.data), "All data must have an 'id' field."

    def put_id_on_data(self, data: list[dict]) -> list[dict]:
        """
        Put an 'id' field on the data.

        Args:
            data: List of data.

        Returns:
            List of data with 'id' field.
        """
        return [
            {"id": f"{self.__class__.__name__}/{i}", **d} for i, d in enumerate(data)
        ]

    @abc.abstractmethod
    def required_input_fields(self) -> list[str]:
        """
        Get the required input fields for the task.

        Returns:
            List of required input fields.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def load(self) -> list[dict]:
        """
        Load the task data. It will load data to self.data.

        Returns:
            List of task data.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _evaluate(
            self, model: GeneratorModelBase | SafetyClassifierBase
    ) -> tuple[dict, list[dict]]:
        """
        Evaluate the model on the given task.

        Args:
            model: Model to evaluate.

        Returns:
            A dictionary of evaluation results, and a list of individual evaluation results per instance.
        """
        raise NotImplementedError

    def evaluate(
            self, model: GeneratorModelBase | SafetyClassifierBase
    ) -> tuple[dict, list[dict]]:
        report, individual_results = self._evaluate(model)
        assert all(
            "id" in d for d in individual_results
        ), "All individual results must have an 'id' field."
        assert len(set(d["id"] for d in individual_results)) == len(
            individual_results
        ), "All individual results must have unique 'id' field."
        return report, individual_results


class ClassificationTaskBase(TaskBase, abc.ABC):
    @abc.abstractmethod
    def required_output_fields(self) -> list[list[str]]:
        """
        Get the required output fields for the task.

        Returns:
            List of required output field lists.
            Each list within the return is a valid set of required output fields.
        """
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def label_field(self) -> str:
        """
        Get the label field for the task.

        Returns:
            Label field.
        """
        raise NotImplementedError

    def display_data_stats(self):
        print(f"Field: {self.label_field}")
        counter = Counter([d[self.label_field] for d in self.data])
        for k, v in counter.items():
            print(f"{k}: {v}")

    def validate_if_model_is_compatible(self, model: SafetyClassifierBase) -> bool:
        """
        Validate if the model is compatible with the task.

        Args:
            model: Model to validate.

        Returns:
            True if the model is compatible with the task, False otherwise.
        """
        input_compatible = all(
            field in model.get_possible_input_fields()
            for field in self.required_input_fields()
        )
        assert input_compatible, (
            f"Model input fields {model.get_possible_input_fields()} "
            f"are not compatible with the task input fields {self.required_input_fields()}"
        )
        output_compatible = any(
            all(field in model.get_output_fields() for field in fieldset)
            for fieldset in self.required_output_fields()
        )
        assert output_compatible, (
            f"Model output fields {model.get_output_fields()} "
            f"are not compatible with the task output fields {self.required_output_fields()}"
        )
        return input_compatible and output_compatible

    @abc.abstractmethod
    def _evaluate(self, model: SafetyClassifierBase) -> tuple[dict, list[dict]]:
        raise NotImplementedError


class OpenEndedTaskBase(TaskBase, abc.ABC):
    def validate_if_model_is_compatible(self, model: GeneratorModelBase) -> bool:
        """
        Validate if the model is compatible with the task.

        Args:
            model: Model to validate.

        Returns:
            True if the model is compatible with the task, False otherwise.
        """
        input_compatible = all(
            field in model.model_input_template
            for field in self.required_input_fields()
        )
        assert input_compatible, (
            f"Model input fields {model.model_input_template} "
            f"are not compatible with the task input fields {self.required_input_fields()}"
        )
        return input_compatible

    @abc.abstractmethod
    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        raise NotImplementedError


class JailbreakBase(abc.ABC):
    def __init__(self, task: OpenEndedTaskBase, **kwargs):
        """
        Initialize the jailbreak attack for the given task.
        """
        self.task = task

    @abc.abstractmethod
    def apply_attack(self) -> list[dict]:
        """
        Apply the attack to the task data.

        Returns:
            List of attacked data.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        """
        Evaluate the model on the given task.
            There are mainly two stages: first, apply attack to the task data. Then, evaluate the model on the attacked data.

        Args:
            model: Model to evaluate.

        Returns:
            A dictionary of evaluation results, and a list of individual evaluation results per instance.
        """
        raise NotImplementedError
