import os
import random
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig
from .utils import remove_rotary_pos_emb, apply_rotary_pos_emb, concate_past_key_value


class E2RAGBase(nn.Module):
    def load_lora_parameters(self, load_path, adepter_name):
        # load the pretrained LoRA parameters
        lora_params = torch.load(load_path)
        to_load = {n.replace("default", adepter_name):p for n, p in lora_params.items()}
        self.load_state_dict(to_load, strict=False)

    def get_positions(self, input_ids):
        # get the position of self.mem_id
        mem_positions = (input_ids == self.mem_id)
        # replace self.mem_id into pad_id
        input_ids[mem_positions] = self.tokenizer.pad_token_id
        return mem_positions, input_ids
    
    def add_mem(self, input_ids, mem_positions):
        embeddings = self.llama.get_input_embeddings()(input_ids).to(self.device)
        memory_tok_embeddings = self.memory_embeddings.repeat(embeddings.shape[0], 1, 1).to(self.device)
        mem_indices = mem_positions.nonzero(as_tuple=True)  # (batch_idx, seq_idx)
        bs_indices, seq_indices = mem_indices
        embeddings[bs_indices, seq_indices, :] = memory_tok_embeddings.view(-1, memory_tok_embeddings.shape[-1])
        return embeddings

    def get_key_values(self, past_key_values, positions, num=None):
        num = self.num_mem if num == None else num
        bs_indices, seq_indices = positions.nonzero(as_tuple=True)
        trimmed_past_key_values = tuple(
            (layer_key[bs_indices, :, seq_indices, :].view(layer_key.shape[0], num, layer_key.shape[1], -1).transpose(1, 2),
             layer_value[bs_indices, :, seq_indices, :].view(layer_value.shape[0], num, layer_value.shape[1], -1).transpose(1, 2))
            for layer_key, layer_value in past_key_values
        )
        return trimmed_past_key_values

    def remove_position(self, past_key_values):
        position_ids = torch.arange(
            0, past_key_values[0][0].shape[2], device=past_key_values[0][0].device).unsqueeze(0)
        cos, sin = self.rotary_emb(past_key_values[0][0], position_ids)
        remove_position_past_key_values = tuple(
            (remove_rotary_pos_emb(layer_key, cos, sin), layer_value) for layer_key, layer_value in past_key_values
        )
        return remove_position_past_key_values
    
    def re_position(self, past_key_values):
        re_position_ids = torch.arange(
            0, past_key_values[0][0].shape[2], device=past_key_values[0][0].device).unsqueeze(0)
        re_cos, re_sin = self.rotary_emb(past_key_values[0][0], re_position_ids)
        reposition_past_key_values = tuple(
            (apply_rotary_pos_emb(layer_key, re_cos, re_sin), layer_value) for layer_key, layer_value in past_key_values
        )
        return reposition_past_key_values
    
    def compress(self, input_ids, init_past_key_values=None):
        mem_positions, input_ids = self.get_positions(input_ids)  # shape: [batch_size, seq_len]
        text_tok_embeddings = self.add_mem(input_ids, mem_positions)
        
        # compressed tokens
        # encoder input: text tokens + compressed tokens
        encoder_output = self.llama(inputs_embeds=text_tok_embeddings, output_hidden_states=True, past_key_values=init_past_key_values)
        # get the K V values for the encoder output
        past_key_values = encoder_output.past_key_values
        remove_position_past_key_values = self.remove_position(past_key_values)
        trimmed_past_key_values = self.get_key_values(remove_position_past_key_values, mem_positions)

        # save the K V values for the compressed tokens
        last_hidden_states = encoder_output.hidden_states[-1]
        bs_indices, seq_indices = mem_positions.nonzero(as_tuple=True)
        mem_last_hidden_states = last_hidden_states[bs_indices, seq_indices, :].view(last_hidden_states.shape[0], self.num_mem, -1)
        
        logits = encoder_output.logits
        mem_logits = logits[bs_indices, seq_indices, :].view(logits.shape[0], self.num_mem, -1)

        return {
            "mem_past_key_values": trimmed_past_key_values,
            "mem_last_hidden_states": mem_last_hidden_states,
            "mem_logits": mem_logits,
        }


