from typing import List

from transformers import AutoModelForCausalLM, AutoTokenizer

from ..message import Message
from .base import BaseModel, build_prompt


MODEL = None
TOKENIZER = None


class TransformersModel(BaseModel):
    """Model builder for HuggingFace Transformers model. `model` should be in the format model_name@checkpoint_path

    Call with a list of `Message` objects to generate a response.
    """

    supports_system_message = True

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        stream: bool = False,
        top_p: float = 1.0,
        max_tokens: int = 512,
        stop=None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        **kwargs,
    ):
        model_name, checkpoint_path = model.split("@")
        self.model_name = model_name
        self.checkpoint_path = checkpoint_path
        self.temperature = temperature
        self.stream = stream
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.stop = stop
        self.frequency_penalty = frequency_penalty
        self.presence_penalty = presence_penalty
        global MODEL, TOKENIZER
        if MODEL is None:
            import torch

            MODEL = AutoModelForCausalLM.from_pretrained(
                self.checkpoint_path, torch_dtype=torch.float16, device_map="auto"
            )
            TOKENIZER = AutoTokenizer.from_pretrained(self.checkpoint_path)
            if not TOKENIZER.pad_token:
                TOKENIZER.pad_token = TOKENIZER.eos_token

    def __call__(self, messages: List[Message], api_key: str = None):
        global MODEL, TOKENIZER
        prompt = build_prompt(messages, self.model_name)
        model_inputs = TOKENIZER(prompt, return_tensors="pt")
        model_inputs["input_ids"] = model_inputs["input_ids"].cuda()
        model_inputs["attention_mask"] = model_inputs["attention_mask"].cuda()
        prompt_len = model_inputs["attention_mask"].sum(dim=1)
        output = MODEL.generate(
            **model_inputs,
            max_new_tokens=self.max_tokens,
            do_sample=self.temperature > 0,
            temperature=self.temperature,
            top_p=self.top_p,
        )
        response = TOKENIZER.decode(
            output[0][prompt_len:], skip_special_tokens=True
        ).strip()
        return [response]
