import transformers
from transformers import BertModel, BertTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoTokenizer, LlamaModel, AutoConfig
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
 
from dataclasses import dataclass, field
from typing import Optional, List
from peft import (
    get_peft_model,
)

from safetensors.torch import load_file
from icecream import ic as pprint
from math import sqrt
import random

import json


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None, metadata={"help": "Path to model."})
    target_model: str = field(default=None, metadata={"help": "Path to target model."})
    
    lora_r: int = 128
    lora_dropout: float = 0.05
    lora_alpha: int = 32

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    is_train: bool = True
    ppl_memory: bool = True
    mrMR: bool = True
    reconstruct: bool = True
    alpha: float = 1.0
    compressor_hidden_size: int = 4096
    target_llm_hidden_size: int = 4096
    mem_size: int = 32
    segment_size: int = 200
    benchmark_metric: str = None
    compressor_hidden_layers: int = 2
    merge_size: int = 4
    head_num: int = 8
    scale: float = 0.5
    mean: bool = True
    num_mem_fusion_layers: int = 1
    mem_lora: bool = True

    segment_size: int = 200
    benchmark_metric: str = "accuracy"
    post_append: bool = False
    is_random: bool = True
    
    split: bool = False
    autoregressive: bool = True
    
    early_stopping_patience: int = 1
    fine_tune: bool = False
    
    lm_ratio: float = 0.5
    leave_len: int = 100
    
    prefix_type: str = 'rs_prefix'
    full: bool = False
    
    keft: bool = False
    restatement_ratio: float = 0.5
    
    icae_infer: bool = False
    
    use_transform_layer: bool = True

@dataclass
class QGCArguments:
    compressor_path: str = None
    lm_model_path: str = None
    lm_model_name: str = 'longchat'
    num_compressor_layers: int = 4
    num_compressor_encoder_layers: int = 2
    fix_compressor_mlp_parameters: bool = False
    num_attention_heads: int = 32
    attn_doc_topp: float = 0.25
    # compressor_hidden_size: int = 4096
    # lm_model_hidden_size: int = 5120
    generation_split_token: str = None
    pool_window_size: int = 4
    random_pool_window_size: bool = False
    cand_pool_window_sizes: List[int] = None

    label_pad_token_id: int = -100

    # inference args
    pw_window_sizes: List[int] = None
    data_path: str = None
    train_data_path: str = None
    num_eval_documents: int = 5

    num_gold_documents: int = 1
    use_answer_as_target: bool = False
    instruction_name: str = 'base'
    # instruction_name: str = 'summary'
    gold_first_for_kd: bool = False

    min_num_documents: int = 1
    max_num_documents: int = 5
    random_num_documents: bool = False

    max_new_tokens: int = 100
    max_doc_tokens: int = 512
    
    restatement: bool = True

def freeze_model(model):
    for _, param in model.named_parameters():
        param.requires_grad = False
        
def freeze_mlp(model):
    for name, param in model.named_parameters():
        if 'mlp' in name:
            param.requires_grad = False

def print_trainable_parameters(model):
    trainable_parameters = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_parameters += param.numel()
    print(f"trainable params: {trainable_parameters} || all params: {all_param} || trainable: {100 * trainable_parameters / all_param} %")

def write_txt(name, item):
    with open(name, "a+") as f:
        f.write(str(item))
        f.write("\n")
    
def select_key_values(key_values, need_idx):
    return [
        [
            k[..., need_idx, :],
            v[..., need_idx, :],
        ]
        for k, v in key_values
    ]

def get_embeding_from_block(
    text,
    model,
    tokenizer
):
    encoded_input = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    with torch.no_grad():
        output = model(**encoded_input)
    
    last_hidden_states = output.last_hidden_state
    mean_pooled_embedding = torch.mean(last_hidden_states[:, 1:-1, :], dim=1)[0][0]
    return mean_pooled_embedding

def mrMR_loss(
    query_embeds, 
    memory_embeds,
):
    # q : [batch_size, query_len, hidden_size]
    # memory_embeds : [batch_size, memory_size, hidden_size]
    distilation_loss = 0
    batch_size = query_embeds.shape[0]
    
    for idx in range(batch_size):

        query_embed = torch.mean(query_embeds[idx], dim=0).squeeze(0)
        # pprint(query_embed.shape)
        loss = 0
        for memory_embed in memory_embeds[idx]:
            relevance = query_embed @ memory_embed
            debias_memory_embeds = [v for v in memory_embeds[idx] if not torch.equal(memory_embed, v)]
            redundancy = 0
            for e in debias_memory_embeds:
                redundancy += (e @ memory_embed)
            redundancy = 1.0 * redundancy / len(debias_memory_embeds)
            loss += (redundancy - relevance)
        
        loss = 1.0 * loss / len(memory_embeds[idx])

        distilation_loss += loss 

    return 1.0 * distilation_loss / batch_size

def weight_to_mem_size(
    weights,
    mem_size      
):
    allocated_mem_sizes = []
    for idx, w in enumerate(weights):
        if idx == (len(weights) - 1):
            s = sum(allocated_mem_sizes)
            allocated_mem_sizes.append(mem_size - s)
        else:
            allocated_mem_sizes.append(int(w * mem_size))

    return allocated_mem_sizes

def dynamic_allocator(
    segments, 
    query,
    mem_size,
    model,
    tokenizer
):
    query_embed = get_embeding_from_block(
        query,
        model,
        tokenizer
    )
    similaritys = []
    sum_similarity = 0
    for segment in segments:
        segment_embed = get_embeding_from_block(
            segment,
            model,
            tokenizer
        )

        similarity = query_embed @ segment_embed
        sum_similarity += similarity
        similaritys.append(similarity)

    normalize_similaritys = [1.0 * e / sum_similarity for e in similaritys]
    allocated_mem_sizes = weight_to_mem_size(normalize_similaritys, mem_size)
    return allocated_mem_sizes