class E2RAGPretraining(E2RAGBase):
    def __init__(
            self,
            llama_path: str,
            lora_config: LoraConfig,
            num_mem: int,
            mem_id: int,
            device_map,
            load_path: str = None,
            max_length: int = 500,
            hidden_dim: Union[int, str] = "auto"
    ):
        super(E2RAGBase, self).__init__()
        llama = AutoModelForCausalLM.from_pretrained(
            llama_path,
            torch_dtype=torch.bfloat16,
            device_map=device_map,
            trust_remote_code=True
        )
        if hidden_dim == "auto":
            hidden_dim = llama.config.hidden_size
        if hasattr(llama, "model"):
            self.rotary_emb = llama.model.rotary_emb
        else:
            self.rotary_emb = llama.transformer.rotary_emb
        self.tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print("llama tokenizer loaded.")
        self.llama = get_peft_model(llama, lora_config, adapter_name="encoder")

        self.device = self.llama.device
        self.max_length = max_length
        self.num_mem = num_mem
        self.mem_id = mem_id
        self.pre_token_cache = None
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.memory_embeddings = nn.Parameter(torch.randn(1, num_mem, hidden_dim, dtype=torch.bfloat16).to(self.device))

        self._keys_to_ignore_on_save = [name for name, _ in self.named_parameters() if 'lora' not in name and 'memory_embeddings' not in name]
        if load_path is not None and load_path != '':
            self.load_lora_parameters(load_path, "encoder")
        
        for name, param in self.named_parameters():
            param.requires_grad = False
            if 'lora' in name and 'encoder' in name:
                param.requires_grad = True
                print(name)

        print(f"Total parameters of llama: {sum(p.numel() for p in self.llama.parameters())}")
    
    def forward(self, encoder_input_ids, decoder_input_ids, decoder_label):
        ####################
        # Encoder - llama+lora
        ####################
        trimmed_past_key_values = self.compress(encoder_input_ids)["mem_past_key_values"]

        ####################
        # Decoder - llama
        ####################
        reposition_trimmed_past_key_values = self.re_position(trimmed_past_key_values)
        decoder_embeddings = self.llama.get_input_embeddings()(decoder_input_ids)
        # use the original LLM without LoRA parameters
        with self.llama.disable_adapter():
            decoder_output = self.llama(inputs_embeds=decoder_embeddings, past_key_values=reposition_trimmed_past_key_values)
        # logits for the decoder output
        all_logits = decoder_output.logits
        # calculate the cross entropy
        loss = self.criterion(all_logits.view(-1, all_logits.size(-1)), decoder_label.view(-1))

        return {'loss': loss, 'logits': all_logits}


class E2RAGFT(E2RAGPretraining):
    def forward(self, chunk_input_ids, chunk_mask, pre_prompt_tokens, a_input_ids, a_labels):
        ####################
        # Encoder - llama+lora
        ####################
        # context compress chunk_input_ids [bs, num_chunks, seq_len]
        index = chunk_mask.sum(dim=0).bool()
        chunk_input_ids = chunk_input_ids[:, index, :]
        chunk_mask = chunk_mask[:, index]
        
        chunks_past_key_values = []
        for i in range(chunk_input_ids.shape[1]):
            chunk = chunk_input_ids[:, i, :]
            chunk_compress_res = self.compress(chunk)
            chunk_past_key_values = chunk_compress_res["mem_past_key_values"]
            mask = chunk_mask[:, i].view(-1, 1, 1, 1)
            masked_chunk_past_key_values = tuple(
                    (layer_key * mask, layer_value * mask) for layer_key, layer_value in chunk_past_key_values
                )
            chunks_past_key_values.append(masked_chunk_past_key_values)
        chunks_past_key_values = concate_past_key_value(chunks_past_key_values)
        ####################
        # Decoder - llama
        ####################
        if self.pre_token_cache is None:
            with torch.no_grad():
                pre_prompt_tokens_embeddings = self.llama.get_input_embeddings()(pre_prompt_tokens)
                with self.llama.disable_adapter():
                    pre_decoder_output = self.llama(inputs_embeds=pre_prompt_tokens_embeddings)
                pre_token_cache = pre_decoder_output.past_key_values
                pre_token_cache = self.remove_position(pre_token_cache)
                self.pre_token_cache = tuple(
                    (layer_key.detach(), layer_value.detach()) for layer_key, layer_value in pre_token_cache
                )
        pre_past_key_values = concate_past_key_value([self.pre_token_cache, chunks_past_key_values])
        final_past_key_values = self.re_position(pre_past_key_values)
        
        a_tok_embeddings = self.llama.get_input_embeddings()(a_input_ids)
        with self.llama.disable_adapter():
            decoder_output = self.llama(inputs_embeds=a_tok_embeddings, past_key_values=final_past_key_values)
        all_logits = decoder_output.logits

        # target tokens: answer
        loss = self.criterion(all_logits.view(-1, all_logits.size(-1)), a_labels.view(-1))
        return {'loss': loss, 'logits': all_logits}

