import random

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig


class LLM_Wrapper():
    def __init__(self, instructions, device_id=None):
        self.instructions = instructions
        assert isinstance(self.instructions, list)
        model_path = r"/xxx/public_data/LLMs/gemma/gemma-7b-it"
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            padding_side='left',
        )
        if (device_id is None):
            device_str = 'auto'
        else:
            device_str = 'cuda:{}'.format(device_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            pad_token_id=self.tokenizer.pad_token_id,
            device_map=device_str,
            torch_dtype=torch.bfloat16,
        ).eval()
        self.model.generation_config = GenerationConfig.from_pretrained(model_path,
                                                                        pad_token_id=self.tokenizer.pad_token_id)

    def input_to_model(self, all_raw_text):
        batch_raw_text = []
        for q in all_raw_text:
            messages = [
                {"role": "user", "content": q},
            ]
            item_text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            batch_raw_text.append(item_text)

        batch_input_ids = self.tokenizer(batch_raw_text, padding='longest')
        # print('batch_input_ids :{}'.format(batch_input_ids))
        batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(self.model.device)
        # print('batch_input_ids input_ids shape:{}'.format(batch_input_ids.shape))
        batch_out_ids = self.model.generate(
            batch_input_ids,
            return_dict_in_generate=False,
            generation_config=self.model.generation_config,
            max_new_tokens=200,
        )
        print('batch_out_ids:{}'.format(batch_out_ids.shape))
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(batch_input_ids, batch_out_ids)
        ]
        # print('generated_ids:{}'.format(len(generated_ids)))
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        # print('res:{}'.format(response))
        return response

    def __call__(self, batch_items):
        all_processed_sentences = []
        for item_obj in batch_items:
            choose_instruction = random.choice(self.instructions)
            item_processd = choose_instruction.format(item_obj)
            all_processed_sentences.append(item_processd)
        ans = self.input_to_model(all_processed_sentences)
        return ans
