import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from .Model import Model


class Llama3(Model):
    def __init__(self, config, args):
        super().__init__(config)
        self.max_output_tokens = int(config["params"]["max_output_tokens"])
        self.device = config["params"]["device"]

        api_pos = int(config["api_key_info"]["api_key_use"])
        hf_token = config["api_key_info"]["api_keys"][api_pos]

        self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=hf_token)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(self.name, torch_dtype=torch.float16, use_auth_token=hf_token).to(self.device)
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

        self.step = args.step

    def query(self, msg):
        input_ids = self.tokenizer(msg, return_tensors="pt").input_ids.to("cuda")
        outputs = self.model.generate(
            input_ids,
            temperature=self.temperature,
            max_new_tokens=self.max_output_tokens,
            num_beams=3,
            early_stopping=True,
            pad_token_id=self.tokenizer.pad_token_id)
        out = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        result = out[len(msg):]
        return result