class E2RAGInference(E2RAGFT):
    def docompress(
            self,
            text: Union[str, list[str]],
            output_path = None):
        if isinstance(text, str):
            text_tokens = self.tokenizer(
                text,
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt",
                add_special_tokens=False
            ).input_ids.squeeze()
            encoder_input_ids = torch.cat(
                [text_tokens,
                 torch.full((self.num_mem,), self.mem_id, dtype=torch.long)]).unsqueeze(0).to(self.device)
        elif isinstance(text, list):
            encoder_input_ids_list = []
            for one in text:
                text_tokens = self.tokenizer(
                    one,
                    truncation=True, 
                    max_length=self.max_length, 
                    return_tensors="pt",
                    add_special_tokens=False
                ).input_ids.squeeze()
                if text_tokens.shape != torch.Size([]):
                    padding_length = self.max_length - text_tokens.size(0)
                else:
                    text_tokens = torch.full((1,), self.tokenizer.pad_token_id, dtype=torch.long)
                    padding_length = self.max_length - 1
                if padding_length > 0:
                    one_input_ids = torch.cat(
                        [text_tokens,
                        torch.full((self.num_mem,), self.mem_id, dtype=torch.long),
                        torch.full((padding_length,), self.tokenizer.pad_token_id, dtype=torch.long)])
                else:
                    one_input_ids = torch.cat(
                        [text_tokens[:self.max_length],
                        torch.full((self.num_mem,), self.mem_id, dtype=torch.long)])
                encoder_input_ids_list.append(one_input_ids)
            encoder_input_ids = torch.stack(encoder_input_ids_list, dim=0).to(self.device)
        else:
            pass
        self.llama.set_adapter("encoder")
        compress_res = self.compress(encoder_input_ids.to(torch.long))

        if output_path is not None:
            torch.save(compress_res, os.path.join(output_path))
            print(f"Saved compressed results to {output_path}")

        return compress_res
    
    def qa(self, chunks: list, question: Union[list, str], max_new_tokens: int):
        system_prompt = "You are an accurate and reliable AI assistant capable of answering questions by referencing external documents. Please note that the external documents may not always be related to the question. If the information in the documents contain the correct answer, you will provide an accurate response. If the documents do not contain the answer, you will refuse to answer."
        user_prompt = """The documents are as follows:
<chunks>

Question: """
        if isinstance(chunks[0], str) and isinstance(question, str):
            # bs = 1
            chunks = [chunks]
            question = [question]
        
        batch_size = len(question)
        pre_prompts = []
        last_prompts = []
        for q in question:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt + q},
            ]
            prompt_inputs = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            pre_prompts.append(prompt_inputs.split("<chunks>")[0])
            last_prompts.append(prompt_inputs.split("<chunks>")[1])

        pre_prompt_tokens = self.tokenizer(
            pre_prompts, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device)
        last_prompt_tokens = [self.tokenizer(
            last_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device) for last_prompt in last_prompts]
        max_last_prompt_len = max([last_prompt_token.shape[-1] for last_prompt_token in last_prompt_tokens])
        last_prompt_pad_len = [max_last_prompt_len - last_prompt_token.shape[-1] for last_prompt_token in last_prompt_tokens]
        last_prompt_tokens_last_token = torch.stack([last_prompt_token[:, -1] for last_prompt_token in last_prompt_tokens], dim=0)
        last_prompt_tokens_input = torch.full((batch_size, max_last_prompt_len - 1), self.tokenizer.eos_token_id, dtype=torch.long, device=self.device)
        for i in range(batch_size):
            if last_prompt_pad_len[i] == 0:
                last_prompt_tokens_input[i, :] = last_prompt_tokens[i][:, :-1]
            else:
                last_prompt_tokens_input[i, :-last_prompt_pad_len[i]] = last_prompt_tokens[i][:, :-1]

        max_num_chunks = max([len(sample) for sample in chunks])
        pad_len = [max_num_chunks - len(sample) for sample in chunks]
        paded_chunks = [["<pad><pad><pad>"] * pad_len[i] + chunks[i] for i in range(batch_size)]
        chunk_mask = torch.zeros((batch_size, max_num_chunks), dtype=torch.bool, device=self.device)
        for i in range(batch_size):
            if len(chunks[i]) != 0:
                chunk_mask[i, -len(chunks[i]):] = True
        self.llama.set_adapter("encoder")
        chunks_past_key_values = []
        for i in range(max_num_chunks):
            current_chunk = [chunk[i] for chunk in paded_chunks]
            chunk_compress_res = self.docompress(current_chunk)
            chunk_past_key_values = chunk_compress_res["mem_past_key_values"]
            mask = chunk_mask[:, i].view(-1, 1, 1, 1)
            masked_chunk_past_key_values = tuple(
                    (layer_key * mask, layer_value * mask) for layer_key, layer_value in chunk_past_key_values
                )
            chunks_past_key_values.append(masked_chunk_past_key_values)
        chunks_past_key_values = concate_past_key_value(chunks_past_key_values)
        
        pre_prompt_tokens_embeddings = self.llama.get_input_embeddings()(pre_prompt_tokens)

        with self.llama.disable_adapter():
            pre_decoder_output = self.llama(inputs_embeds=pre_prompt_tokens_embeddings)
        pre_token_cache = pre_decoder_output.past_key_values
        pre_token_cache = self.remove_position(pre_token_cache)
        pre_past_key_values = concate_past_key_value([pre_token_cache, chunks_past_key_values])
        past_key_values = self.re_position(pre_past_key_values)

        last_token_embeddings = self.llama.get_input_embeddings()(last_prompt_tokens_input)
        with self.llama.disable_adapter():
            prefill_decoder_output = self.llama(inputs_embeds=last_token_embeddings, past_key_values=past_key_values)

        prefill_past_key_values = self.remove_position(prefill_decoder_output.past_key_values)
        paded_past_key_values = []
        for layer_key, layer_value in prefill_past_key_values:
            new_layer_key = torch.zeros_like(layer_key, device=self.device)
            new_layer_value = torch.zeros_like(layer_value, device=self.device)
            for i in range(batch_size):
                if last_prompt_pad_len[i] == 0:
                    new_layer_key[i, :, :, :] = layer_key[i, :, :, :]
                    new_layer_value[i, :, :, :] = layer_value[i, :, :, :]
                else:
                    new_layer_key[i, :, last_prompt_pad_len[i]:, :] = layer_key[i, :, :-last_prompt_pad_len[i], :]
                    new_layer_value[i, :, last_prompt_pad_len[i]:, :] = layer_value[i, :, :-last_prompt_pad_len[i], :]
            paded_past_key_values.append((new_layer_key, new_layer_value))
        past_key_values = self.re_position(tuple(paded_past_key_values))

        input_tokens = last_prompt_tokens_last_token
        generated_text = [[] for _ in range(batch_size)]
            
        # Track whether each sequence in the batch has finished
        eos_mask = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
        
        for _ in range(max_new_tokens):
            input_embeddings = self.llama.get_input_embeddings()(input_tokens)

            with self.llama.disable_adapter():
                output = self.llama(inputs_embeds=input_embeddings, past_key_values=past_key_values)
            logits = output.logits
            past_key_values = output.past_key_values
            # choose the token id with the highest probability
            next_token = torch.argmax(logits[:, -1, :], dim=-1)

            eos_token_id = self.tokenizer.eos_token_id
            eos_mask = eos_mask | (next_token == eos_token_id) | (next_token == 128001) 
            
            if torch.all(eos_mask):
                break  
            for batch_id in range(batch_size):
                if eos_mask[batch_id] == False:
                    generated_text[batch_id].append(next_token.tolist()[batch_id])

            input_tokens = next_token.unsqueeze(-1)  # Make it (batch_size, 1)
            input_tokens = input_tokens.to(self.device)
            input_tokens = input_tokens * (~eos_mask).unsqueeze(-1).long() + eos_token_id * eos_mask.unsqueeze(-1).long()

        generated_text = [self.tokenizer.decode(batch, skip_special_tokens=True) for batch in generated_text]

        return generated_text