class MemoryFusion(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MemoryFusion, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.q_proj = nn.Linear(d_model, d_model).to(torch.bfloat16)
        self.k_proj = nn.Linear(d_model, d_model).to(torch.bfloat16)
        self.v_proj = nn.Linear(d_model, d_model).to(torch.bfloat16)
        self.fct = nn.Linear(d_model, d_model).to(torch.bfloat16)
        
    def forward(self, X, attention_mask=None):
        # pprint(X.dtype)
        X = X.to(torch.bfloat16)
        # exit(0)
        batch_size, seq_len, d_model = X.size()
        
        Q = self.q_proj(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.k_proj(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.v_proj(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.bfloat16))
        
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # 形状为 [batch_size, 1, 1, seq_len]
            scores = scores.masked_fill(attention_mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        
        output = torch.matmul(attention_weights, V)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        output = self.fct(output)
        
        final_output = torch.mean(output, dim=1)
        
        return final_output

class LLMCompressor(nn.Module):
    def __init__(
            self,
            model_args, 
            training_args,
            lora_config,
            fusion_lora_config
        ):
        
        super().__init__()

        self.model_args = model_args
        self.training_args = training_args
        
        self.model_name = model_args.model_name_or_path
        self.target_llm = model_args.target_model

        self.training = training_args.is_train
        self.ppl_memory = training_args.ppl_memory
        self.post_append = training_args.post_append
        # pprint(self.ppl_memory)
        self.mrMR = training_args.mrMR
        self.alpha = training_args.alpha
        self.reconstruct = training_args.reconstruct
        self.merge_size = training_args.merge_size
        self.compressor_hidden_size = training_args.compressor_hidden_size
        self.llm_hidden_size = training_args.target_llm_hidden_size
        self.mem_size = training_args.mem_size
        self.segment_size = training_args.segment_size
        self.mean = training_args.mean
        self.split = training_args.split
        self.autoregressive = training_args.autoregressive
        self.fine_tune = training_args.fine_tune
        self.lm_ratio = training_args.lm_ratio
        self.scale = training_args.scale
        self.leave_len = training_args.leave_len
        self.full = training_args.full
        self.keft = training_args.keft
        self.restatement_ratio = training_args.restatement_ratio
        self.use_transform_layer = training_args.use_transform_layer
        
        # self.memory_fusion_layer = get_peft_model(self.memory_fusion_layer, fusion_lora_config)
        # exit(0)
        self.compressor_config = AutoConfig.from_pretrained(self.model_name)
        self.fusion_layer_config = AutoConfig.from_pretrained(self.target_llm)
        
        self.compressor_size = self.compressor_config.hidden_size
        self.target_size = self.fusion_layer_config.hidden_size
        
        self.compressor_config.num_hidden_layers = training_args.compressor_hidden_layers
        # pprint(self.compressor_config)
        # exit(0)
        self.llm_encoder = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            config=self.compressor_config
        )
        
        self.llm_encoder = self.llm_encoder.to(torch.bfloat16)
        # if not self.fine_tune:
        self.llm_encoder = get_peft_model(self.llm_encoder, lora_config)
        self.is_random = training_args.is_random
        if self.post_append:
            self.semantic_alignment_layer = nn.Linear(self.compressor_hidden_size \
                                                , self.llm_hidden_size).to(dtype=torch.float16)
        
        # load memory fusion layer
        if not self.post_append:
            self.fusion_layer_config.num_hidden_layers = training_args.num_mem_fusion_layers
            if self.use_transform_layer:
                if self.compressor_size != self.target_size:
                    self.dimension_alignment_layer = nn.Linear(self.compressor_size \
                                                        , self.target_size).to(dtype=torch.bfloat16) 
                     
                self.memory_fusion_layer = AutoModelForCausalLM.from_pretrained(
                    self.target_llm,
                    config=self.fusion_layer_config
                )
                self.memory_fusion_layer = self.memory_fusion_layer.to(torch.bfloat16)
                if training_args.mem_lora:
                    self.memory_fusion_layer = get_peft_model(self.memory_fusion_layer, lora_config)

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        self.vocab_size = self.llm_encoder.config.vocab_size   #[PAD] token
        self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
        self.tokenizer.pad_token = self.tokenizer.unk_token
        self.vocab_size_with_mem = self.vocab_size + self.mem_size

        # resize embeddings 
        self.llm_encoder.resize_token_embeddings(self.vocab_size + self.mem_size)
        # self.llm.resize_token_embeddings(self.llm_vocab_size)

        if self.post_append:
            self.memory_token_embed = nn.Embedding(self.mem_size, self.compressor_hidden_size, padding_idx=None)
        # self.special_token_embed = nn.Embedding(1, self.compressor_hidden_size, padding_idx=None)
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        self.ppl_fct = nn.CrossEntropyLoss(reduction="none")
        self.icae_infer = training_args.icae_infer
        # self.contrast_loss = InfoNCELoss(temperature=0.1)
        # generate 1-d tensor
        if self.post_append:
            self.memory_sequence = torch.arange(self.vocab_size, self.vocab_size + self.mem_size)
            
        if  self.training:
            self.llm_tokenizer = AutoTokenizer.from_pretrained(self.target_llm)
            self.llm = AutoModelForCausalLM.from_pretrained(
                self.target_llm,
                torch_dtype=torch.bfloat16
            )
            
            # if self.fine_tune:
            # self.llm = get_peft_model(self.llm, lora_config)

            self.llm_vocab_size = self.llm.config.vocab_size
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.unk_token_id
            self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
            
            self.init()
    
    def generate_merge_size(
        self
    ):
        numbers = [4, 8]

        sampled_number = random.choice(numbers)
        return sampled_number
    
    def generate_post_append_size(
        self  
    ):
        numbers = [64, 128]

        sampled_number = random.choice(numbers)
        return sampled_number
    
    def mrMR_loss2(
        self,
        query_embeds, 
        memory_embeds,
        query_attention_mask
    ):
        # q : [batch_size, query_len, hidden_size]
        # memory_embeds : [batch_size, memory_size, hidden_size]
        
        distilation_loss = 0
        batch_size = query_embeds.shape[0]
        # pprint(memory_embeds.shape)
        # exit(0)
        
        for idx in range(batch_size):
            effect_query_embeds = torch.cat([x.unsqueeze(0) for x, y in zip(query_embeds[idx], query_attention_mask[idx]) if y == 1], dim=0)
            query_embed = torch.mean(effect_query_embeds, dim=0).squeeze(0)
            # pprint(query_embed.shape)
            loss = 0
            for memory_embed in memory_embeds[idx]:
                debias_memory_embeds = torch.cat([v.unsqueeze(0) for v in memory_embeds[idx] if not torch.equal(memory_embed, v)], dim=0)
                # pprint(debias_memory_embeds.shape)
                # exit(0)
                loss += self.infoNCE_loss(
                    query_embed=query_embed,
                    memory_embed=memory_embed,
                    other_memorys_embed=debias_memory_embeds
                )
            # distilation_loss += (1.0 * loss / memory_embeds.shape[1])
            distilation_loss += loss

        return distilation_loss
    
    def max_relevence_score(
        self,
        query_embeds, 
        memory_embeds,
        query_attention_mask
    ):
        # q : [batch_size, query_len, hidden_size]
        # memory_embeds : [batch_size, memory_size, hidden_size]
        
        total_r = 0
        batch_size = query_embeds.shape[0]
        
        for idx in range(batch_size):
            effect_query_embeds = torch.cat([x.unsqueeze(0) for x, y in zip(query_embeds[idx], query_attention_mask[idx]) if y == 1], dim=0)
            query_embed = torch.mean(effect_query_embeds, dim=0)
            # pprint(query_embed.shape)
            max_r = self.max_relevence(
                query_embed=query_embed,
                mermory_embeds=memory_embeds[idx]    
            )
            total_r += max_r

        return 1.0 * total_r / batch_size
    
    # def distillation_loss(
    #     logits,
    #     target_logits
    # ):

    def infoNCE_loss(
        self, 
        query_embed, 
        memory_embed, 
        other_memorys_embed, 
        temperature=0.1
    ):
        # query_embed: torch.size([768])
        # memory_embed: torch.size([768])
        # other_memorys_embed: torch.size([8, 768])

        memory_embed = memory_embed.unsqueeze(0)  # [1, 768]
        other_memorys_embed = other_memorys_embed  # [8, 768]

        positive_similarity = torch.matmul(query_embed.unsqueeze(0), memory_embed.T) / temperature  # [1, 1]
        negative_similarity = torch.matmul(query_embed.unsqueeze(0), other_memorys_embed.T) / temperature  # [1, 8]

        all_similarity = torch.cat([positive_similarity, negative_similarity], dim=1)  # [1, 9]
        # all_similarity = all_similarity.float() 

        labels = torch.zeros(all_similarity.size(0), dtype=torch.long).to(query_embed.device)  # [9]
        # labels = labels.float()
        # pprint(all_similarity.shape)
        # pprint(labels.shape)
        # exit(0)
        loss = self.loss_fct(all_similarity, labels)
        return loss
    
    def vertical_pooling(
        self,
        query_embeds,
        memorys_hidden_state,
        query_attention_masks,
    ):
        # query_embeds : [batch_size, que_len, hidden_size]
        # memorys_hidden_state : Tuple([batch_size, mem_size, hidden_size]) Tuple_size : 13

        batch_size = query_embeds.shape[0]
        
        batch_global_memorys = []
        for idx in range(batch_size):
            effect_query_embeds = torch.cat([x.unsqueeze(0) for x, y in zip(query_embeds[idx], query_attention_masks[idx]) if y == 1], dim=0)
            # query_embed : [768]
            query_embed = torch.mean(effect_query_embeds, dim=0)
            q = self.q_proj(query_embed)
            # current_memorys_hidden_state : [13, mem_size, 768]
            current_memorys_hidden_state = torch.cat([e[idx].unsqueeze(0) for e in memorys_hidden_state], dim=0)
            k = self.k_proj(current_memorys_hidden_state)
            v = self.v_proj(current_memorys_hidden_state)
            # weights_matrix : [13, mem_size]
            weights_matrix = torch.matmul(k, q) / sqrt(768)
            # pprint(weights_matrix.shape)
            weights_matrix = torch.softmax(weights_matrix, dim=0)
            weights_matrix_expanded = weights_matrix.unsqueeze(2)  

            weighted_memory = weights_matrix_expanded * v

            global_memorys = torch.sum(weighted_memory, dim=0)
            # pprint(global_memorys.shape)
            batch_global_memorys.append(global_memorys.unsqueeze(0))
        
        return torch.cat(batch_global_memorys, dim=0)

    def max_relevence(
        self,
        query_embed,
        mermory_embeds
    ):
        # query_embed: torch.size([1, 768])
        # memory_embeds: torch.size([N, 768])
        simmilarity = torch.matmul(query_embed, mermory_embeds.T) # [1, N]
        relevence_score = torch.mean(simmilarity.squeeze(0), dim=0)
        return relevence_score

    def freeze_others(
        self
    ):
        for name, param in self.llm.named_parameters():
            if not ('q_proj' in name or 'k_proj' in name or 'v_proj' in name):
            # if not ('q_proj.lora' in name or 'k_proj.lora' in name or 'v_proj.lora' in name):
                param.requires_grad = False

    def init(self):
        print("Freezing the target LLM...")

        if not self.fine_tune:
            freeze_model(self.llm)
        else:
            freeze_model(self.llm_encoder)
            if self.use_transform_layer:
                freeze_model(self.memory_fusion_layer)
            self.freeze_others()

        if not self.fine_tune:
            self.llm.eval()    
        else:
            self.llm_encoder.eval()
            if self.use_transform_layer:
                self.memory_fusion_layer.eval()
                
        print_trainable_parameters(self)

    def split_segments(
        self,
        context_input_ids,
        memory_sequence,
        input_mask
    ):
        # pprint(context_input_ids.shape)
        batch_size = context_input_ids.shape[0]
        memory_sequence = memory_sequence.repeat(batch_size, 1)
        # pprint(memory_sequence.shape)
        num_segments = math.ceil(context_input_ids.shape[1] * 1.0 / self.segment_size)
        # pprint(self.segment_size)
        # exit(0)
        segments_input_ids = [context_input_ids[:, i * self.segment_size: (i + 1) \
                                        * self.segment_size] for i in range(num_segments - 1)]
        segments_input_mask = [input_mask[:, i * self.segment_size: (i + 1) \
                                        * self.segment_size] for i in range(num_segments - 1)]
        last_segment_start_index = (num_segments - 1) * self.segment_size
        segments_input_ids.append(context_input_ids[:, last_segment_start_index:])
        segments_input_mask.append(input_mask[:, last_segment_start_index:])

        memory_len = math.ceil(memory_sequence.shape[1] * 1.0 / num_segments)
        memory_sequences = [memory_sequence[:, i * memory_len : (i + 1)\
                                            * memory_len] for i in range(num_segments - 1)]
        memory_sequences.append(memory_sequence[:, (num_segments - 1) * memory_len:])
        return segments_input_ids, memory_sequences, segments_input_mask

    def compute_num_segments(self, total_length):
        assert total_length > 0
        num_segments = math.ceil(total_length * 1.0 / self.segment_size)
        return num_segments

    def get_avg_embeds(self, tensor, attention_mask):
        attention_mask = attention_mask.unsqueeze(-1)  # [batch_size, merge_size, 1]

        masked_tensor = tensor * attention_mask  # [batch_size, merge_size, hidden_size]

        sum_tensor = masked_tensor.sum(dim=1)  # [batch_size, hidden_size]

        valid_counts = attention_mask.sum(dim=1)  # [batch_size, 1]

        average_tensor = sum_tensor / valid_counts  # [batch_size, hidden_size]
        return average_tensor

    def generate_pooling_memorys(
        self,
        input_ids,
        input_mask,
        merge_size
    ):
        device = input_ids.device
        if self.is_random:
            merge_size = self.generate_merge_size()
            
        last_hidden_state = self.llm_encoder(
            input_ids=input_ids,
            attention_mask=input_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            current_mask = input_mask[i]
            is_all_zero = torch.all(current_mask == 0).item()
            # pprint(is_all_zero)
            if is_all_zero:
                sequence_length = last_hidden_state.shape[0]
                if sequence_length % merge_size == 0:
                    pad_length = 0
                else:
                    pad_length = merge_size - (sequence_length % merge_size)
                
                mem_size = int((sequence_length + pad_length) / merge_size)
                memorys = torch.zeros(mem_size, hidden_size).unsqueeze(0).to(device)
                # pprint(memorys.shape)
                memorys_list.append(memorys)
                continue
            
            current_hidden_state = last_hidden_state[i]
            
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            current_mask = torch.tensor([x for x in current_mask if x == 1]).to(device)
            
            # # select_hidden_state: [seq, hidden_size]
            # generate memorys then transform
            sequence_length = select_hidden_state.shape[0]
            res_len = sequence_length % merge_size
            if res_len != 0:
                res_hidden_state = select_hidden_state[-res_len:,:]
                res_memory = torch.mean(res_hidden_state, dim=0).unsqueeze(0)
                select_hidden_state = select_hidden_state[:-res_len, :]
                
            hidden_state_reshaped = select_hidden_state.reshape(-1, merge_size, hidden_size)
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            memorys = torch.mean(hidden_state_reshaped, dim=1)
            
            if res_len != 0:
                memorys = torch.cat((memorys, res_memory), dim=0).unsqueeze(0)
            else:
                memorys = memorys.unsqueeze(0)
            
            # if sequence_length % merge_size == 0:
            #     pad_length = 0
            # else:
            #     pad_length = merge_size - (sequence_length % merge_size)
            
            # pad_embeds = torch.zeros(pad_length, select_hidden_state.shape[1]).to(device)
            # pad_mask = torch.zeros(pad_length).to(device)
            # total_hidden_state = torch.cat((select_hidden_state, pad_embeds), dim=0)
            # pad_att_mask = torch.cat((current_mask, pad_mask), dim=0)
            # # pad_att_mask: [mem_len, merge_size]
            # pad_att_mask = pad_att_mask.reshape(-1, merge_size)
            # # hidden_state_reshaped.shape: [mem_len, merge_size, hidden_size]
            # hidden_state_reshaped = total_hidden_state.reshape(-1, merge_size, select_hidden_state.shape[1])
            
            # hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            # memorys = self.memory_fusion_layer(
            #     inputs_embeds=hidden_state_reshaped,
            #     output_hidden_states=True,
            #     attention_mask=pad_att_mask
            # ).hidden_states[-1]
            # # memorys = torch.mean(hidden_state_reshaped, dim=1).unsqueeze(0)
            # memorys = self.get_avg_embeds(memorys, pad_att_mask).unsqueeze(0)
            memorys_list.append(memorys)
            
        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        att_mask = att_mask.to(device)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        if self.use_transform_layer:
            if self.compressor_size != self.target_size:
                final_memorys = self.dimension_alignment_layer(final_memorys)
                
            aligned_memorys = self.memory_fusion_layer(
                inputs_embeds=final_memorys,
                attention_mask=att_mask,
                output_hidden_states=True
            ).hidden_states[-1]
        else:
            aligned_memorys = final_memorys
        
        return aligned_memorys, att_mask

    def generate_autoregressive_pooling_memorys(
        self,
        input_ids,
        input_mask,
        merge_size,
        pre_mem_embeds=None,
        pre_mem_attention_mask=None
    ):
        # pprint(input_ids.shape)
        # pprint(input_mask.shape)
        device = input_ids.device
        if pre_mem_embeds is not None:
            input_embeds = self.llm.model.embed_tokens(input_ids)
            final_input_embeds = torch.cat((pre_mem_embeds, input_embeds), dim=1)
            final_attention_mask = torch.cat((pre_mem_attention_mask, input_mask), dim=1)
            last_hidden_state = self.llm_encoder(
                inputs_embeds=final_input_embeds,
                attention_mask=final_attention_mask,
                output_hidden_states=True
            ).hidden_states[-1][:, pre_mem_embeds.shape[1]:, :]
        else:
            last_hidden_state = self.llm_encoder(
                input_ids=input_ids,
                attention_mask=input_mask,
                output_hidden_states=True
            ).hidden_states[-1]
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            current_mask = input_mask[i]
            is_all_zero = torch.all(current_mask == 0).item()
            # pprint(is_all_zero)
            if is_all_zero:
                sequence_length = last_hidden_state.shape[0]
                if sequence_length % merge_size == 0:
                    pad_length = 0
                else:
                    pad_length = merge_size - (sequence_length % merge_size)
                
                mem_size = int((sequence_length + pad_length) / merge_size)
                memorys = torch.zeros(mem_size, hidden_size).unsqueeze(0).to(device)
                # pprint(memorys.shape)
                memorys_list.append(memorys)
                continue
            
            current_hidden_state = last_hidden_state[i]
            
            # pprint(current_hidden_state.shape)
            # pprint(current_mask)
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            # pprint(select_hidden_state.shape)
            # exit(0)
            
            sequence_length = select_hidden_state.shape[0]
            if sequence_length % merge_size == 0:
                pad_length = 0
            else:
                pad_length = merge_size - (sequence_length % merge_size)
            
            padding = torch.zeros(pad_length, hidden_size).to(device)
            total_hidden_state = torch.cat((select_hidden_state, padding), dim=0)
            # hidden_state_reshaped.shape: [mem_len, merge_size, hidden_size]
            hidden_state_reshaped = total_hidden_state.reshape(-1, merge_size, hidden_size)
            # pprint(hidden_state_reshaped.shape)
            
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            memorys = self.memory_fusion_layer(
                inputs_embeds=hidden_state_reshaped,
                output_hidden_states=True
            ).hidden_states[-1]
            memorys = torch.mean(memorys, dim=1).unsqueeze(0)
            # pprint(memorys.shape)
            # exit(0)
            memorys_list.append(memorys)
        # exit(0)
            

        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        pprint(max_len)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        # pprint(final_memorys_list[0].shape)
        # pprint(final_memorys_list[1].shape)
        # pprint(att_mask)
        # exit(0)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        aligned_memorys = final_memorys
        # self.semantic_alignment_layer.weight.data = self.semantic_alignment_layer.weight.data.to(torch.bfloat16)
        # pprint(final_memorys.dtype)
        # pprint(self.semantic_alignment_layer.weight.dtype)
        # exit(0)
        # aligned_memorys = self.semantic_alignment_layer(final_memorys).to(torch.bfloat16)
        att_mask = att_mask.to(device)
        
        if pre_mem_embeds is not None:
            total_memorys = torch.cat((pre_mem_embeds, aligned_memorys), dim=1)
            total_attention_mask = torch.cat((pre_mem_attention_mask, att_mask), dim=1)
        else:
            total_memorys = aligned_memorys
            total_attention_mask = att_mask
        
        return total_memorys, total_attention_mask

    def generate_memorys_then_transform(
        self,
        input_ids,
        input_mask,
    ):
        # pprint(input_ids.shape)
        # pprint(input_mask.shape)
        device = input_ids.device
        last_hidden_state = self.llm_encoder(
            input_ids=input_ids,
            attention_mask=input_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        # pprint(last_hidden_state.dtype)
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            current_mask = input_mask[i]
            is_all_zero = torch.all(current_mask == 0).item()
            # pprint(is_all_zero)
            if is_all_zero:
                sequence_length = last_hidden_state.shape[0]
                if sequence_length % self.merge_size == 0:
                    pad_length = 0
                else:
                    pad_length = self.merge_size - (sequence_length % self.merge_size)
                
                mem_size = int((sequence_length + pad_length) / self.merge_size)
                memorys = torch.zeros(mem_size, hidden_size).unsqueeze(0).to(device)
                # pprint(memorys.shape)
                memorys_list.append(memorys)
                continue
            
            current_hidden_state = last_hidden_state[i]
            
            # pprint(current_hidden_state.shape)
            # pprint(current_mask)
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            # pprint(select_hidden_state.shape)
            # exit(0)
            
            sequence_length = select_hidden_state.shape[0]
            if sequence_length % self.merge_size == 0:
                pad_length = 0
            else:
                pad_length = self.merge_size - (sequence_length % self.merge_size)
            
            padding = torch.zeros(pad_length, select_hidden_state.shape[1]).to(device)
            total_hidden_state = torch.cat((select_hidden_state, padding), dim=0)
            # hidden_state_reshaped.shape: [mem_len, merge_size, hidden_size]
            hidden_state_reshaped = total_hidden_state.reshape(-1, self.merge_size, select_hidden_state.shape[1])
            # pprint(hidden_state_reshaped.shape)
            
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            # memorys = self.memory_fusion_layer(
            #     inputs_embeds=hidden_state_reshaped,
            #     output_hidden_states=True
            # ).hidden_states[-1]
            memorys = torch.mean(hidden_state_reshaped, dim=1).unsqueeze(0)
            # pprint(memorys.shape)
            # exit(0)
            memorys_list.append(memorys)
        # exit(0)
            

        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        # for idx, e in enumerate(memorys_list):
        #     pad_len = max_len - e.shape[1]
        #     pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
        #     pad_memorys = torch.cat((e, pad_embeds), dim=1)

        #     if pad_len != 0:
        #         pad_mask = torch.zeros(pad_len).to(device)
        #         att_mask[idx][-pad_len:] = pad_mask
        #         # pprint(pad_len)
        #         # pprint(pad_mask.shape)
        #     final_memorys_list.append(pad_memorys)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        # pprint(final_memorys_list[0].shape)
        # pprint(final_memorys_list[1].shape)
        # pprint(att_mask)
        # exit(0)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        # final_memorys.shape: [batch_size, mem_len, hidden_size]
        final_memorys = self.memory_fusion_layer(
                inputs_embeds=final_memorys,
                output_hidden_states=True
            ).hidden_states[-1]
        aligned_memorys = final_memorys
        # self.semantic_alignment_layer.weight.data = self.semantic_alignment_layer.weight.data.to(torch.bfloat16)
        # pprint(final_memorys.dtype)
        # pprint(self.semantic_alignment_layer.weight.dtype)
        # exit(0)
        # aligned_memorys = self.semantic_alignment_layer(final_memorys).to(torch.bfloat16)
        att_mask = att_mask.to(device)
        # pprint(final_memorys.shape)
        # pprint(att_mask.shape)
        # exit(0)
        
        return aligned_memorys, att_mask
    
    def get_mean_query_embeddings(
        self,
        query_hidden_state, # [batch_size, actual_q_len, hidden_size]
        query_mask
    ):
        mask = query_mask.unsqueeze(-1)  # [batch_size, query_len, 1]
        
        masked_hidden_state = query_hidden_state * mask  # [batch_size, query_len, hidden_size]
        
        valid_length = mask.sum(dim=1)  # [batch_size, 1]
        
        mean_query_embeddings = masked_hidden_state.sum(dim=1) / valid_length  # [batch_size, hidden_size]
        return mean_query_embeddings

    def weighted_pooling(
        self,
        mean_query_embedding, 
        hidden_state_reshaped
    ):
        # hidden_state_reshaped: [mem_len, merge_size, hidden_size]
        # mean_query_embedding: [hidden_size]
        dot_product = torch.einsum("msh,h->ms", hidden_state_reshaped, mean_query_embedding)  # [mem_len, merge_size]
        scaled_dot_product = dot_product
        
        # scale_factor = torch.sqrt(torch.tensor(hidden_state_reshaped.size(-1), dtype=torch.bfloat16))  # sqrt(hidden_size)
        # scaled_dot_product = dot_product / scale_factor  # [mem_len, merge_size]

        weights = F.softmax(scaled_dot_product, dim=-1)  # [mem_len, merge_size]

        # weights: [mem_len, merge_size] -> unsqueeze 为 [mem_len, merge_size, 1]
        # hidden_state_reshaped: [mem_len, merge_size, hidden_size]
        weighted_sum = torch.einsum("msh,ms->mh", hidden_state_reshaped, weights)  # [mem_len, hidden_size]

        return weighted_sum
        
    def res_weighted_pooling(
        self,
        mean_query_embedding, 
        res_hidden_state
    ):
        # res_hidden_state: [res_len, hidden_size]
        # mean_query_embeddings: [hidden_size]
        dot_product = torch.einsum("lh,h->l", res_hidden_state, mean_query_embedding)  # [res_len]
        scaled_dot_product = dot_product
        
        # scale_factor = torch.sqrt(torch.tensor(res_hidden_state.size(-1), dtype=torch.bfloat16))  # sqrt(hidden_size)
        # scaled_dot_product = dot_product / scale_factor  # [res_len]

        weights = F.softmax(scaled_dot_product, dim=-1)  # [res_len]

        # weights: [res_len] -> unsqueeze 为 [res_len, 1]
        # res_hidden_state: [res_len, hidden_size]
        weighted_sum = torch.einsum("lh,l->h", res_hidden_state, weights)  # [hidden_size]

        return weighted_sum
    
    def generate_query_guided_pooling_memorys(
        self,
        query_ids, # with bos token
        query_mask,
        input_ids,
        input_mask,
        merge_size
    ):

        device = input_ids.device
        
        final_input_ids = torch.cat((query_ids, input_ids), dim=1)
        final_input_mask = torch.cat((query_mask, input_mask), dim=1)
        
        query_len = query_ids.shape[1]
        last_hidden_state = self.llm_encoder(
            input_ids=final_input_ids,
            attention_mask=final_input_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        # mean_query_embeddings.shape: [batch_size, hidden_size]
        mean_query_embeddings = self.get_mean_query_embeddings(
            query_hidden_state=last_hidden_state[:, :query_len, :],
            query_mask=query_mask
        )
        
        last_hidden_state = last_hidden_state[:, query_len:, :]
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            mean_query_embedding = mean_query_embeddings[i]
            current_mask = input_mask[i]
            current_hidden_state = last_hidden_state[i]
            
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            current_mask = torch.tensor([x for x in current_mask if x == 1]).to(device)
            
            # # select_hidden_state: [seq, hidden_size]
            # generate memorys then transform
            sequence_length = select_hidden_state.shape[0]
            res_len = sequence_length % merge_size
            if res_len != 0:
                # res_hidden_state.shape: [res_len, hidden_size]
                res_hidden_state = select_hidden_state[-res_len:,:]
                # res_memory = torch.mean(res_hidden_state, dim=0).unsqueeze(0)
                res_memory = self.res_weighted_pooling(
                    mean_query_embedding=mean_query_embedding,
                    res_hidden_state=res_hidden_state
                )
                res_memory = res_memory.unsqueeze(0)
                
                select_hidden_state = select_hidden_state[:-res_len, :]
                
            hidden_state_reshaped = select_hidden_state.reshape(-1, merge_size, hidden_size)
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            memorys = self.weighted_pooling(
                mean_query_embedding=mean_query_embedding,
                hidden_state_reshaped=hidden_state_reshaped
            )
            
            if res_len != 0:
                memorys = torch.cat((memorys, res_memory), dim=0).unsqueeze(0)
            else:
                memorys = memorys.unsqueeze(0)
            
            memorys_list.append(memorys)
            
        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        for idx, e in enumerate(memorys_list):
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            if pad_len != 0:
                pad_mask = torch.zeros(pad_len).to(device)
                att_mask[idx][-pad_len:] = pad_mask

            final_memorys_list.append(pad_memorys)

        att_mask = att_mask.to(device)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        
        aligned_memorys = self.memory_fusion_layer(
            inputs_embeds=final_memorys,
            attention_mask=att_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        return aligned_memorys, att_mask

    # def get_weighted_memorys(
    #     self,
    #     input_ids,
    #     input_mask
    # ):
    #     batch_size = input_ids.shape[0]
    #     output = self.llm_encoder(
    #         input_ids=input_ids,
    #         attention_mask=input_mask,
    #         output_hidden_states=True
    #     )
        
    #     # last_hidden_state.shape: [batch_size, context_tokens, hidden_size]
    #     last_hidden_state = output.hidden_states[-1][:, :-1, :]
    #     # shift_last_hidden_state.shape: [batch_size, -1, merge_size, hidden_size]
    #     shift_last_hidden_state = last_hidden_state.view(batch_size, -1, self.merge_size, self.compressor_hidden_size)
        
    #     logits = output.logits[..., :-1, :].contiguous()
    #     labels = input_ids[..., 1:, :].contiguous()
        
    #     shift_logits = logits.view(-1, logits.size(-1))
    #     shift_labels = labels.view(-1)
        
    #     # loss.shape: [batch_size, context_tokens]
    #     # loss.shape: [batch_size, 512]
    #     loss = self.loss_fct(shift_logits, shift_labels)
    #     group_loss = loss.reshape(batch_size, -1, self.merge_size)
    #     # weighs.shape: [batch_size, -1, merge_size, 1]
    #     weights = torch.softmax(group_loss, dim=-1).unsqueeze(-1)
    #     weighted_hidden_state = shift_last_hidden_state * weights
    #     # weighted_memorys.shape: [batch_size, -1, hidden_size]
    #     weighted_memorys = torch.sum(weighted_hidden_state, dim=-2)
        
    #     return weighted_memorys
    
    def normalize(
        self,
        x
    ):
        # x.shape: [seq_len, hidden_size]
        max_val = torch.max(x)
        min_val = torch.min(x)
        
        normalized_x = (x - min_val) / (max_val - min_val + 1e-8)
        return self.scale * normalized_x
    
    # def min_max_normalize(
    #     self,
    #     data
    # ):
    #     for i in range(data.shape[0]):
    #         for j in range(data.shape[1]):
    #             data[i][j] = self.normalize(data[i][j])
    #             pprint(data[i][j])
                
    #     return data        
    
    def get_weighted_memorys(
        self,
        input_ids,
        input_mask
    ):
        batch_size = input_ids.shape[0]
        device = input_ids.device
        input_ids_dtype = input_ids.dtype
        # add pad token
        if input_ids.shape[1] % self.merge_size == 0:
            pad_len = 0
        else:
            pad_len = self.merge_size - (input_ids.shape[1] % self.merge_size)
        pad_tokens = torch.full((batch_size, pad_len), self.tokenizer.pad_token_id, dtype=input_ids_dtype).to(device)
        pad_mask = torch.zeros(batch_size, pad_len, dtype=input_ids_dtype).to(device)
        input_ids = torch.cat([input_ids, pad_tokens], dim=1)
        input_mask = torch.cat([input_mask, pad_mask], dim=1)  
        
        # concate start token
        start_tokens = torch.full((batch_size, 1), self.tokenizer.bos_token_id, dtype=input_ids_dtype).to(device)
        start_attention_mask = torch.ones(batch_size, 1, dtype=input_ids_dtype).to(device)
        input_ids = torch.cat([start_tokens, input_ids], dim=1)
        attention_mask = torch.cat([start_attention_mask, input_mask], dim=1)    
    
        output = self.llm_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # last_hidden_state.shape: [batch_size, context_tokens, hidden_size]
        last_hidden_state = output.hidden_states[-1][:, :-1, :]
        # pprint(last_hidden_state.shape)
        # exit(0)
        
        mermorys_len = math.ceil(last_hidden_state.shape[1] / self.merge_size)
        pprint(mermorys_len)
        # shift_last_hidden_state.shape: [batch_size, mermorys_len, merge_size, hidden_size]
        shift_last_hidden_state = last_hidden_state.view(batch_size, mermorys_len, self.merge_size, self.compressor_hidden_size)
        
        logits = output.logits[..., :-1, :].contiguous()
        labels = input_ids[:, 1:].contiguous()
        # pprint(logits.shape)
        # pprint(labels.shape)
        
        shift_logits = logits.view(-1, logits.size(-1))
        shift_labels = labels.view(-1)
        # pprint(shift_logits.shape)
        # pprint(shift_labels.shape)
        # exit(0)
        
        # loss.shape: [batch_size, context_tokens]
        # loss.shape: [batch_size, 512]
        ppl = self.ppl_fct(shift_logits, shift_labels)
        group_ppl = ppl.reshape(batch_size, mermorys_len, self.merge_size)
        # normalized_group_ppl = self.min_max_normalize(group_ppl)
        normalized_group_ppl = torch.cat([self.normalize(e).unsqueeze(0) for e in group_ppl], dim=0)
        # convert_l = normalized_group_ppl.detach().cpu().tolist()
        # with open("normalized_convert.json", "w") as f:
        #     json.dump(convert_l, f, indent=4)
        # pprint(group_ppl)
        # exit(0)
        # weighs.shape: [batch_size, mermorys_len, merge_size]
        if self.mean:
            weights = torch.full((normalized_group_ppl.shape[0], normalized_group_ppl.shape[1]\
                , normalized_group_ppl.shape[2]), 1.0 / self.merge_size).to(torch.bfloat16).to(device) 
        else:
            weights = torch.softmax(normalized_group_ppl, dim=-1)

        # convert_weights = weights.detach().cpu().tolist()
        # with open(f"weights_{self.scale}.json", "a") as f:
        #     json.dump(convert_weights, f, indent=4)
        # exit(0)
        reshaped_mask = input_mask.reshape(batch_size, mermorys_len, self.merge_size)
        memorys_mask = (reshaped_mask.sum(dim=-1) != 0).float().to(input_ids.dtype)
        # all_ones = (memorys_mask == 1).all()
        # pprint(all_ones)
        weights = weights * reshaped_mask
            
        weighted_hidden_state = shift_last_hidden_state * (weights.unsqueeze(-1))
        # weighted_memorys.shape: [batch_size, mermorys_len, hidden_size]
        weighted_memorys = torch.sum(weighted_hidden_state, dim=-2)
        pprint(weighted_memorys.shape)
        aligned_memorys = self.semantic_alignment_layer(weighted_memorys).to(torch.bfloat16)
        # att_mask = torch.ones(weighted_memorys.shape[0], weighted_memorys.shape[1]).to(device)
        
        return aligned_memorys, memorys_mask

    def construct_input_embeds(
        self,
        input_ids, 
        input_mask,
        bert_prompt_ids,
        bert_prompt_mask
    ):
        segments, memorys, segments_attention_mask = self.split_segments(
            input_ids,
            self.memory_sequence,
            input_mask
        ) 
        
        pprint(len(segments))
        pprint(segments[0].shape)
        # exit(0)
        
        self.memory_sequence = self.memory_sequence.to(segments[0].device)

        memorys_embeds = self.memory_token_embed(self.memory_sequence - self.vocab_size)
        memorys_embeds = memorys_embeds.repeat(input_ids.shape[0], 1, 1)
        memorys_attention_mask = torch.ones((input_ids.shape[0], \
                                             self.memory_sequence.shape[0])).to(bert_prompt_mask.device)

        bert_prompt_embeds = self.llm_encoder.embed_tokens(bert_prompt_ids)
        total_mem_size = len(segments) * self.mem_size
        final_memorys = None

        total_loss = 0
        for idx, segment in enumerate(segments):
            segment_embeds = self.llm_encoder.embed_tokens(segment)
            # final_input_embeds = torch.cat((bert_prompt_embeds, segment_embeds, memorys_embeds), dim=1)
            # attention_mask = torch.cat((bert_prompt_mask, segments_attention_mask[idx]  \
            #                             , memorys_attention_mask), dim=1)

            # memorys_mask = [False] * (bert_prompt_embeds.shape[1] + segment_embeds.shape[1]) \
            #                 + [True] * memorys_embeds.shape[1]
            # query_mask = [True] * bert_prompt_embeds.shape[1] + [False] * segment_embeds.shape[1] \
            #                 + [False] * memorys_embeds.shape[1]

            # final_input_embeds = torch.cat((segment_embeds, bert_prompt_embeds, memorys_embeds), dim=1)
            # attention_mask = torch.cat((segments_attention_mask[idx], bert_prompt_mask,   \
            #                             memorys_attention_mask), dim=1)

            # memorys_mask = [False] * (segment_embeds.shape[1] + bert_prompt_embeds.shape[1]) \
            #                 + [True] * memorys_embeds.shape[1]
            # query_mask = [False] * segment_embeds.shape[1] + [True] * bert_prompt_embeds.shape[1] \
            #                 + [False] * memorys_embeds.shape[1]

            final_input_embeds = torch.cat((segment_embeds, memorys_embeds), dim=1)
            attention_mask = torch.cat((segments_attention_mask[idx], memorys_attention_mask), dim=1)

            memorys_mask = [False] * segment_embeds.shape[1] + [True] * memorys_embeds.shape[1]
            # query_mask = [False] * segment_embeds.shape[1] + [True] * bert_prompt_embeds.shape[1] \
            #                 + [False] * memorys_embeds.shape[1]

            compress_outputs = self.llm_encoder(
                inputs_embeds=final_input_embeds,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            # pprint(len(compress_outputs.hidden_states))
            # pprint(compress_outputs.hidden_states[0].shape)
            # exit(0)
            last_hidden_state = compress_outputs.last_hidden_state
            # enc_query_embeds = last_hidden_state[:, query_mask, :]
            # enc_memorys_embeds = last_hidden_state[:, memorys_mask, :]
            # segment_loss = self.mrMR_loss2(
            #     query_embeds=enc_query_embeds,
            #     memory_embeds=enc_memorys_embeds,
            #     query_attention_mask=bert_prompt_mask
            # )
            # memorys_hidden_state = [e[:, memorys_mask, :] for e in compress_outputs.hidden_states]
            # global_memorys_embeds = self.vertical_pooling(
            #     query_embeds=enc_query_embeds,
            #     memorys_hidden_state=memorys_hidden_state,
            #     query_attention_masks=bert_prompt_mask,
            # )
            global_memorys_embeds = last_hidden_state[:, memorys_mask, :]
            # segment_loss = self.mrMR_loss2(
            #     enc_query_embeds, 
            #     global_memorys_embeds,
            #     bert_prompt_mask
            # )
            # segment_relevence = self.max_relevence_score(
            #     query_embeds=enc_query_embeds,
            #     memory_embeds=enc_memorys_embeds,
            #     query_attention_mask=bert_prompt_mask                
            # )
            # total_loss += segment_loss
            if final_memorys == None:
                final_memorys = global_memorys_embeds
            else:
                final_memorys = torch.cat((final_memorys, global_memorys_embeds), dim=1)

            del segment, segment_embeds, final_input_embeds, global_memorys_embeds
            torch.cuda.empty_cache()
        # with open("attention_masks.txt", "w") as f:
        #     f.write(str(attention_masks))
        # exit(0)

        aligned_memorys_embeds = self.semantic_alignment_layer(final_memorys)
        return {
            "final_memorys" : final_memorys,
            "total_mem_size" : total_mem_size,
            "aligned_memorys_embeds" : aligned_memorys_embeds,
            # "total_loss" : 1.0 * total_loss / len(segments)
        }

    def split_to_segments(
        self,
        enc_doc_ids,
        enc_doc_mask,
        segment_size
    ):
        num_segments = math.ceil(1.0 * enc_doc_ids.shape[1] / segment_size)
        if num_segments == 1:
            segments = [enc_doc_ids]
            segments_masks = [enc_doc_mask]
        else:
            segments = [enc_doc_ids[:, i * segment_size : (i + 1) * segment_size] if i != num_segments - 1 else \
                enc_doc_ids[:, i * segment_size : ] for i in range(num_segments)]
            segments_masks = [enc_doc_mask[:, i * segment_size : (i + 1) * segment_size] if i != num_segments - 1 else \
                enc_doc_mask[:, i * segment_size : ] for i in range(num_segments)]
            
        return segments, segments_masks

    def generate_post_append_memorys(
        self,
        input_ids,
        input_mask
    ):
        device = input_ids.device
        batch_size = input_ids.shape[0]
        
        mem_map = {
            4: 128,
            8: 64
        }
        
        if self.is_random:
            actual_mem_size = self.generate_post_append_size()
        else:
            if not self.icae_infer:
                actual_mem_size = self.mem_size
            else:
                actual_mem_size = mem_map[self.merge_size]
            
        self.memory_sequence = self.memory_sequence.to(device)
        
        memorys_embeddings = torch.cat([self.memory_token_embed(self.memory_sequence[:actual_mem_size] - self.vocab_size).unsqueeze(0) for _ in range(batch_size)], dim=0)
        memorys_attention_mask = torch.ones(memorys_embeddings.shape[0], memorys_embeddings.shape[1]).to(device)
        
        input_embeddings = self.llm_encoder.get_base_model().model.embed_tokens(input_ids)
        
        final_input_embeddings = torch.cat((input_embeddings, memorys_embeddings), dim=1)
        attention_mask = torch.cat((input_mask, memorys_attention_mask), dim=1)
        
        last_hidden_state = self.llm_encoder(
            inputs_embeds=final_input_embeddings,
            attention_mask=attention_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        final_memorys = last_hidden_state[:, -actual_mem_size:, :]
        # aligned_memorys = final_memorys.to(torch.bfloat16)
        aligned_memorys = self.semantic_alignment_layer(final_memorys).to(torch.bfloat16)
        
        return aligned_memorys, memorys_attention_mask
    
    def forward(
        self,
        enc_doc_ids,
        enc_doc_mask,
        target_doc_ids,
        target_doc_mask,
        llm_ins_ids=None,
        llm_ins_mask=None,
        enc_prefix_ids=None,
        enc_prefix_mask=None,
        enc_repeat_ids=None,
        enc_repeat_mask=None,
        llm_answer_ids=None,
        llm_answer_mask=None,
        llm_que_ids=None,
        llm_que_mask=None,
        enc_continue_ids=None,
        enc_continue_mask=None,
        **kwargs
    ):
        # enc_doc_ids, enc_doc_mask, target_doc_ids, target_doc_mask = self.batch_text_extraction(
        #     enc_doc_ids=enc_doc_ids,
        #     enc_doc_mask=enc_doc_mask,
        # )
        
        if self.split:
            segments, segments_mask = self.split_to_segments(
                enc_doc_ids, 
                enc_doc_mask, 
                self.training_args.segment_size
            )
            segments_merge_size = [self.generate_merge_size() for _ in range(len(segments))]
        else:
            segments = [enc_doc_ids]
            segments_mask = [enc_doc_mask]
        
        if self.post_append:
            aligned_memorys, memorys_masks = zip(*[
                self.generate_post_append_memorys(input_ids=s, input_mask=m)
                for s, m in zip(segments, segments_mask)
            ])     
        else:
            if self.split:
                if self.autoregressive:
                    aligned_memorys = None
                    memorys_mask = None
                    for s, m, m_s in zip(segments, segments_mask, segments_merge_size):
                        aligned_memorys, memorys_mask = self.generate_autoregressive_pooling_memorys(
                            input_ids=s, 
                            input_mask=m,
                            merge_size=m_s,
                            pre_mem_embeds=aligned_memorys,
                            pre_mem_attention_mask=memorys_mask
                        ) 
                else:
                    aligned_memorys, memorys_masks = zip(*[
                        self.generate_pooling_memorys(
                            input_ids=s, 
                            input_mask=m,
                            merge_size=m_s
                        )
                        for s, m, m_s in zip(segments, segments_mask, segments_merge_size)
                    ])
            else:
                if self.keft:
                    aligned_memorys, memorys_masks = zip(*[
                        self.generate_query_guided_pooling_memorys(
                            input_ids=s, 
                            input_mask=m, 
                            query_ids=kwargs['enc_que_ids'],
                            query_mask=kwargs['enc_que_mask'],
                            merge_size=self.merge_size
                        )
                        for s, m in zip(segments, segments_mask)
                    ])
                else:
                    aligned_memorys, memorys_masks = zip(*[
                        self.generate_pooling_memorys(
                            input_ids=s, 
                            input_mask=m, 
                            merge_size=self.merge_size
                        )
                        for s, m in zip(segments, segments_mask)
                    ])

        if not self.autoregressive:
            if len(aligned_memorys) == 1:
                aligned_memorys = aligned_memorys[0]
                memorys_mask = memorys_masks[0]
            else:
                aligned_memorys = torch.cat(aligned_memorys, dim=1)
                memorys_mask = torch.cat(memorys_masks, dim=1)

        pprint(aligned_memorys.shape)  

        # if not self.fine_tune:
        if random.random() < self.restatement_ratio:
            prefix_embeds = self.llm.model.embed_tokens(enc_prefix_ids)
            target_embeds = self.llm.model.embed_tokens(target_doc_ids)
            
            repeat_embeds = self.llm.model.embed_tokens(enc_repeat_ids)
            llm_input_embedings = torch.cat((prefix_embeds, aligned_memorys,\
                repeat_embeds, target_embeds), dim=1)
            pprint(llm_input_embedings.shape)
            llm_attention_mask = torch.cat((enc_prefix_mask, memorys_mask, \
                enc_repeat_mask, target_doc_mask), dim=1)
            
            llm_reconstruct_labels = torch.full_like(llm_attention_mask, -100)
            llm_reconstruct_labels[:, -target_doc_ids.size(1):] = target_doc_ids.masked_fill(
                ~target_doc_mask.bool(), -100,
            )
            
            llm_outputs = self.llm(
                inputs_embeds=llm_input_embedings,
                attention_mask=llm_attention_mask
            )  
            logits = llm_outputs.logits

            effective_logits = logits[:, :-1,:].reshape(-1, logits.size(-1))
            target_ids = llm_reconstruct_labels[:, 1:].reshape(-1).to(torch.long)
            loss = self.loss_fct(effective_logits, target_ids)
        else:
            # device = target_doc_ids.device
            # # target_doc_embeds = self.llm.get_base_model().model.embed_tokens(target_doc_ids)
            # target_doc_embeds = self.llm.model.embed_tokens(target_doc_ids)

            # llm_input_embedings = torch.cat((aligned_memorys, target_doc_embeds), dim=1)
            # llm_attention_mask = torch.cat((memorys_mask, target_doc_mask), dim=1)
            
            # # kwargs['labels'].shape: [batch_size, target_doc_len]
            # # pad_labels.shape: [batch_size, mem_len]
            # pad_labels = torch.full(memorys_mask.shape, -100).to(device)
            # llm_fine_tune_labels = torch.cat((pad_labels, kwargs['labels']), dim=1)
            
            # llm_outputs = self.llm(
            #     inputs_embeds=llm_input_embedings,
            #     attention_mask=llm_attention_mask
            # )  
            # pprint(llm_input_embedings.shape)
            
            # logits = llm_outputs.logits

            # effective_logits = logits[:, :-1,:].reshape(-1, logits.size(-1))
            # target_ids = llm_fine_tune_labels[:, 1:].reshape(-1).to(torch.long)
            # loss = self.loss_fct(effective_logits, target_ids)   
            # pprint(loss)  
            
            # if not self.full:
            # prefix_embeds = self.llm.get_base_model().model.embed_tokens(enc_prefix_ids)
            # lora config
            llm_ins_embeds = self.llm.model.embed_tokens(llm_ins_ids)
            llm_que_embeds = self.llm.model.embed_tokens(llm_que_ids)
            answer_embeds = self.llm.model.embed_tokens(llm_answer_ids)
            # else:
            #     prefix_embeds = self.llm.model.embed_tokens(enc_prefix_ids)
            # llm_ins_embeds = self.llm.model.embed_tokens(llm_ins_ids)
            # llm_que_embeds = self.llm.model.embed_tokens(llm_que_ids)
            # answer_embeds = self.llm.model.embed_tokens(llm_answer_ids)
            
            # llm_input_embedings = torch.cat((prefix_embeds, aligned_memorys, \
            #     llm_ins_embeds, llm_que_embeds, answer_embeds), dim=1)
            # llm_attention_mask = torch.cat((enc_prefix_mask, memorys_mask, \
            #     llm_ins_mask, llm_que_mask, llm_answer_mask), dim=1)

            llm_input_embedings = torch.cat((llm_ins_embeds, aligned_memorys, \
                llm_que_embeds, answer_embeds), dim=1)
            llm_attention_mask = torch.cat((llm_ins_mask, memorys_mask, \
                llm_que_mask, llm_answer_mask), dim=1)
            
            pprint(llm_answer_mask.shape)
            llm_fine_tune_labels = torch.full_like(llm_attention_mask, -100)
            llm_fine_tune_labels[:, -llm_answer_ids.size(1):] = llm_answer_ids.masked_fill(
                ~llm_answer_mask.bool(), -100,
            )
            
            llm_outputs = self.llm(
                inputs_embeds=llm_input_embedings,
                attention_mask=llm_attention_mask
            )  
            pprint(llm_input_embedings.shape)
            
            logits = llm_outputs.logits

            effective_logits = logits[:, :-1,:].reshape(-1, logits.size(-1))
            target_ids = llm_fine_tune_labels[:, 1:].reshape(-1).to(torch.long)
            loss = self.loss_fct(effective_logits, target_ids)   
            # pprint(loss.dtype)
            # exit(0)

        return {"loss" : loss, "logits" : logits}

