#!/usr/bin/env python3



from transformers.generation import GenerationConfig

from partnr.llm.hf_model import HFModel


class Mixtral(HFModel):
    """Load Mixtral using Hugging Face (HF)"""

    def __init__(self, conf):
        super().__init__(conf)

    def generate_hf(self, prompt, stop, max_length):
        """Generate the instruction using hf"""
        # Prepare the model input from prompt
        model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        # Set the generating parameters
        gen_cfg = GenerationConfig.from_model_config(self.model.config)
        gen_cfg.max_new_tokens = max_length
        gen_cfg.do_sample = self.generation_params.do_sample
        gen_cfg.num_return_sequences = self.generation_params.n
        gen_cfg.num_beams = self.generation_params.best_of
        gen_cfg.temperature = self.generation_params.temperature
        gen_cfg.repetition_penalty = self.generation_params.repetition_penalty
        gen_cfg.top_k = self.generation_params.top_k
        gen_cfg.top_p = self.generation_params.top_p
        gen_cfg.output_scores = True
        gen_cfg.return_dict_in_generate = True

        # Generate the response
        self.response_raw = self.model.generate(
            **model_inputs,
            generation_config=gen_cfg,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        # Only decode the response, not including the prompt
        input_prompt_len = model_inputs.input_ids.shape[1]
        # Decode the response
        decode_text = self.tokenizer.batch_decode(
            self.response_raw["sequences"][:, input_prompt_len:],
            skip_special_tokens=True,
        )

        # Process the response
        if self.generation_params.batch_response:
            # Return a list of response
            self.response = [res.split(stop)[0].rstrip() for res in decode_text]
        else:
            # Return a single response
            self.response = decode_text[0].split(stop)[0]
            self.response = self.response.rstrip()