class E2RAGEditor(E2RAGBase):
    def __init__(
            self,
            llama_path: str,
            lora_config: LoraConfig,
            num_mem: int,
            mem_id: int,
            device_map,
            editor_config: LoraConfig = None,
            yoqo_path: str = None,
            init_path: str = None,
            hidden_dim: Union[int, str] = "auto",
            eval_mode: bool = False,
            num_edit: int = 4
    ):
        super(E2RAGBase, self).__init__()
        llama = AutoModelForCausalLM.from_pretrained(
            llama_path,
            torch_dtype=torch.bfloat16,
            device_map=device_map,
            trust_remote_code=True
        )
        if hidden_dim == "auto":
            hidden_dim = llama.config.hidden_size
        if hasattr(llama, "model"):
            self.rotary_emb = llama.model.rotary_emb
        else:
            self.rotary_emb = llama.transformer.rotary_emb
        self.tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print("llama tokenizer loaded.")
        if editor_config is None:
            editor_config = lora_config
        self.llama = get_peft_model(llama, editor_config, adapter_name="editor")
        self.llama.add_adapter("encoder", lora_config)

        print(f"Total parameters of llama: {sum(p.numel() for p in self.llama.parameters())}")

        self.device = self.llama.device
        self.num_mem = num_mem
        self.mem_id = mem_id
        self.max_length = 500
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.pre_token_cache = None
        self.queue = []
        self.max_queue_size = 20
        self.num_edit = num_edit
        
        self.memory_embeddings = nn.Parameter(torch.randn(1, num_mem, hidden_dim, dtype=torch.bfloat16).to(self.device))
        self.editing_embeddings = nn.Parameter(torch.randn(1, num_edit, hidden_dim, dtype=torch.bfloat16).to(self.device))

        if init_path is not None:
            self.load_lora_parameters(init_path, "editor")
        if yoqo_path is not None:
            self.load_lora_parameters(yoqo_path, "encoder")
        if not eval_mode:
            self.editing_embeddings.data = self.memory_embeddings.data.clone()[:, :self.editing_embeddings.shape[-2], :]
        
        for name, param in self.named_parameters():
            param.requires_grad = False
            if ('lora' in name and 'editor' in name) or 'editing_embeddings' in name:
                param.requires_grad = True
                print(name)

    def add_edit(self, input_ids, edit_positions):
        embeddings = self.llama.get_input_embeddings()(input_ids).to(self.device)
        edit_tok_embeddings = self.editing_embeddings.repeat(embeddings.shape[0], 1, 1).to(self.device)
        edit_indices = edit_positions.nonzero(as_tuple=True)  # (batch_idx, seq_idx)
        bs_indices, seq_indices = edit_indices
        embeddings[bs_indices, seq_indices, :] = edit_tok_embeddings.view(-1, edit_tok_embeddings.shape[-1])
        return embeddings
    
    def add_noise(self, trimmed_past_key_values, delta_past_key_values):
        noise_num = random.randint(0, 5)
        noise_num = min(noise_num, len(self.queue))
        noise = random.sample(self.queue, noise_num)
        edit_past_key_values = concate_past_key_value([trimmed_past_key_values, delta_past_key_values])
        noised_past_key_values = [edit_past_key_values] + noise

        random.shuffle(noised_past_key_values)
        self.queue.append(tuple(
            (layer_key.detach(), layer_value.detach()) for layer_key, layer_value in trimmed_past_key_values
        ))
        self.queue.append(tuple(
            (layer_key.detach(), layer_value.detach()) for layer_key, layer_value in edit_past_key_values
        ))
        if len(self.queue) > self.max_queue_size:
            self.queue = self.queue[-self.max_queue_size:]
        return concate_past_key_value(noised_past_key_values)

    def edit(self, trimmed_past_key_values, e_input_ids):
        edit_positions, input_ids = self.get_positions(e_input_ids)  # shape: [batch_size, seq_len]
        text_tok_embeddings = self.add_edit(input_ids, edit_positions)
        
        # compressed tokens
        encoder_output = self.llama(inputs_embeds=text_tok_embeddings, past_key_values=None, output_hidden_states=True)
        # get the K V values for the encoder output
        past_key_values = encoder_output.past_key_values
        remove_position_past_key_values = self.remove_position(past_key_values)
        delta_past_key_values = self.get_key_values(remove_position_past_key_values, edit_positions, num=self.num_edit)
        if self.training:
            edit_past_key_values = self.add_noise(trimmed_past_key_values, delta_past_key_values)
        else:
            edit_past_key_values = concate_past_key_value([trimmed_past_key_values, delta_past_key_values])
        # save the K V values for the compressed tokens
        last_hidden_states = encoder_output.hidden_states[-1]
        bs_indices, seq_indices = edit_positions.nonzero(as_tuple=True)
        edit_last_hidden_states = last_hidden_states[bs_indices, seq_indices, :].view(last_hidden_states.shape[0], self.num_mem, -1)
        
        logits = encoder_output.logits
        edit_logits = logits[bs_indices, seq_indices, :].view(logits.shape[0], self.num_mem, -1)

        return {
            "edit_past_key_values": edit_past_key_values,
            "edit_last_hidden_states": edit_last_hidden_states,
            "edit_logits": edit_logits,
        }
    
    def forward(self, c_input_ids, e_input_ids, pre_prompt_tokens, a_input_ids, a_labels):
        ####################
        # Encoder - llama+lora
        ####################
        with torch.no_grad():
            self.llama.set_adapter("encoder")
            trimmed_past_key_values = self.compress(c_input_ids)["mem_past_key_values"]
        self.llama.set_adapter("editor")
        edit_results = self.edit(trimmed_past_key_values, e_input_ids)
        edited_past_key_values = edit_results["edit_past_key_values"]
        ####################
        # Decoder - llama
        ####################
        if self.pre_token_cache is None:
            with torch.no_grad():
                pre_prompt_tokens_embeddings = self.llama.get_input_embeddings()(pre_prompt_tokens)
                with self.llama.disable_adapter():
                    pre_decoder_output = self.llama(inputs_embeds=pre_prompt_tokens_embeddings)
                pre_token_cache = pre_decoder_output.past_key_values
                pre_token_cache = self.remove_position(pre_token_cache)
                self.pre_token_cache = tuple(
                    (layer_key.detach(), layer_value.detach()) for layer_key, layer_value in pre_token_cache
                )
        pre_past_key_values = concate_past_key_value([self.pre_token_cache, edited_past_key_values])
        final_past_key_values = self.re_position(pre_past_key_values)
        
        a_tok_embeddings = self.llama.get_input_embeddings()(a_input_ids)
        with self.llama.disable_adapter():
            decoder_output = self.llama(inputs_embeds=a_tok_embeddings, past_key_values=final_past_key_values)
        all_logits = decoder_output.logits
        # calculate the cross entropy
        loss = self.criterion(all_logits.view(-1, all_logits.size(-1)), a_labels.view(-1))

        return {'loss': loss, 'logits': all_logits}


