import torch
# import fastchat.model
# import fastchat
import gc
# import spacy
from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.messages import UserMessage, SystemMessage
from datetime import datetime
import psutil
import time
class LLM:
    def __init__(self, llm_name, dtype:str, llm_path, tokenizer_path, device, debug, parameter=None):
        self.hooks_force = []
        self.llm_name = llm_name
        self.device = device
        self.debug = debug
        # self.parameter = parameter # 似乎参数量并不是一个重要的元素  
        self.dtype = dtype  
        self.parameter = parameter

        self.load(llm_path, tokenizer_path)
        if self.llm_name == "llama2" or self.llm_name == "llama3" or self.llm_name == "qwen2" or self.llm_name == "llama3.1" or self.llm_name == "qwen2.5" or self.llm_name == "qwq":
            self.eos_token_id = self.tokenizer.eos_token_id
        # self.dtype = self.model.dtype
        if self.llm_name == "llama2" or self.llm_name == "llama3" or self.llm_name == "qwen2" or self.llm_name == "llama3.1" or self.llm_name == "qwen2.5" or self.llm_name == "qwq":
            self.eos_token_id = self.tokenizer.eos_token_id
        elif self.llm_name == "mistral":
            self.eos_token_id = self.tokenizer.instruct_tokenizer.tokenizer.eos_id
        elif self.llm_name == "phi3":
            # if eos_token_id == self.tokenizer.eos_token_id, 则为<|endoftext|>，但实际上为<|end|>
            self.eos_token_id = 100266

        if self.llm_name == 'llama2':
            self.embed_matrix = self.model.model.get_input_embeddings().weight
        elif self.llm_name == "qwen2":
            self.embed_matrix = self.model.model.get_input_embeddings().weight
        elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
            self.embed_matrix = self.model.model.get_input_embeddings().weight
        elif self.llm_name == "mistral":
            self.embed_matrix = self.model.model.get_input_embeddings().weight
        elif self.llm_name == "phi3":
            self.embed_matrix = self.model.model.get_input_embeddings().weight
        elif self.llm_name == "qwen2.5"  or self.llm_name == "qwq":
            # model.model.base_model.embed_tokens.weight
            # 获取embedding: model.model.base_model.embed_tokens([input_ids]): input_ids为一维tensor
            self.embed_matrix = self.model.model.get_input_embeddings().weight
        else:
            assert None
    def get_layer_count(self):
        if self.llm_name == "llama2":
            self.layer_count = len(self.model.model.layers)
        elif self.llm_name == "qwen2":
            self.layer_count = len(self.model.model.layers)
        elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
            self.layer_count = len(self.model.model.layers)
        elif self.llm_name == "mistral":
            self.layer_count = len(self.model.model.layers)
        elif self.llm_name == "phi3":
            self.layer_count = len(self.model.model.layers)
        return self.layer_count

    def load(self, llm_path, tokenizer_path):
        if self.llm_name == "llama2":
            from transformers import AutoTokenizer, LlamaForCausalLM
            self.model = LlamaForCausalLM.from_pretrained(llm_path).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.model.resize_token_embeddings(len(self.tokenizer))
        elif self.llm_name == 'qwen2':
            from transformers import AutoModelForCausalLM, AutoTokenizer
            self.model = AutoModelForCausalLM.from_pretrained(
                llm_path,
                # torch_dtype=torch.float16,
            ).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        elif self.llm_name == "llama3":
            from transformers import AutoTokenizer, LlamaForCausalLM
            self.model = LlamaForCausalLM.from_pretrained(llm_path).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.model.resize_token_embeddings(len(self.tokenizer))
        elif self.llm_name == "mistral":
            from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM
            self.model = AutoModelForCausalLM.from_pretrained(llm_path).to(self.device)
            self.tokenizer = MistralTokenizer.from_file(f"{llm_path}/tokenizer.model.v3")
        elif self.llm_name == "phi3":
            from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM
            self.model = AutoModelForCausalLM.from_pretrained(llm_path, trust_remote_code=True, torch_dtype="auto").to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
        elif self.llm_name == "llama3.1":
            from transformers import AutoModelForCausalLM, AutoTokenizer
            from transformers import BitsAndBytesConfig
            import torch
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            if self.dtype == "int8":
                quantization_config = BitsAndBytesConfig(load_in_8bit=True,llm_int8_enable_fp32_cpu_offload=True)
                self.model = AutoModelForCausalLM.from_pretrained(llm_path, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
            elif self.dtype == "float16":
                self.model = AutoModelForCausalLM.from_pretrained(llm_path, device_map="auto", torch_dtype=torch.bfloat16)
        elif self.llm_name == "qwen2.5"  or self.llm_name == "qwq":
            from transformers import AutoModelForCausalLM, AutoTokenizer
            import torch
            self.tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side='left')
            if self.dtype == "int8":
                from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
                from transformers import AutoTokenizer
                quantize_config = BaseQuantizeConfig(
                    bits=8, # 4 or 8
                    group_size=128,
                    damp_percent=0.01,
                    desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad
                    static_groups=False,
                    sym=True,
                    true_sequential=True,
                    model_name_or_path=None,
                    model_file_base_name="model"
                )
                # max_len = 8192
                self.model = AutoGPTQForCausalLM.from_pretrained(llm_path, quantize_config, device_map="auto")
                self.device = self.model.device
            elif self.dtype == "float16":
                # 正常工作
                self.model = AutoModelForCausalLM.from_pretrained(
                    llm_path,
                    torch_dtype=torch.float16,
                    device_map="auto"
                )
                self.device = self.model.device



    '''
    Convert single token into its corresponding id
    token: str
    return: int
    '''
    def convert_token_to_ids(self, token):
        if self.llm_name == "llama2":
            # id = self.tokenizer.convert_tokens_to_ids(token)
            id = self.tokenizer.encode(token, add_special_tokens = False)[0]
        elif self.llm_name == "qwen2" or self.llm_name == "qwen2.5":
            id = self.tokenizer.encode(token)[0]
        elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
            id = self.tokenizer.encode(token, add_special_tokens = False)[0]
        elif self.llm_name == "mistral":
            id = self.tokenizer.instruct_tokenizer.tokenizer.encode(token, bos = False, eos=  False)[0]
        elif self.llm_name == "phi3":
            id = self.tokenizer.encode(token, add_special_tokens = False)[0]

        return id 
    '''
    id: list
    return: str
    '''
    def convert_id_to_token(self,id):
        if self.llm_name == "llama2":
            token = self.tokenizer.decode(id)
        elif self.llm_name == "qwen2" or self.llm_name == "qwen2.5":
            token = self.tokenizer.decode(id)
        elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
            token = self.tokenizer.decode(id)
        elif self.llm_name == "mistral":
            token = self.tokenizer.instruct_tokenizer.tokenizer.decode(id)
        elif self.llm_name == "phi3":
            token = self.tokenizer.decode(id)

        return token 
    '''
    input_ids_list: List[List]
    '''
    def decode(self, input_ids_list, skip_special_tokens = False):
        string_list = []
        for input_ids in input_ids_list:
            if self.llm_name == "llama2":
                string = self.tokenizer.decode(input_ids, skip_special_tokens = skip_special_tokens)
            elif self.llm_name == "qwen2" or self.llm_name == "qwen2.5":
                string = self.tokenizer.decode(input_ids, skip_special_tokens = skip_special_tokens)
            elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
                string = self.tokenizer.decode(input_ids, skip_special_tokens = skip_special_tokens)
            elif self.llm_name == "mistral":
                string = self.tokenizer.instruct_tokenizer.tokenizer.decode(input_ids, skip_special_tokens = skip_special_tokens)
            elif self.llm_name == "phi3":
                string = self.tokenizer.decode(input_ids, skip_special_tokens = skip_special_tokens)
            string_list.append(string)
            
        return string_list
    '''
    Without add any special and structure tokens to it.
    '''
    def clean_tokenize(self, text_list):
        if self.llm_name == "llama2":
            input_list = text_list
            inputs = self.tokenizer(input_list, padding = True, add_special_tokens = False)
            input_ids = torch.tensor(inputs.input_ids, device = self.device)
            attention_mask = torch.tensor(inputs.attention_mask, device = self.device)
            return input_ids, attention_mask 
        elif self.llm_name == "qwen2" or self.llm_name == "qwen2.5"  or self.llm_name == "qwq":
            input_list = text_list
            inputs = self.tokenizer(input_list, padding = True, add_special_tokens = False)
            input_ids = torch.tensor(inputs.input_ids, device = self.device)
            attention_mask = torch.tensor(inputs.attention_mask, device = self.device)
            return input_ids, attention_mask 
        elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
            input_list = text_list
            inputs = self.tokenizer(input_list, padding = True, add_special_tokens = False)
            input_ids = torch.tensor(inputs.input_ids, device = self.device)
            attention_mask = torch.tensor(inputs.attention_mask, device = self.device)
            return input_ids, attention_mask 
        elif self.llm_name == "mistral":
            input_list = text_list
            input_ids = [self.instruct_tokenizer.tokenizer.encode(input_list[0], bos = False, eos= False)]
            input_ids = torch.tensor(input_ids, device = self.device)
            attention_mask = torch.ones_like(input_ids, device = self.device)
            return input_ids, attention_mask 
        elif self.llm_name == "phi3":
            input_list = text_list
            inputs = self.tokenizer(input_list, padding = True, add_special_tokens = False)
            input_ids = torch.tensor(inputs.input_ids, device = self.device)
            attention_mask = torch.tensor(inputs.attention_mask, device = self.device)
            return input_ids, attention_mask 

            
    '''
    Add specific structure and tokenize,同时generation是true
    text_list: List[str]: 
    '''
    def tokenize(self, text_list):
        if self.llm_name == "llama2":
            input_list = []
            # Attention, structured format is required
            for text in text_list:
                template = fastchat.model.get_conversation_template("llama-2")
                template.append_message(message = text,role =  r"[\INST]")
                input = template.get_prompt()
                # 
                input = input + r"[/INST]" + " "
                input_list.append(input)
            
            inputs = self.tokenizer(input_list, padding = True)
            input_ids = torch.tensor(inputs.input_ids, device = self.device)
            attention_mask = torch.tensor(inputs.attention_mask, device = self.device)
            return input_ids, attention_mask 
        elif self.llm_name == "qwen2" or self.llm_name == "llama3" or self.llm_name == "phi3" or self.llm_name == "llama3.1" or self.llm_name == "qwen2.5"  or self.llm_name == "qwq":
            messages_list = []
            for text in text_list:
                # {"role": "system", "content": "I like you"},
                messages = [
                    {"role": "user", "content": text}
                ]
                template_text = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                messages_list.append(template_text)

            model_inputs = self.tokenizer(messages_list, return_tensors="pt", padding = True).to(self.device)
            input_ids = model_inputs.input_ids
            attention_mask = model_inputs.attention_mask
            return input_ids, attention_mask
        elif self.llm_name == "mistral":
            messages=[UserMessage(content=text_list[0])]
            completion_request = ChatCompletionRequest(messages = messages)
            input_ids = torch.tensor([self.tokenizer.encode_chat_completion(completion_request).tokens], device = self.device)
            # generated_ids = self.model.generate(
            #     torch.tensor([tokens]).to(self.device),
            #     max_length = 20000
            # )
            attention_mask = torch.ones_like(input_ids, device= self.device)
            return input_ids, attention_mask
    '''
    input_ids: torch.tensor, suppose(batch_size is essential)
    attention_mask: torch.tensor
    # 返回值形状为[batch_size, vocab_size]
    '''
    def generate_one_token(self, input_ids= None, input_embeds = None,attention_mask = None):
        if self.llm_name == "llama2" or self.llm_name == "llama3" or self.llm_name == "phi3" or self.llm_name == "llama3.1":
            # input_ids = torch.tensor(self.tokenizer(text_list).input_ids).to(self.device)
            with torch.no_grad():
                if input_ids != None:
                    outputs = self.model(input_ids = input_ids, attention_mask = attention_mask)
                elif input_embeds != None:
                    outputs = self.model(inputs_embeds = input_embeds, attention_mask = attention_mask)
                logits = outputs.logits
                
            return logits[:, -1, :]
        elif self.llm_name == "qwen2" or self.llm_name == "qwen2.5" or self.llm_name == "qwq":
            # input_ids = torch.tensor(self.tokenizer(text_list).input_ids).to(self.device)
            with torch.no_grad():
                if input_ids != None:
                    outputs = self.model(input_ids = input_ids, attention_mask = attention_mask)
                elif input_embeds != None:
                    outputs = self.model(inputs_embeds = input_embeds, attention_mask = attention_mask)
                logits = outputs.logits
                
            return logits[:, -1, :]
        elif self.llm_name == "mistral":
            with torch.no_grad():
                if input_ids != None:
                    outputs = self.model(input_ids = input_ids, attention_mask = attention_mask)
                elif input_embeds != None:
                    outputs = self.model(inputs_embeds = input_embeds, attention_mask = attention_mask)
                logits = outputs.logits
    

    '''
    input_ids: [batch_size, seq_len], 当然要求batch_size必须为1
    返回值: input_ids: Tensor:[batch_size, seq_len]
           logits_list: LIST:[batch_size, seq_len, vocab_size],注意到返回值的2个batch_size 必须为1
    '''
    def generate_complete_tokens_one_batch(self, input_ids = None, attention_mask = None,  max_seq_length = 10000, display_count = 50):
        ori_input_ids_len = input_ids.shape[1]
        if input_ids.shape[0] != 1:
            raise ValueError("Batch size must be 1 in generate_complete_tokens_one_batch")
        logits_tensor = None
        start_time = time.time()
        for i in range(max_seq_length):
            logits = self.generate_one_token(input_ids = input_ids, attention_mask = attention_mask)
            if logits_tensor is None:
                logits_tensor = logits.unsqueeze(1)
            else:
                logits_tensor = torch.cat((logits_tensor, logits.unsqueeze(1)), dim = 1)
            # logits_list[0].append(logits[0])
            tokens = torch.argmax(logits, dim = -1).unsqueeze(1)
            input_ids = torch.cat((input_ids, tokens), dim = -1)
            attention_mask = torch.cat((attention_mask, torch.ones_like(tokens)), dim = -1)
            gc.collect()
            torch.cuda.empty_cache()
            if i != 0 and i % display_count == 0:
                # current_time = datetime.now()
                current_time = time.time()
                print("-----------------------------------")
                print(f"经过了{current_time - start_time}秒")
                print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: Finish {i} tokens') 
                print(f"解码内容是: {self.tokenizer.decode(input_ids[0][ori_input_ids_len:], skip_special_tokens = True)}")
                # 我还希望记录在上一次打印过去了多久

                print("-----------------------------------")
            if tokens.item() == self.eos_token_id:
                return input_ids, logits_tensor
        return input_ids, logits_tensor
    def generate_complete_tokens_default(self, input_ids = None, attention_mask = None,  max_seq_length = 10000):
        if self.llm_name == "qwen2" or self.llm_name == "llama3" or self.llm_name == "phi3" or self.llm_name == "llama3.1" or self.llm_name == "qwen2.5" or self.llm_name == "qwq":
            with torch.no_grad():
                generated_ids = self.model.generate(
                        input_ids = input_ids,
                        attention_mask = attention_mask,
                        max_length = max_seq_length
                    )
            return generated_ids
    '''
    似乎是删除了前面的废话的decode过程,generated_ids形状为(batch_size, seq_len)
    '''
    def decode_delelte_previous_tokens(self, generated_ids = None):
        output_list = []
        def get_substring_after_last_occurrence(str_a, str_b):
            # Find the index of the last occurrence of str_a in str_b
            last_index = str_b.rfind(str_a)
            # If str_a is not found, return an empty string
            if last_index == -1:
                return ""
            # Return the substring after the last occurrence of str_a
            return str_b[last_index + len(str_a):]

        if self.llm_name == "llama3" or self.llm_name == "llama3.1":
            truncate_part = "<|start_header_id|>assistant<|end_header_id|>\n\n"
        elif self.llm_name == "phi3":
            truncate_part = "<|assistant|>\n"
        elif self.llm_name == "qwen2" or self.llm_name == "qwen2.5":
            truncate_part = '<|im_start|>assistant\n' 
        else:
            raise ValueError("Unsupported model name for decoding")

        for ids in generated_ids:
            complete_string = self.tokenizer.decode(ids)
            truncated_string = get_substring_after_last_occurrence(truncate_part, complete_string)
            truncated_string = truncated_string.replace(self.tokenizer.pad_token, "")
            
            output_list.append(truncated_string)

        return output_list

    '''
    这个函数将大量的input_ids，以batch_size进行划分输入到模型中。最后将模型的输出进行解码为英文，再将英文拼接在list中，进行返回
    '''
    def generate_complete_tokens_default_with_batch_size(self, input_ids = None, attention_mask = None,  max_seq_length = 10000, batch_size = 4):
        if self.llm_name == "qwen2" or self.llm_name == "llama3" or self.llm_name == "phi3" or self.llm_name == "llama3.1" or self.llm_name == "qwen2.5" or self.llm_name == "qwq":
            input_ids_chunks = input_ids.split(batch_size)
            input_ids_chunk_list = [chunk for chunk in input_ids_chunks]
            attention_mask_chunks = attention_mask.split(batch_size)
            attention_mask_chunk_list = [chunk for chunk in attention_mask_chunks]
            output_list = []
            for index in range(len(input_ids_chunk_list)):
                # generated_ids = self.model.generate(
                #         input_ids = input_ids_chunk_list[index],
                #         attention_mask = attention_mask_chunk_list[index],
                #         max_length = max_seq_length
                    # )
                generated_ids = self.generate_complete_tokens_default(input_ids_chunk_list[index],attention_mask_chunk_list[index],max_seq_length)
                # 在decoded_string_list存在非常不完美的情况，即大量的eos_token混入其中。解决方法: 将其编码之后，再解码
                decoded_string_list = self.decode_delelte_previous_tokens(generated_ids)
                # decoded_string_delete_eos_list = []
                # for i in range(len(decoded_string_list)):
                _ids = self.tokenizer(decoded_string_list, padding = False, add_special_tokens = False)
                _decode_string_list = self.tokenizer.batch_decode(_ids.input_ids, skip_special_tokens = True)
                # print(decoded_string_list)
                output_list += _decode_string_list
                current_time = datetime.utcnow()
                # print(f"{current_time.strftime('%a %b %d %H:%M:%S UTC %Y')}: Finish 1 batch")
            return output_list
    
    '''
    这个函数是自己写的使用最简单的贪心算法实现的，所以，请尽量不使用本函数
    input_ids: torch.tensor, suppose(batch_size is essential)
    attention_mask: torch.tensor

    '''
    def generate_complete_tokens(self, input_ids = None, input_embeds = None, attention_mask = None, max_seq_length = None):
        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = input_embeds.shape[0]
            output_ids = torch.zeros(batch_size, 1).to(self.device).to(torch.int)
        generating_trigger = [True] * batch_size
        logits_list = []
        for i in range(max_seq_length):
            if True not in generating_trigger:
                break
            if input_ids is not None:
                logits = self.generate_one_token(input_ids = input_ids, attention_mask = attention_mask)
            else:
                logits = self.generate_one_token(input_embeds =  input_embeds, attention_mask = attention_mask)
            logits_list.append(logits)
            # (batch_size, 1)
            tokens = torch.argmax(logits, dim = -1).unsqueeze(1)
            if input_ids is not None:
                input_ids = torch.cat((input_ids, tokens), dim = -1)
            else:
                one_hot_tensor = torch.zeros(input_embeds.shape[0], 1, self.embed_matrix.shape[0], device = self.device)
                if self.llm_name == "phi3":
                    one_hot_tensor = one_hot_tensor.to(self.model.dtype) 
                one_hot_tensor.scatter_(2, tokens.unsqueeze(-1), 1)
                new_embeds = one_hot_tensor @ self.embed_matrix
                input_embeds = torch.cat((input_embeds, new_embeds), dim = 1)
                output_ids = torch.cat((output_ids, tokens), dim = -1)
            attention_mask = torch.cat((attention_mask, torch.ones_like(tokens,device = self.device)), dim = -1)

            for index in range(batch_size):
                if generating_trigger[index]:
                    token = tokens[index]
                    if token.item() == self.eos_token_id:
                        generating_trigger[index] = False
        if input_ids is not None:
            return input_ids, logits_list
        else:
            return input_embeds, logits_list, output_ids[:, 1:]

    '''
    hook hidden state to obtain the clean data
    '''
    def hook_hs(self):
        self.hooks = []
        self.layer_outputs = []
        if self.llm_name == "llama2" or self.llm_name == "qwen2" or self.llm_name == "llama3" or self.llm_name == "mistral" or self.llm_name == "phi3":
            def hook_fn(module, input, output):
                self.layer_outputs.append(output)
            for i, layer in enumerate(self.model.model.layers):
                hook = layer.register_forward_hook(hook_fn)
                self.hooks.append(hook)

    def remove_hooks_hs(self):
        for hook in self.hooks:
            hook.remove()

    def get_hs(self):
        return self.layer_outputs

    def clean_hs(self):
        self.layer_outputs = []
        torch.cuda.empty_cache()
        gc.collect()
        
    '''
    The function is disigned to force a specfic layer to output a targeted value.
    '''
    def hook_force_output(self, layer_num, word_index, output_value: torch.tensor):
        self.hooks_force = []
        if self.llm_name == "llama2" or self.llm_name == "qwen2" or self.llm_name == "llama3" or self.llm_name == "mistral" or self.llm_name == "phi3" or self.llm_name == "llama3.1" or self.llm_name == "qwen2.5":
            def hook_fn(module, input, output):
                output[0].data[ : , word_index, :] = output_value.clone()[:, word_index, :]
            # for i, layer in enumerate(self.model.model.layers):
            hook = self.model.model.layers[layer_num].register_forward_hook(hook_fn)
            self.hooks_force.append(hook)
    def remove_hs_force_output(self):
        for hook in self.hooks_force:
            hook.remove()
    # '''
    # word: single word string
    # '''
    # def if_token_splited(self, word):
    #     if self.llm_name == "llama2":
    #         input_ids = self.tokenizer(word, add_special_tokens = False)
    #         if len(input_ids) > 1:
    #             return True
    #         else:
    #             return False

    # Function to check if all nouns in a sentence are split into multiple tokens
    # def check_nouns_tokenization(self, sentence):
    #     # Process the sentence with SpaCy
    #     nlp = spacy.load("en_core_web_sm")
    #     doc = nlp(sentence)
        
    #     # Extract nouns
    #     nouns = [token.text for token in doc if token.pos_ == "NOUN" or token.pos_ == "PROPN"]
        
    #     # Check tokenization for each noun
    #     results_dict = {}
    #     results_list = []
    #     for noun in nouns:
    #         if self.llm_name == "llama2":
    #             tokens = self.tokenizer.encode(noun, add_special_tokens = False)
    #         results_dict[noun] = tokens
    #         results_list.append(len(tokens) > 1)
        
    #     return results_dict, results_list
    '''
    请不用管这个函数，这是上一个项目的结果hhh
    '''
    '''
    def get_noun_position(self, sentence, index):
        nlp = spacy.load("en_core_web_sm")
        doc = nlp(sentence)

        # Extract nouns and their character offsets
        nouns = [(token.text, token.idx) for token in doc if token.pos_ in ['NOUN', 'PROPN']]

        # Initialize the list to store the positions of noun tokens in input_ids
        noun_positions = []
        # none_index = []
        if self.llm_name == "llama2":
            input_ids = self.tokenizer.encode(sentence, add_special_tokens = False)
        elif self.llm_name == "mistral":
            input_ids = self.tokenizer.instruct_tokenizer.tokenizer.encode(sentence, bos=False, eos=False)
        elif self.llm_name == "qwen2" and index == 0:
            input_ids = self.tokenizer.encode(sentence, add_special_tokens = False)
        elif self.llm_name == "qwen2" and index == 1:
            input_ids = self.tokenizer.encode(" " + sentence, add_special_tokens = False)
        elif self.llm_name == "llama3" and index == 0:
            input_ids = self.tokenizer.encode(sentence, add_special_tokens = False)
        elif self.llm_name == "llama3" and index == 1:
            input_ids = self.tokenizer.encode(" " + sentence, add_special_tokens = False)
        elif self.llm_name == "phi3" and index == 0:
            input_ids = self.tokenizer.encode(sentence, add_special_tokens = False)
        elif self.llm_name == "phi3" and index == 1:
            input_ids = self.tokenizer.encode(" " + sentence, add_special_tokens = False)
        # Iterate over tokenized input_ids to find the positions of nouns
        for noun, char_offset in nouns:
            # Find the position of the noun in the tokenized input_ids
            if char_offset == 0:
                continue
            if self.llm_name == "llama2":
                token_ids = self.tokenizer.encode(noun, add_special_tokens = False)
            elif self.llm_name == "mistral":
                token_ids = self.tokenizer.instruct_tokenizer.tokenizer.encode(noun, bos=False, eos=False) 
            elif self.llm_name == "qwen2" or self.llm_name == "llama3" or self.llm_name == "phi3":
                # if char_offset == 0:
                #     token_ids = self.tokenizer.encode(noun, add_special_tokens = False)
                # else:
                token_ids = self.tokenizer.encode(" "+noun, add_special_tokens = False)
            # elif 
            # for token_id in token_ids:
            for i in range(len(input_ids)):
                succ = True
                for j in range(len(token_ids)):
                    if token_ids[j] != input_ids[i+j]:
                        succ = False
                        break
                if succ:
                    start = i
                    break
            if succ:
                end = start + len(token_ids)
                noun_positions.append([i for i in range(start, end)]) 
        return noun_positions
    '''
    '''
        sequences: {
            type: List[Dict]
            description: 带有对话记录的列表, 其具体格式为: [{"from":xxxx, "value": yyyy}]，其中xxxx代表从哪里, "human"代表从人类, "gpt", "llm"等代表从语言大模型, yyy代表具体的值
            example: [{"from": "human", "value": "Who are you?"}] 
        }
        question: {
            type: str
            description:  你的下一个问题
            exmaple: "你是谁?"
        }
    '''
    def tokenize_with_chat_sequence(self, sequences, question):
        if self.llm_name == "qwen2" or self.llm_name == "phi3" or self.llm_name == "llama3" or self.llm_name == "llama3.1":
            messages_list = []
            # assistant, user, system
            messages = []
            first_met = True
            for element in sequences:
                if element['from'] == "human" or element['from'] == "user":
                    messages.append({"role": "user", "content": element['value']})
                    if first_met:
                        first_met = False
                        user_first_index = len(messages) - 1
                elif element['from'] == "gpt" or element['from'] == "chatgpt" or element['from'] == "bing" or element['from'] == "bard" or element['from'] == "llm":
                    messages.append({"role": "assistant", "content": element['value']})
                elif element['from'] == "system":
                    messages.append({"role": "system", "content": element['value']})
                # else:
            messages.append({"role": "user", "content": question})
                    
                
            template_text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            messages_list.append(template_text)
            model_inputs = self.tokenizer(messages_list, return_tensors="pt", padding = True).to(self.device)
            input_ids = model_inputs.input_ids
            attention_mask = model_inputs.attention_mask
            if first_met:
                return input_ids, attention_mask, messages_list, messages, -1
            return input_ids, attention_mask, messages_list, messages, user_first_index

    def tokenize_with_system(self, text_list, system_prompt_list):
        if self.llm_name == "llama2":
            input_list = []
            # Attention, structured format is required
            for i, text in enumerate(text_list):
                template = fastchat.model.get_conversation_template("llama-2")
                template.set_system_message(system_prompt_list[i])
                template.append_message(message = text,role =  r"[\INST]")
                input = template.get_prompt()
                # 
                input = input + r"[/INST]" + " "
                input_list.append(input)
            messages_list = input_list
            inputs = self.tokenizer(input_list, padding = True)
            input_ids = torch.tensor(inputs.input_ids, device = self.device)
            attention_mask = torch.tensor(inputs.attention_mask, device = self.device)
            return input_ids, attention_mask, messages_list
        elif self.llm_name == "qwen2" or self.llm_name == "phi3" or self.llm_name == "qwen2.5" or self.llm_name == "qwq":
            messages_list = []
            for i, text in enumerate(text_list):
                messages = [
                    {"role": "system", "content": system_prompt_list[i]},
                    {"role": "user", "content": text}
                ]
                template_text = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                messages_list.append(template_text)

            model_inputs = self.tokenizer(messages_list, return_tensors="pt", padding = True).to(self.device)
            input_ids = model_inputs.input_ids
            attention_mask = model_inputs.attention_mask
            return input_ids, attention_mask, messages_list
        elif self.llm_name == "llama3" or self.llm_name == "llama3.1":
            messages_list = []
            for i, text in enumerate(text_list):
                messages = [
                    {"role": "system", "content": system_prompt_list[i]},
                    {"role": "user", "content": text}
                ]
                template_text = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                messages_list.append(template_text)

            model_inputs = self.tokenizer(messages_list, return_tensors="pt", padding = True).to(self.device)
            input_ids = model_inputs.input_ids
            attention_mask = model_inputs.attention_mask
            return input_ids, attention_mask, messages_list
        elif self.llm_name == "mistral":
            messages_list = []
            messages=[SystemMessage(content=system_prompt_list[0]), UserMessage(content=text_list[0])]
            completion_request = ChatCompletionRequest(messages = messages)
            input_ids = torch.tensor([self.tokenizer.encode_chat_completion(completion_request).tokens], device = self.device)
            # generated_ids = self.model.generate(
            #     torch.tensor([tokens]).to(self.device),
            #     max_length = 20000
            # )
            messages_list.append(self.tokenizer.instruct_tokenizer.tokenizer.decode(self.tokenizer.encode_chat_completion(completion_request).tokens))
            attention_mask = torch.ones_like(input_ids, device= self.device)
            return input_ids, attention_mask, messages_list
 









