import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from .Model import Model


class Gemma(Model):
    def __init__(self, config, args):
        super().__init__(config)
        self.max_output_tokens = int(config["params"]["max_output_tokens"])
        self.device = config["params"]["device"]

        api_pos = int(config["api_key_info"]["api_key_use"])
        hf_token = config["api_key_info"]["api_keys"][api_pos]

        self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=hf_token)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(
            self.name,
            torch_dtype=torch.float16,
            use_auth_token=hf_token
        ).to(self.device)
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

        self.step = args.step

    def query(self, msg):
        input_ids = self.tokenizer(msg, return_tensors="pt").input_ids.to(self.device)
        outputs = self.model.generate(
            input_ids,
            temperature=self.temperature,
            max_new_tokens=self.max_output_tokens,
            early_stopping=True,
            pad_token_id=self.tokenizer.pad_token_id
        )
        out = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        result = out[len(msg):]
        return result