class EditorInference(E2RAGEditor, E2RAGInference):
    def edit_qa(self, chunks: Union[str, list], question: Union[str, list], edit_text: Union[str, list], max_new_tokens: int, multi_edit: bool = False, multi_chunk: bool = False):
        system_prompt = """You are an accurate and reliable AI assistant capable of answering questions by referencing external documents. Please note that the external documents may not always be related to the question. If the information in the documents contain the correct answer, you will provide an accurate response. If the documents do not contain the answer, you will refuse to answer."""
        user_prompt = """The documents are as follows:
<chunks>

Question: """

        if isinstance(chunks, str) and isinstance(question, str):
            # bs = 1
            chunks = [chunks]
            question = [question]
        
        batch_size = len(question)
        pre_prompts = []
        last_prompts = []
        for q in question:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt + q},
            ]
            prompt_inputs = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            pre_prompts.append(prompt_inputs.split("<chunks>")[0])
            last_prompts.append(prompt_inputs.split("<chunks>")[1])        

        pre_prompt_tokens = self.tokenizer(
            pre_prompts, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device)
        last_prompt_tokens = [self.tokenizer(
            last_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device) for last_prompt in last_prompts]
        max_last_prompt_len = max([last_prompt_token.shape[-1] for last_prompt_token in last_prompt_tokens])
        last_prompt_pad_len = [max_last_prompt_len - last_prompt_token.shape[-1] for last_prompt_token in last_prompt_tokens]
        last_prompt_tokens_last_token = torch.stack([last_prompt_token[:, -1] for last_prompt_token in last_prompt_tokens], dim=0)
        last_prompt_tokens_input = torch.full((batch_size, max_last_prompt_len - 1), self.tokenizer.eos_token_id, dtype=torch.long, device=self.device)
        for i in range(batch_size):
            if last_prompt_pad_len[i] == 0:
                last_prompt_tokens_input[i, :] = last_prompt_tokens[i][:, :-1]
            else:
                last_prompt_tokens_input[i, :-last_prompt_pad_len[i]] = last_prompt_tokens[i][:, :-1]
        if multi_chunk:
            chunks_past_key_values = []
            for i in range(len(chunks[0])):
                current_chunk = [chunk[i] for chunk in chunks]
                chunk_compress_res = self.docompress(current_chunk)
                chunk_past_key_values = chunk_compress_res["mem_past_key_values"]
                chunks_past_key_values.append(chunk_past_key_values)
            edited_past_key_values = self.edit_compressed_kv(chunks_past_key_values[0], edit_text)
            chunks_past_key_values[0] = edited_past_key_values
            edited_past_key_values = concate_past_key_value(chunks_past_key_values)
        else:
            chunk_past_key_values = self.get_compressed_kv(chunks)
            if multi_edit:
                edited_past_key_values = self.multi_edit_compressed_kv(chunk_past_key_values, edit_text)
            else:
                edited_past_key_values = self.edit_compressed_kv(chunk_past_key_values, edit_text)

        pre_prompt_tokens_embeddings = self.llama.get_input_embeddings()(pre_prompt_tokens)
        with self.llama.disable_adapter():
            pre_decoder_output = self.llama(inputs_embeds=pre_prompt_tokens_embeddings)

        pre_token_cache = pre_decoder_output.past_key_values
        pre_token_cache = self.remove_position(pre_token_cache)
        pre_past_key_values = concate_past_key_value([pre_token_cache, edited_past_key_values])
        past_key_values = self.re_position(pre_past_key_values)

        last_token_embeddings = self.llama.get_input_embeddings()(last_prompt_tokens_input)

        with self.llama.disable_adapter():
            prefill_decoder_output = self.llama(inputs_embeds=last_token_embeddings, past_key_values=past_key_values)

        prefill_past_key_values = self.remove_position(prefill_decoder_output.past_key_values)
        paded_past_key_values = []
        for layer_key, layer_value in prefill_past_key_values:
            new_layer_key = torch.zeros_like(layer_key, device=self.device)
            new_layer_value = torch.zeros_like(layer_value, device=self.device)
            for i in range(batch_size):
                if last_prompt_pad_len[i] == 0:
                    new_layer_key[i, :, :, :] = layer_key[i, :, :, :]
                    new_layer_value[i, :, :, :] = layer_value[i, :, :, :]
                else:
                    new_layer_key[i, :, last_prompt_pad_len[i]:, :] = layer_key[i, :, :-last_prompt_pad_len[i], :]
                    new_layer_value[i, :, last_prompt_pad_len[i]:, :] = layer_value[i, :, :-last_prompt_pad_len[i], :]
            paded_past_key_values.append((new_layer_key, new_layer_value))
        past_key_values = self.re_position(tuple(paded_past_key_values))

        input_tokens = last_prompt_tokens_last_token
        generated_text = [[] for _ in range(batch_size)]
            
        # Track whether each sequence in the batch has finished
        eos_mask = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
        
        for _ in range(max_new_tokens):
            input_embeddings = self.llama.get_input_embeddings()(input_tokens)

            with self.llama.disable_adapter():
                output = self.llama(inputs_embeds=input_embeddings, past_key_values=past_key_values)

            logits = output.logits
            past_key_values = output.past_key_values
            next_token = torch.argmax(logits[:, -1, :], dim=-1)

            eos_token_id = self.tokenizer.eos_token_id
            eos_mask = eos_mask | (next_token == eos_token_id) | (next_token == 128001) 

            if torch.all(eos_mask):
                break  
            for batch_id in range(batch_size):
                if eos_mask[batch_id] == False:
                    generated_text[batch_id].append(next_token.tolist()[batch_id])

            input_tokens = next_token.unsqueeze(-1)  # Make it (batch_size, 1)
            input_tokens = input_tokens.to(self.device)
            
            input_tokens = input_tokens * (~eos_mask).unsqueeze(-1).long() + eos_token_id * eos_mask.unsqueeze(-1).long()

        generated_text = [self.tokenizer.decode(batch, skip_special_tokens=True) for batch in generated_text]

        return generated_text
    
    def get_compressed_kv(self, chunks: Union[str, list]):
        self.llama.set_adapter("encoder")
        chunk_compress_res = self.docompress(chunks)
        chunk_past_key_values = chunk_compress_res["mem_past_key_values"]
        return chunk_past_key_values
    
    def multi_edit_compressed_kv(self, old_past_key_values, edit_text: list):
        self.llama.set_adapter("editor")
        if isinstance(edit_text[0], str):
            edit_text = [edit_text]
        
        delta_past_key_values_list = []
        for i in range(len(edit_text[0])):
            edit_input_ids_list = []
            max_token = 0
            edit_token_list = []
            for one in edit_text:
                text_tokens = self.tokenizer(
                    one[i],
                    truncation=True, 
                    max_length=self.max_length, 
                    return_tensors="pt",
                    add_special_tokens=False
                ).input_ids.squeeze()
                max_token = max(max_token, text_tokens.shape[0])
                edit_token_list.append(text_tokens)
            for one in edit_token_list:
                padding_length = max_token - one.size(0)
                if padding_length > 0:
                    one_input_ids = torch.cat(
                        [one,
                        torch.full((self.num_edit,), self.mem_id, dtype=torch.long),
                        torch.full((padding_length,), self.tokenizer.pad_token_id, dtype=torch.long)])
                else:
                    one_input_ids = torch.cat(
                        [one,
                        torch.full((self.num_edit,), self.mem_id, dtype=torch.long)])
                edit_input_ids_list.append(one_input_ids)
            edit_input_ids = torch.stack(edit_input_ids_list, dim=0).to(self.device)
        
            edit_positions, input_ids = self.get_positions(edit_input_ids)  # shape: [batch_size, seq_len]
            text_tok_embeddings = self.add_edit(input_ids, edit_positions)
            
            encoder_output = self.llama(inputs_embeds=text_tok_embeddings, past_key_values=None, output_hidden_states=True)
            # get the K V values for the encoder output
            past_key_values = encoder_output.past_key_values
            remove_position_past_key_values = self.remove_position(past_key_values)
            delta_past_key_values = self.get_key_values(remove_position_past_key_values, edit_positions, num=self.num_edit)
            delta_past_key_values_list.append(delta_past_key_values)
        edit_past_key_values = concate_past_key_value([old_past_key_values] + delta_past_key_values_list)
        return edit_past_key_values


    def edit_compressed_kv(self, old_past_key_values, edit_text: Union[str, list]):
        self.llama.set_adapter("editor")
        if isinstance(edit_text, str):
            edit_text = [edit_text]

        edit_input_ids_list = []
        max_token = 0
        edit_token_list = []
        for one in edit_text:
            text_tokens = self.tokenizer(
                one,
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt",
                add_special_tokens=False
            ).input_ids.squeeze()
            max_token = max(max_token, text_tokens.shape[0])
            edit_token_list.append(text_tokens)
        for one in edit_token_list:
            padding_length = max_token - one.size(0)
            if padding_length > 0:
                one_input_ids = torch.cat(
                    [one,
                    torch.full((self.num_edit,), self.mem_id, dtype=torch.long),
                    torch.full((padding_length,), self.tokenizer.pad_token_id, dtype=torch.long)])
            else:
                one_input_ids = torch.cat(
                    [one,
                    torch.full((self.num_edit,), self.mem_id, dtype=torch.long)])
            edit_input_ids_list.append(one_input_ids)
        edit_input_ids = torch.stack(edit_input_ids_list, dim=0).to(self.device)

        edit_results = self.edit(old_past_key_values, edit_input_ids)
        return edit_results["edit_past_key_values"]