import vllm
from transformers import AutoTokenizer
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel

from jailbreakbench.config import CHAT_TEMPLATES, HF_MODEL_NAMES
from jailbreakbench.llm.llm_output import LLMOutput
from jailbreakbench.llm.llm_wrapper import LLM, _count_hf_tokens


class LLMvLLM(LLM):
    def __init__(self, model_name: str):
        """Initializes the LLMvLLM with the specified model name."""
        super().__init__(model_name)
        self.hf_model_name = HF_MODEL_NAMES[self.model_name]
        destroy_model_parallel()
        self.model = vllm.LLM(model=self.hf_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_name)
        # Manually set the chat template is not a default one from HF
        if self.tokenizer.chat_template is None:
            self.tokenizer.chat_template = CHAT_TEMPLATES[self.model_name]
        if self.temperature > 0:
            self.sampling_params = vllm.SamplingParams(
                temperature=self.temperature, top_p=self.top_p, max_tokens=self.max_n_tokens
            )
        else:
            self.sampling_params = vllm.SamplingParams(temperature=0, max_tokens=self.max_n_tokens)

    def update_max_new_tokens(self, n: int | None):
        """Update the maximum number of generation tokens."""

        if n is not None:
            self.sampling_params.max_tokens = n

    def query_llm(self, inputs: list[list[dict[str, str]]]) -> LLMOutput:
        """Query the LLM with supplied inputs."""

        def convert_query_to_string(query) -> str:
            """Convert query into prompt string."""
            return self.tokenizer.apply_chat_template(query, tokenize=False, add_generation_prompt=True)

        prompts = [convert_query_to_string(query) for query in inputs]
        outputs = self.model.generate(prompts, self.sampling_params, use_tqdm=False)
        # Get output from each input, but remove initial 2 spaces
        responses = [output.outputs[0].text.strip(" ") for output in outputs]
        prompt_tokens, completion_tokens = _count_hf_tokens(self.tokenizer, prompts, responses)
        return LLMOutput(responses=responses, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
