import litellm
from litellm.llms.prompt_templates.factory import custom_prompt
from transformers import AutoTokenizer

from jailbreakbench.config import HF_MODEL_NAMES, LITELLM_CUSTOM_PROMPTS, LITELLM_MODEL_NAMES
from jailbreakbench.llm.llm_output import LLMOutput
from jailbreakbench.llm.llm_wrapper import LLM, _count_hf_tokens

litellm.drop_params = True


class LLMLiteLLM(LLM):
    _N_RETRIES: int = 10
    _RETRY_STRATEGY: str = "exponential_backoff_retry"

    _litellm_model_name: str
    _api_key: str
    system_prompt: str

    def __init__(self, model_name: str, api_key: str):
        """Initializes the LLMLiteLLM with the specified model name.
        For Together AI models, api_key is the API key for the Together AI API.
        For OpenAI models, api_key is the API key for the OpenAI API.
        """

        super().__init__(model_name)
        self._api_key = api_key
        self._litellm_model_name = LITELLM_MODEL_NAMES[self.model_name]
        self.generation_params: dict[str, int | float] = {}
        if self.temperature == 0:
            self.generation_params["temperature"] = 0
        else:
            self.generation_params["top_p"] = 0.9
        self.generation_params["max_tokens"] = self.max_n_tokens
        self.generation_params["seed"] = 0
        if self.model_name in LITELLM_CUSTOM_PROMPTS:
            litellm.register_prompt_template(
                model=self._litellm_model_name, roles=LITELLM_CUSTOM_PROMPTS[self.model_name]
            )
        if self.model_name in HF_MODEL_NAMES:
            self.tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAMES[self.model_name])
        else:
            self.tokenizer = None

    def update_max_new_tokens(self, n: int | None):
        """Update the maximum number of generation tokens."""

        if n is not None:
            self.generation_params["max_tokens"] = n

    def query_llm(self, inputs: list[list[dict[str, str]]]) -> LLMOutput:
        """Query the LLM with supplied inputs."""

        outputs = litellm.batch_completion(
            model=self._litellm_model_name,
            messages=inputs,
            api_key=self._api_key,
            num_retries=self._N_RETRIES,
            retry_strategy=self._RETRY_STRATEGY,
            **self.generation_params,  # type: ignore
        )

        responses: list[str] = [output["choices"][0]["message"].content.strip() for output in outputs]

        if self.model_name in HF_MODEL_NAMES:
            assert self.tokenizer is not None
            full_input_prompts = [
                custom_prompt(
                    role_dict=LITELLM_CUSTOM_PROMPTS[self.model_name],
                    messages=conv,
                )
                for conv in inputs
            ]
            prompt_tokens, completion_tokens = _count_hf_tokens(self.tokenizer, full_input_prompts, responses)
        else:
            # Do not use HF tokenizer for GPT
            prompt_tokens = [output["usage"]["prompt_tokens"] for output in outputs]
            completion_tokens = [output["usage"]["completion_tokens"] for output in outputs]

        return LLMOutput(responses=responses, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
