import transformers
from transformers import BertModel, BertTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoTokenizer, LlamaModel, AutoConfig
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.distributed as dist
from dataclasses import dataclass, field
from typing import Optional, List
from peft import (
    get_peft_model,
)
from safetensors.torch import load_file
from icecream import ic as pprint
from math import sqrt
import random

import json


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None, metadata={"help": "Path to model."})
    target_model: str = field(default=None, metadata={"help": "Path to target model."})
    
    lora_r: int = 128
    lora_dropout: float = 0.05
    lora_alpha: int = 32

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    is_train: bool = True
    ppl_memory: bool = True
    mrMR: bool = True
    reconstruct: bool = True
    alpha: float = 1.0
    compressor_hidden_size: int = 4096
    target_llm_hidden_size: int = 4096
    mem_size: int = 32
    segment_size: int = 200
    benchmark_metric: str = None
    compressor_hidden_layers: int = 2
    merge_size: int = 4
    head_num: int = 8
    scale: float = 0.5
    mean: bool = True
    num_mem_fusion_layers: int = 1
    mem_lora: bool = True

    segment_size: int = 200
    benchmark_metric: str = "accuracy"
    post_append: bool = False
    is_random: bool = True
    
    split: bool = False
    autoregressive: bool = True
    
    early_stopping_patience: int = 1
    fine_tune: bool = False
    
    lm_ratio: float = 0.5
    leave_len: int = 100
    
    prefix_type: str = 'rs_prefix'
    full: bool = False
    
    keft: bool = False
    restatement_ratio: float = 0.5
    
    icae_infer: bool = False
    
    use_transform_layer: bool = True

    # 启用梯度检查点
    gradient_enable: bool = True

    # add for tkdr
    launch_tkdr: bool = True
    key_percentage: float = 0.005
    merge_sizes: str = "8"
    adaptive_pick: bool = True
    tau: float = 0.85

    lamda_select: bool = True
    lamda_merge: bool = True

    checkpoint_path: str = ""

    lora_the_encoder: bool = False

    # add for ablation study
    coarse_grained_on: bool = True
    fine_grained_on: bool = True
    redun_coarse: bool = True
    redun_fine: bool = True

@dataclass
class COMIArguments:
    compressor_path: str = None
    lm_model_path: str = None
    lm_model_name: str = 'longchat'
    num_compressor_layers: int = 4
    num_compressor_encoder_layers: int = 2
    fix_compressor_mlp_parameters: bool = False
    num_attention_heads: int = 32
    attn_doc_topp: float = 0.25
    generation_split_token: str = None
    pool_window_size: int = 4
    random_pool_window_size: bool = False
    cand_pool_window_sizes: List[int] = None

    label_pad_token_id: int = -100

    # inference args
    pw_window_sizes: List[int] = None
    data_path: str = None
    train_data_path: str = None
    num_eval_documents: int = 5

    num_gold_documents: int = 1
    use_answer_as_target: bool = False
    instruction_name: str = 'base'
    # instruction_name: str = 'summary'
    gold_first_for_kd: bool = False

    min_num_documents: int = 1
    max_num_documents: int = 5
    random_num_documents: bool = False

    max_new_tokens: int = 100
    max_doc_tokens: int = 512
    
    restatement: bool = True

def freeze_model(model):
    for _, param in model.named_parameters():
        param.requires_grad = False
        
def freeze_mlp(model):
    for name, param in model.named_parameters():
        if 'mlp' in name:
            param.requires_grad = False

def print_trainable_parameters(model):
    trainable_parameters = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_parameters += param.numel()
    print(f"trainable params: {trainable_parameters} || all params: {all_param} || trainable: {100 * trainable_parameters / all_param} %")

def write_txt(name, item):
    with open(name, "a+") as f:
        f.write(str(item))
        f.write("\n")
    
def select_key_values(key_values, need_idx):
    return [
        [
            k[..., need_idx, :],
            v[..., need_idx, :],
        ]
        for k, v in key_values
    ]

def get_embeding_from_block(
    text,
    model,
    tokenizer
):
    encoded_input = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    with torch.no_grad():
        output = model(**encoded_input)
    
    last_hidden_states = output.last_hidden_state
    mean_pooled_embedding = torch.mean(last_hidden_states[:, 1:-1, :], dim=1)[0][0]
    return mean_pooled_embedding

def mrMR_loss(
    query_embeds, 
    memory_embeds,
):
    # q : [batch_size, query_len, hidden_size]
    # memory_embeds : [batch_size, memory_size, hidden_size]
    distilation_loss = 0
    batch_size = query_embeds.shape[0]
    
    for idx in range(batch_size):

        query_embed = torch.mean(query_embeds[idx], dim=0).squeeze(0)
        # pprint(query_embed.shape)
        loss = 0
        for memory_embed in memory_embeds[idx]:
            relevance = query_embed @ memory_embed
            debias_memory_embeds = [v for v in memory_embeds[idx] if not torch.equal(memory_embed, v)]
            redundancy = 0
            for e in debias_memory_embeds:
                redundancy += (e @ memory_embed)
            redundancy = 1.0 * redundancy / len(debias_memory_embeds)
            loss += (redundancy - relevance)
        
        loss = 1.0 * loss / len(memory_embeds[idx])

        distilation_loss += loss 

    return 1.0 * distilation_loss / batch_size

def weight_to_mem_size(
    weights,
    mem_size      
):
    allocated_mem_sizes = []
    for idx, w in enumerate(weights):
        if idx == (len(weights) - 1):
            s = sum(allocated_mem_sizes)
            allocated_mem_sizes.append(mem_size - s)
        else:
            allocated_mem_sizes.append(int(w * mem_size))

    return allocated_mem_sizes

def dynamic_allocator(
    segments, 
    query,
    mem_size,
    model,
    tokenizer
):
    query_embed = get_embeding_from_block(
        query,
        model,
        tokenizer
    )
    similaritys = []
    sum_similarity = 0
    for segment in segments:
        segment_embed = get_embeding_from_block(
            segment,
            model,
            tokenizer
        )

        similarity = query_embed @ segment_embed
        sum_similarity += similarity
        similaritys.append(similarity)

    normalize_similaritys = [1.0 * e / sum_similarity for e in similaritys]
    allocated_mem_sizes = weight_to_mem_size(normalize_similaritys, mem_size)
    return allocated_mem_sizes

class GroupReallocator:
    def __init__(self, alpha_select, beta_select, redun_coarse=True):
        self.alpha_select = alpha_select
        self.beta_select = beta_select

        self.redun_coarse = redun_coarse

    def reallocate_group_size(
        self,
        group_size,
        context_embeddings,
        query_embedding
    ):
        N, D = context_embeddings.shape

        n_group = math.ceil(N / group_size)
        if n_group == 1:
            return [N]

        context_normalized = F.normalize(context_embeddings, p=2, dim=1)  # [N, D]
        query_normalized = F.normalize(query_embedding.unsqueeze(0), p=2, dim=1).squeeze()  # [D]

        relevance = context_normalized @ query_normalized  # [N]

        pooled_representatives = []  
        group_relevance = []         

        for i in range(n_group):
            start_idx = i * group_size
            end_idx = min(start_idx + group_size, N)

            group_relevance_slice = relevance[start_idx:end_idx]

            local_max_idx = torch.argmax(group_relevance_slice)
            global_max_idx = start_idx + local_max_idx

            representative = context_embeddings[global_max_idx]  # [D]
            pooled_representatives.append(representative)

            group_relevance.append(relevance[global_max_idx].item())

        pooled = torch.stack(pooled_representatives)  # [n_group, D]
        group_relevance = torch.tensor(group_relevance, device=context_embeddings.device)  # [n_group]


        if self.redun_coarse:
            pooled_normalized = F.normalize(pooled, p=2, dim=1)  # [n_group, D]
            cosine_sim_matrix = pooled_normalized @ pooled_normalized.T  # [n_group, n_group]
            cosine_sim_matrix.fill_diagonal_(-float('inf'))  
            redundancy = cosine_sim_matrix.max(dim=1).values  
        else:
            redundancy = torch.zeros_like(group_relevance)

        scores = self.alpha_select * group_relevance - self.beta_select * redundancy  # [n_group]

        probs = F.softmax(-scores, dim=0).detach().cpu().numpy().copy()  

        def distribute_integers(probs, total):
            quotas = [total * p for p in probs]
            ints = [int(q) for q in quotas]
            remainder = total - sum(ints)
            fractional_parts = [(q - int(q), i) for i, q in enumerate(quotas)]
            fractional_parts.sort(reverse=True, key=lambda x: x[0])
            for i in range(remainder):
                idx = fractional_parts[i][1]
                ints[idx] += 1
            return ints

        reallocated_sizes = distribute_integers(probs, N)
        return reallocated_sizes


class TKDR:
    def __init__(
            self, 
            alpha_select, 
            beta_select, 
            alpha_merge, 
            beta_merge,
            coarse_grained_on,
            fine_grained_on,
            redun_coarse,
            redun_fine
        ):
            self.alpha_select = alpha_select
            self.beta_select = beta_select
            self.alpha_merge = alpha_merge
            self.beta_merge = beta_merge

            self.coarse_grained_on = coarse_grained_on
            self.fine_grained_on = fine_grained_on
            self.redun_coarse = redun_coarse
            self.redun_fine = redun_fine

    @staticmethod
    def max_cosine_sim_except_self(embs: torch.Tensor) -> torch.Tensor:
        embs_norm = F.normalize(embs, p=2, dim=1)
        sim_mat = embs_norm @ embs_norm.T
        sim_mat.fill_diagonal_(-1)
        rel2self = sim_mat.max(dim=1).values
        return rel2self

    @staticmethod
    def cosine_sim(a, b):
        a_norm = F.normalize(a, dim=-1)
        b_norm = F.normalize(b, dim=-1)
        return torch.matmul(a_norm, b_norm.transpose(0, 1))

    @staticmethod
    def partition_by_keyframes(N, key_indices):
        key_indices = sorted(key_indices.tolist())
        boundaries = []
        prev = 0
        for ki in key_indices:
            boundaries.append((prev, ki))
            prev = ki + 1
        boundaries.append((prev, N))
        return boundaries

    @staticmethod
    def make_groups(start_idx: int, end_idx: int, group_size: int):
        if end_idx < start_idx:
            raise ValueError("end_idx must be >= start_idx")
        
        total_len = end_idx - start_idx + 1
        num_full = total_len // group_size
        leftover = total_len % group_size
        
        groups = []
        cur = start_idx
        for i in range(num_full):
            size = group_size + leftover if i == num_full - 1 else group_size
            groups.append((cur, cur + size - 1))
            cur += size
        
        if not groups:
            groups.append((start_idx, end_idx))
        return groups

    @staticmethod
    def batched_weighted_pooling(group_embs: torch.Tensor, 
                                group_mask: torch.Tensor, 
                                query_norm: torch.Tensor, 
                                alpha: float = 1e-2):
        if group_embs.shape[0] == 0:
            return torch.empty((0, group_embs.shape[-1]), 
                             device=group_embs.device, 
                             dtype=group_embs.dtype)
        
        M, G_max, D = group_embs.shape
        
        # L2 normalization
        embs_norm = F.normalize(group_embs, p=2, dim=-1)
        
        # rel2query
        rel2query = torch.matmul(embs_norm, query_norm)
        
        # rel2self
        sim_mat = torch.bmm(embs_norm, embs_norm.transpose(1, 2))
        sim_mat.diagonal(dim1=-2, dim2=-1).fill_(-1)
        
        mask_2d = group_mask.unsqueeze(1).expand(-1, G_max, -1)
        sim_mat.masked_fill_(~mask_2d, -torch.inf)
        rel2self = sim_mat.max(dim=-1).values
        
        logits = alpha * (rel2query - rel2self)
        logits.masked_fill_(~group_mask, -torch.inf)
        weights = F.softmax(logits, dim=-1)
        
        pooled = (weights.unsqueeze(-1) * group_embs).sum(dim=1)
        return pooled

    def select_key_tokens_fast(
        self, 
        context_embeds_norm, 
        query_embed_norm, 
        num_keyframes=8, 
        redundancy_thresh=0.85, 
        precomputed_rel2query=None
    ):
        if precomputed_rel2query is not None:
            rel2query = precomputed_rel2query
        else:
            rel2query = (context_embeds_norm @ query_embed_norm.unsqueeze(-1)).squeeze(-1)

        sorted_idx = torch.argsort(rel2query, descending=True)
        
        key_token_indices = []
        exist_k_norm = []
        for idx in sorted_idx:
            if not key_token_indices:
                key_token_indices.append(idx.item())
                exist_k_norm.append(context_embeds_norm[idx])
            else:
                dist = (context_embeds_norm[idx] @ torch.stack(exist_k_norm).T).max()
                if dist < redundancy_thresh:
                    key_token_indices.append(idx.item())
                    exist_k_norm.append(context_embeds_norm[idx])
                    
            if len(key_token_indices) >= num_keyframes:
                break
                
        key_token_indices_tensor = torch.tensor(sorted(key_token_indices), 
                                              dtype=torch.long, 
                                              device=context_embeds_norm.device)
        return key_token_indices_tensor, key_token_indices_tensor

    def pick_by_threshold(self, attn_weights, tau=0.99):
        w = torch.sort(attn_weights, descending=True).values
        cumsum = torch.cumsum(w, dim=0)
        total  = cumsum[-1]
        pprint(total)
        k = int((cumsum / total >= tau).nonzero(as_tuple=True)[0][0]) + 1
        return k 

    def compress(self, context_embeddings, query_embedding, compress_rate, 
                key_percentage, redundancy_thresh=0.85, alpha=1e-2):
        N, D = context_embeddings.shape
        device = context_embeddings.device
        pprint(query_embedding.shape)
        
        context_norm = F.normalize(context_embeddings, dim=-1)
        query_norm = F.normalize(query_embedding, dim=-1)
        
        rel2query = (context_norm @ query_norm.unsqueeze(-1)).squeeze(-1)  # [N]
        rel2query = (query_embedding.unsqueeze(0) @ context_embeddings.T).squeeze(-1)
        scores = rel2query / math.sqrt(D)
        attention_weights = F.softmax(scores, dim=-1).squeeze(0)
        pprint(attention_weights.shape)

        if getattr(self, 'adaptive_pick', False):
            K = self.pick_by_threshold(attention_weights, tau=self.tau)  
            K = max(1, min(K, N))  
            print("adaptive select")
        else:
            K = int(N * key_percentage)
            K = max(1, min(K, N))
        pprint(K)
        exit(0)
        key_token_indices, _ = self.select_key_tokens_fast(
            context_norm, query_norm, K, redundancy_thresh, rel2query
        )
        key_tokens = context_embeddings[key_token_indices]
        
        total_len = N - K
        compressed_len = int(round(N / compress_rate))
        res_len = compressed_len - K
        pprint(res_len)
        
        
        if res_len <= 0:
            return key_tokens
            
        group_size = max(1, total_len // res_len)
        
        regions = self.partition_by_keyframes(N, key_token_indices)
        all_groups = []
        for region in regions:
            start_idx, end_idx = region
            if start_idx >= end_idx:
                continue
            all_groups += self.make_groups(start_idx, end_idx - 1, group_size)
        
        if not all_groups:
            pooled_embs = torch.empty((0, D), device=device, dtype=context_embeddings.dtype)
        else:
            max_group_len = max(g[1] - g[0] + 1 for g in all_groups)
            num_groups = len(all_groups)
            
            group_embs_padded = torch.zeros(num_groups, max_group_len, D, 
                                          device=device, dtype=context_embeddings.dtype)
            group_mask = torch.zeros(num_groups, max_group_len, 
                                    device=device, dtype=torch.bool)
            
            for i, group in enumerate(all_groups):
                start, end = group
                length = end - start + 1
                group_embs_padded[i, :length] = context_embeddings[start:end+1]
                group_mask[i, :length] = True

            pooled_embs = self.batched_weighted_pooling(
                group_embs_padded, group_mask, query_norm, alpha
            )

        group_tuples = [(group[0], "group", i) for i, group in enumerate(all_groups)]
        key_token_tuples = [(k.item(), "key", i) for i, k in enumerate(key_token_indices)]
        all_items = sorted(group_tuples + key_token_tuples, key=lambda x: x[0])
        
        compressed_tokens_list = []
        for _, type_flag, idx in all_items:
            if type_flag == "group":
                compressed_tokens_list.append(pooled_embs[idx])
            else:
                compressed_tokens_list.append(key_tokens[idx])

        if not compressed_tokens_list:
            return torch.empty((0, D), device=device, dtype=context_embeddings.dtype)
            
        return torch.stack(compressed_tokens_list, dim=0)

    def compress_by_mrmr(self, context_embeddings, query_embedding, compress_rate, 
                        group_size=None):
        N, D = context_embeddings.shape
        device = context_embeddings.device
        

        context_norm = F.normalize(context_embeddings, dim=-1)
        query_norm = F.normalize(query_embedding, dim=-1)
        

        target_length = max(1, int(N / compress_rate))

        if self.coarse_grained_on:

            Reallocator = GroupReallocator(
                self.alpha_select, 
                self.beta_select, 
                redun_coarse=getattr(self, 'redun_coarse', True) 
            )        

            reallocated_sizes = Reallocator.reallocate_group_size(
                group_size=compress_rate, 
                context_embeddings=context_embeddings,
                query_embedding=query_embedding
            )
        else:
            num_groups = max(1, round(N / compress_rate))
            base_group_size = N // num_groups
            remainder = N % num_groups
            reallocated_sizes = [base_group_size + 1 if i < remainder else base_group_size for i in range(num_groups)]

        all_groups = []
        start = 0
        for size in reallocated_sizes:
            end = start + size - 1
            all_groups.append((start, end))
            start = end + 1
        
        if not all_groups:
            return context_embeddings[:target_length]  

        max_group_len = max(g[1] - g[0] + 1 for g in all_groups)
        num_groups = len(all_groups)

        group_embs_padded = torch.zeros(num_groups, max_group_len, D, 
                                    device=device, dtype=context_embeddings.dtype)
        group_mask = torch.zeros(num_groups, max_group_len, 
                                device=device, dtype=torch.bool)
        
        for i, group in enumerate(all_groups):
            start, end = group
            length = end - start + 1
            group_embs_padded[i, :length] = context_embeddings[start:end+1]
            group_mask[i, :length] = True

        if self.fine_grained_on:
            pooled_embs = self.batched_weighted_pooling_mrmr(
                group_embs_padded, group_mask, query_norm
            )
        else:
            masked_embs = group_embs_padded * group_mask.unsqueeze(-1)
            sum_pooled = masked_embs.sum(dim=1) # [num_groups, D]
            count_per_group = group_mask.sum(dim=1, keepdim=True) # [num_groups, 1]
            count_per_group = torch.clamp(count_per_group, min=1) 
            pooled_embs = sum_pooled / count_per_group # [num_groups, D]
        
        return pooled_embs

    def batched_weighted_pooling_mrmr(
        self, 
        group_embs: torch.Tensor, 
        group_mask: torch.Tensor, 
        query_norm: torch.Tensor, 
    ):
        if group_embs.shape[0] == 0:
            return torch.empty((0, group_embs.shape[-1]), 
                            device=group_embs.device, 
                            dtype=group_embs.dtype)
        
        M, G_max, D = group_embs.shape
        
        # L2 normalization
        embs_norm = F.normalize(group_embs, p=2, dim=-1)
        
        rel2query = torch.matmul(embs_norm, query_norm)  # [M, G_max]
        
        if self.redun_fine:
            sim_mat = torch.bmm(embs_norm, embs_norm.transpose(1, 2))  # [M, G_max, G_max]

            sim_mat.diagonal(dim1=-2, dim2=-1).fill_(0)
            
            mask_2d = group_mask.unsqueeze(1).expand(-1, G_max, -1)  # [M, G_max, G_max]
            sim_mat.masked_fill_(~mask_2d, 0)  
            
            valid_count = group_mask.sum(dim=-1, keepdim=True).unsqueeze(-1)  # [M, 1, 1]
            valid_count = torch.clamp(valid_count, min=1)  
            
            avg_redundancy = sim_mat.sum(dim=-1) / valid_count.squeeze(-1)  # [M, G_max]
        else:
            avg_redundancy = torch.zeros_like(rel2query)

        mrmr_scores = self.alpha_merge * rel2query - self.beta_merge * avg_redundancy
        
        mrmr_scores.masked_fill_(~group_mask, -torch.inf)
        
        weights = F.softmax(mrmr_scores, dim=-1)  # [M, G_max]
        
        pooled = (weights.unsqueeze(-1) * group_embs).sum(dim=1)  # [M, D]
        
        return pooled

class MemoryFusion(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MemoryFusion, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.q_proj = nn.Linear(d_model, d_model).to(torch.bfloat16)
        self.k_proj = nn.Linear(d_model, d_model).to(torch.bfloat16)
        self.v_proj = nn.Linear(d_model, d_model).to(torch.bfloat16)
        self.fct = nn.Linear(d_model, d_model).to(torch.bfloat16)
        
    def forward(self, X, attention_mask=None):
        # pprint(X.dtype)
        X = X.to(torch.bfloat16)
        # exit(0)
        batch_size, seq_len, d_model = X.size()
        
        Q = self.q_proj(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.k_proj(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.v_proj(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.bfloat16))
        
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  
            scores = scores.masked_fill(attention_mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        
        output = torch.matmul(attention_weights, V)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # output.shape: [batch_size, seq_len, d_model]
        output = self.fct(output)
        
        # pooling
        final_output = torch.mean(output, dim=1)
        
        return final_output

class LLMCompressor(nn.Module):
    def __init__(
            self,
            model_args, 
            training_args,
            lora_config,
            max_doc_tokens
        ):
        
        super().__init__()

        self.model_args = model_args
        self.training_args = training_args
        
        self.model_name = model_args.model_name_or_path
        self.target_llm = model_args.target_model
        # self.model_name = model_args.local_encoder_path
        # self.target_llm = model_args.local_decoder_path

        self.training = training_args.is_train
        self.ppl_memory = training_args.ppl_memory
        self.post_append = training_args.post_append
        # pprint(self.ppl_memory)
        self.mrMR = training_args.mrMR
        self.alpha = training_args.alpha
        self.reconstruct = training_args.reconstruct
        self.merge_size = training_args.merge_size
        self.compressor_hidden_size = training_args.compressor_hidden_size
        self.llm_hidden_size = training_args.target_llm_hidden_size
        self.mem_size = training_args.mem_size
        self.segment_size = training_args.segment_size
        self.mean = training_args.mean
        self.split = training_args.split
        self.autoregressive = training_args.autoregressive
        self.fine_tune = training_args.fine_tune
        self.lm_ratio = training_args.lm_ratio
        self.scale = training_args.scale
        self.leave_len = training_args.leave_len
        self.full = training_args.full
        self.keft = training_args.keft
        self.restatement_ratio = training_args.restatement_ratio
        self.use_transform_layer = training_args.use_transform_layer
        self.lora_the_encoder = training_args.lora_the_encoder
        # add for ablation study
        self.coarse_grained_on = training_args.coarse_grained_on
        self.fine_grained_on = training_args.fine_grained_on
        self.redun_coarse = training_args.redun_coarse
        self.redun_fine = training_args.redun_fine

        if training_args.lamda_select:
            self.alpha_select = nn.Parameter(torch.tensor(1.0))
            self.beta_select = nn.Parameter(torch.tensor(1.0))

        if training_args.lamda_merge:
            self.alpha_merge = nn.Parameter(torch.tensor(1.0))
            self.beta_merge = nn.Parameter(torch.tensor(1.0))

        # self.memory_fusion_layer = get_peft_model(self.memory_fusion_layer, fusion_lora_config)
        # exit(0)
        self.compressor_config = AutoConfig.from_pretrained(self.model_name)
        self.fusion_layer_config = AutoConfig.from_pretrained(self.target_llm)
        
        self.compressor_size = self.compressor_config.hidden_size
        self.target_size = self.fusion_layer_config.hidden_size
        
        self.compressor_config.num_hidden_layers = training_args.compressor_hidden_layers
        self.gradient_checkpointing_enable = training_args.gradient_enable
        self.merge_sizes = training_args.merge_sizes

        # NOTE: we should scaling llama rope
        orig_ctx_len = getattr(self.compressor_config, "max_position_embeddings")
        if max_doc_tokens > orig_ctx_len:
            scaling_factor = float(math.ceil(max_doc_tokens / orig_ctx_len))
            self.compressor_config.rope_scaling = {"type": "linear", "factor": scaling_factor}

        # add for tkdr
        self.launch_tkdr = training_args.launch_tkdr
        if self.launch_tkdr:
            self.key_percentage = training_args.key_percentage
            self.tkdr = TKDR(
                self.alpha_select, 
                self.beta_select,
                self.alpha_merge,
                self.beta_merge,
                self.coarse_grained_on,
                self.fine_grained_on,
                self.redun_coarse,
                self.redun_fine,
            )

        self.llm_encoder = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            config=self.compressor_config
        )
        
        self.llm_encoder = self.llm_encoder.to(torch.bfloat16)
        # # if not self.fine_tune:
        if self.lora_the_encoder:
            self.llm_encoder = get_peft_model(self.llm_encoder, lora_config)

        self.is_random = training_args.is_random
        if self.post_append:
            self.semantic_alignment_layer = nn.Linear(self.compressor_hidden_size \
                                                , self.llm_hidden_size).to(dtype=torch.float16)
        
        # load memory fusion layer
        if not self.post_append:
            self.fusion_layer_config.num_hidden_layers = training_args.num_mem_fusion_layers
            if self.use_transform_layer:
                if self.compressor_size != self.target_size:
                    self.dimension_alignment_layer = nn.Linear(self.compressor_size \
                                                        , self.target_size).to(dtype=torch.bfloat16) 
                     
                self.memory_fusion_layer = AutoModelForCausalLM.from_pretrained(
                    self.target_llm,
                    config=self.fusion_layer_config
                )
                self.memory_fusion_layer = self.memory_fusion_layer.to(torch.bfloat16)
                if training_args.mem_lora:
                    self.memory_fusion_layer = get_peft_model(self.memory_fusion_layer, lora_config)

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        self.vocab_size = self.llm_encoder.config.vocab_size   #[PAD] token
        self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
        self.tokenizer.pad_token = self.tokenizer.unk_token
        self.vocab_size_with_mem = self.vocab_size + self.mem_size

        # resize embeddings 
        self.llm_encoder.resize_token_embeddings(self.vocab_size + self.mem_size)
        # self.llm.resize_token_embeddings(self.llm_vocab_size)

        if self.post_append:
            self.memory_token_embed = nn.Embedding(self.mem_size, self.compressor_hidden_size, padding_idx=None)
        # self.special_token_embed = nn.Embedding(1, self.compressor_hidden_size, padding_idx=None)
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        self.ppl_fct = nn.CrossEntropyLoss(reduction="none")
        self.icae_infer = training_args.icae_infer
        # self.contrast_loss = InfoNCELoss(temperature=0.1)
        # generate 1-d tensor
        if self.post_append:
            self.memory_sequence = torch.arange(self.vocab_size, self.vocab_size + self.mem_size)
            
        if  self.training:
            self.llm_tokenizer = AutoTokenizer.from_pretrained(self.target_llm)
            self.llm = AutoModelForCausalLM.from_pretrained(
                self.target_llm,
                torch_dtype=torch.bfloat16
            )
            
            # if self.fine_tune:
            # self.llm = get_peft_model(self.llm, lora_config)

            self.llm_vocab_size = self.llm.config.vocab_size
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.unk_token_id
            self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
            
            self.init()
        
        if self.gradient_checkpointing_enable:
            self.llm_encoder.use_cache = False
            self.llm_encoder.config.use_cache = False
            self.llm_encoder.enable_input_require_grads()
            self.llm_encoder.gradient_checkpointing_enable()

            if hasattr(self, 'memory_fusion_layer'):
                self.memory_fusion_layer.use_cache = False
                self.memory_fusion_layer.config.use_cache = False
                self.memory_fusion_layer.enable_input_require_grads()
                self.memory_fusion_layer.gradient_checkpointing_enable()

            if hasattr(self, 'llm'):
                self.llm.use_cache = False
                self.llm.config.use_cache = False
                self.llm.enable_input_require_grads()
                self.llm.gradient_checkpointing_enable()


    def generate_merge_size(
        self
    ):
        if "," not in self.merge_sizes:
            numbers = [int(self.merge_sizes)]
        else:
            numbers = [int(e) for e in self.merge_sizes.split(",")]

        sampled_number = random.choice(numbers)
        return sampled_number
    
    def generate_post_append_size(
        self  
    ):
        numbers = [64, 128]

        sampled_number = random.choice(numbers)
        return sampled_number
    
    def mrMR_loss2(
        self,
        query_embeds, 
        memory_embeds,
        query_attention_mask
    ):
        # q : [batch_size, query_len, hidden_size]
        # memory_embeds : [batch_size, memory_size, hidden_size]
        
        distilation_loss = 0
        batch_size = query_embeds.shape[0]
        # pprint(memory_embeds.shape)
        # exit(0)
        
        for idx in range(batch_size):
            effect_query_embeds = torch.cat([x.unsqueeze(0) for x, y in zip(query_embeds[idx], query_attention_mask[idx]) if y == 1], dim=0)
            query_embed = torch.mean(effect_query_embeds, dim=0).squeeze(0)
            # pprint(query_embed.shape)
            loss = 0
            for memory_embed in memory_embeds[idx]:
                debias_memory_embeds = torch.cat([v.unsqueeze(0) for v in memory_embeds[idx] if not torch.equal(memory_embed, v)], dim=0)
                # pprint(debias_memory_embeds.shape)
                # exit(0)
                loss += self.infoNCE_loss(
                    query_embed=query_embed,
                    memory_embed=memory_embed,
                    other_memorys_embed=debias_memory_embeds
                )
            # distilation_loss += (1.0 * loss / memory_embeds.shape[1])
            distilation_loss += loss

        return distilation_loss
    
    def max_relevence_score(
        self,
        query_embeds, 
        memory_embeds,
        query_attention_mask
    ):
        # q : [batch_size, query_len, hidden_size]
        # memory_embeds : [batch_size, memory_size, hidden_size]
        
        total_r = 0
        batch_size = query_embeds.shape[0]
        
        for idx in range(batch_size):
            effect_query_embeds = torch.cat([x.unsqueeze(0) for x, y in zip(query_embeds[idx], query_attention_mask[idx]) if y == 1], dim=0)
            query_embed = torch.mean(effect_query_embeds, dim=0)
            # pprint(query_embed.shape)
            max_r = self.max_relevence(
                query_embed=query_embed,
                mermory_embeds=memory_embeds[idx]    
            )
            total_r += max_r

        return 1.0 * total_r / batch_size
    
    # def distillation_loss(
    #     logits,
    #     target_logits
    # ):

    def infoNCE_loss(
        self, 
        query_embed, 
        memory_embed, 
        other_memorys_embed, 
        temperature=0.1
    ):
        # query_embed: torch.size([768])
        # memory_embed: torch.size([768])
        # other_memorys_embed: torch.size([8, 768])

        memory_embed = memory_embed.unsqueeze(0)  # [1, 768]
        other_memorys_embed = other_memorys_embed  # [8, 768]

        positive_similarity = torch.matmul(query_embed.unsqueeze(0), memory_embed.T) / temperature  # [1, 1]
        negative_similarity = torch.matmul(query_embed.unsqueeze(0), other_memorys_embed.T) / temperature  # [1, 8]

        all_similarity = torch.cat([positive_similarity, negative_similarity], dim=1)  # [1, 9]
        # all_similarity = all_similarity.float()  

        labels = torch.zeros(all_similarity.size(0), dtype=torch.long).to(query_embed.device)  # [9]
        # labels = labels.float()
        # pprint(all_similarity.shape)
        # pprint(labels.shape)
        # exit(0)
        loss = self.loss_fct(all_similarity, labels)
        return loss
    
    def vertical_pooling(
        self,
        query_embeds,
        memorys_hidden_state,
        query_attention_masks,
    ):
        # query_embeds : [batch_size, que_len, hidden_size]
        # memorys_hidden_state : Tuple([batch_size, mem_size, hidden_size]) Tuple_size : 13

        batch_size = query_embeds.shape[0]
        
        batch_global_memorys = []
        for idx in range(batch_size):
            effect_query_embeds = torch.cat([x.unsqueeze(0) for x, y in zip(query_embeds[idx], query_attention_masks[idx]) if y == 1], dim=0)
            # query_embed : [768]
            query_embed = torch.mean(effect_query_embeds, dim=0)
            q = self.q_proj(query_embed)
            # current_memorys_hidden_state : [13, mem_size, 768]
            current_memorys_hidden_state = torch.cat([e[idx].unsqueeze(0) for e in memorys_hidden_state], dim=0)
            k = self.k_proj(current_memorys_hidden_state)
            v = self.v_proj(current_memorys_hidden_state)
            # weights_matrix : [13, mem_size]
            weights_matrix = torch.matmul(k, q) / sqrt(768)
            # pprint(weights_matrix.shape)
            weights_matrix = torch.softmax(weights_matrix, dim=0)
            weights_matrix_expanded = weights_matrix.unsqueeze(2)  

            weighted_memory = weights_matrix_expanded * v

            global_memorys = torch.sum(weighted_memory, dim=0)
            # pprint(global_memorys.shape)
            batch_global_memorys.append(global_memorys.unsqueeze(0))
        
        return torch.cat(batch_global_memorys, dim=0)

    def max_relevence(
        self,
        query_embed,
        mermory_embeds
    ):
        # query_embed: torch.size([1, 768])
        # memory_embeds: torch.size([N, 768])
        simmilarity = torch.matmul(query_embed, mermory_embeds.T) # [1, N]
        relevence_score = torch.mean(simmilarity.squeeze(0), dim=0)
        return relevence_score

    def freeze_others(
        self
    ):
        for name, param in self.llm.named_parameters():
            if not ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name):
            # if not ('q_proj.lora' in name or 'k_proj.lora' in name or 'v_proj.lora' in name):
                param.requires_grad = False

    def init(self):
        # print("Freezing the target LLM...")

        # if not self.fine_tune:
            # freeze_model(self.llm)
        # else:
        #     freeze_model(self.llm_encoder)
        #     if self.use_transform_layer:
        #         freeze_model(self.memory_fusion_layer)
        self.freeze_others()

        # if not self.fine_tune:
            # self.llm.eval()    
        # else:
        #     self.llm_encoder.eval()
        #     if self.use_transform_layer:
        #         self.memory_fusion_layer.eval()
                
        print_trainable_parameters(self)

    def split_segments(
        self,
        context_input_ids,
        memory_sequence,
        input_mask
    ):
        # pprint(context_input_ids.shape)
        batch_size = context_input_ids.shape[0]
        memory_sequence = memory_sequence.repeat(batch_size, 1)
        # pprint(memory_sequence.shape)
        num_segments = math.ceil(context_input_ids.shape[1] * 1.0 / self.segment_size)
        # pprint(self.segment_size)
        # exit(0)
        segments_input_ids = [context_input_ids[:, i * self.segment_size: (i + 1) \
                                        * self.segment_size] for i in range(num_segments - 1)]
        segments_input_mask = [input_mask[:, i * self.segment_size: (i + 1) \
                                        * self.segment_size] for i in range(num_segments - 1)]
        last_segment_start_index = (num_segments - 1) * self.segment_size
        segments_input_ids.append(context_input_ids[:, last_segment_start_index:])
        segments_input_mask.append(input_mask[:, last_segment_start_index:])

        memory_len = math.ceil(memory_sequence.shape[1] * 1.0 / num_segments)
        memory_sequences = [memory_sequence[:, i * memory_len : (i + 1)\
                                            * memory_len] for i in range(num_segments - 1)]
        memory_sequences.append(memory_sequence[:, (num_segments - 1) * memory_len:])
        return segments_input_ids, memory_sequences, segments_input_mask

    def compute_num_segments(self, total_length):
        assert total_length > 0
        num_segments = math.ceil(total_length * 1.0 / self.segment_size)
        return num_segments

    def get_avg_embeds(self, tensor, attention_mask):
        attention_mask = attention_mask.unsqueeze(-1)  # [batch_size, merge_size, 1]
        masked_tensor = tensor * attention_mask  # [batch_size, merge_size, hidden_size]

        sum_tensor = masked_tensor.sum(dim=1)  # [batch_size, hidden_size]

        valid_counts = attention_mask.sum(dim=1)  # [batch_size, 1]

        average_tensor = sum_tensor / valid_counts  # [batch_size, hidden_size]
        return average_tensor

    # NOTE: The autoregressive version of tkdr
    def generate_autoregressive_tkdr_memorys(
        self,
        input_ids,
        input_mask,
        merge_size,
        query_ids,
        query_input_mask,
        pre_mem_embeds=None,
        pre_mem_attention_mask=None,
        end_flag=False
    ):
        device = input_ids.device
        
        all_input_ids = torch.cat((input_ids, query_ids), dim=1)
        all_input_mask = torch.cat((input_mask, query_input_mask), dim=1)

        if pre_mem_embeds is not None:
            input_embeds = self.llm.model.embed_tokens(all_input_ids)
            final_input_embeds = torch.cat((pre_mem_embeds.detach(), input_embeds), dim=1)
            # final_input_embeds = torch.cat((pre_mem_embeds, input_embeds), dim=1)
            final_attention_mask = torch.cat((pre_mem_attention_mask, all_input_mask), dim=1)
            # last_hidden_state = self.llm_encoder(
            #     inputs_embeds=final_input_embeds,
            #     attention_mask=final_attention_mask,
            #     output_hidden_states=True
            # ).hidden_states[-1][:, pre_mem_embeds.shape[1]:, :]
            last_hidden_state = self.llm_encoder(
                inputs_embeds=final_input_embeds,
                attention_mask=final_attention_mask,
                output_hidden_states=True
            ).hidden_states[-1]
        else:
            last_hidden_state = self.llm_encoder(
                input_ids=all_input_ids,
                attention_mask=all_input_mask,
                output_hidden_states=True
            ).hidden_states[-1]

        batch_size = input_ids.shape[0]
        context_len = input_ids.shape[1]
        query_len = query_ids.shape[1]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        additional_computation = []
        for i in range(batch_size):
            # context_mask = input_mask[i]   
            if pre_mem_embeds is not None:
                context_mask = torch.cat((pre_mem_attention_mask[i], input_mask[i]), dim=0)        
            else:
                context_mask = input_mask[i]
            # context_hidden_state = last_hidden_state[i][:context_len, :]
            context_hidden_state = last_hidden_state[i][:-query_len, :]

            query_mask = query_input_mask[i]
            query_hidden_state = last_hidden_state[i][-query_len:, :]
            # pprint(query_mask.shape)
            # pprint(query_hidden_state.shape)

            is_all_zero = torch.all(context_mask == 0)
            if is_all_zero:
                # shape = (context_hidden_state.shape[0], hidden_size)          
                # select_context_hidden_state = torch.ones(shape, device=device, dtype=context_hidden_state.dtype)
                select_context_hidden_state = context_hidden_state
                additional_computation.append(True)
            else:
                select_context_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(context_mask, \
                    context_hidden_state) if x == 1], dim=0)
                additional_computation.append(False)
            # context_mask = torch.tensor([x for x in context_mask if x == 1]).to(device)

            # NOTE: we should introduce addtional computation overload here

            select_query_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(query_mask, \
                query_hidden_state) if x == 1], dim=0)
            query_mask = torch.tensor([x for x in query_mask if x == 1]).to(device)
            # exit(0)

            # NOTE: we need to implement query_embeds as shape [hidden_size]
            query_embeds = torch.mean(select_query_hidden_state, dim=0)

            # memory_embeds.shape: [mem_len, hidden_size]
            memory_embeds = self.tkdr.compress_by_mrmr(
                select_context_hidden_state, 
                query_embeds, 
                merge_size, 
                # self.key_percentage
            )

            memorys_list.append(memory_embeds.unsqueeze(0))

        # NOTE: wait for check
        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        for idx, (e, fg) in enumerate(zip(memorys_list, additional_computation)):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            # NOTE: if fg == True, we ignore all memory embeds
            if fg:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        att_mask = att_mask.to(device)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        if self.use_transform_layer:
            if self.compressor_size != self.target_size:
                final_memorys = self.dimension_alignment_layer(final_memorys)
            
            if end_flag:
                aligned_memorys = self.memory_fusion_layer(
                    inputs_embeds=final_memorys,
                    attention_mask=att_mask,
                    output_hidden_states=True
                ).hidden_states[-1]
            else:
                aligned_memorys = final_memorys
        else:
            aligned_memorys = final_memorys

        if pre_mem_embeds is not None:
            # total_memorys = torch.cat((pre_mem_embeds, aligned_memorys), dim=1)
            total_memorys = torch.cat((pre_mem_embeds.detach(), aligned_memorys.detach()), dim=1)
            total_attention_mask = torch.cat((pre_mem_attention_mask, att_mask), dim=1)
        else:
            total_memorys = aligned_memorys
            total_attention_mask = att_mask

        pprint(input_ids.shape)
        pprint(merge_size)
        pprint(total_memorys.shape)

        return total_memorys, total_attention_mask

    # NOTE: we need to implement def generate_tkdr_memorys()
    def generate_tkdr_memorys(
        self,
        input_ids,
        input_mask,
        merge_size,
        query_ids,
        query_input_mask
    ):
        device = input_ids.device
        if self.is_random:
            merge_size = self.generate_merge_size()
        
        all_input_ids = torch.cat((input_ids, query_ids), dim=1)
        all_input_mask = torch.cat((input_mask, query_input_mask), dim=1)

        last_hidden_state = self.llm_encoder(
            input_ids=all_input_ids,
            attention_mask=all_input_mask,
            output_hidden_states=True
        ).hidden_states[-1]

        batch_size = input_ids.shape[0]
        context_len = input_ids.shape[1]
        query_len = query_ids.shape[1]
        memorys_list = []
        for i in range(batch_size):
            context_mask = input_mask[i]            
            context_hidden_state = last_hidden_state[i][:context_len, :]

            query_mask = query_input_mask[i]
            query_hidden_state = last_hidden_state[i][-query_len:, :]
            # pprint(query_mask.shape)
            # pprint(query_hidden_state.shape)

            select_context_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(context_mask, \
                context_hidden_state) if x == 1], dim=0)
            # context_mask = torch.tensor([x for x in context_mask if x == 1]).to(device)

            select_query_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(query_mask, \
                query_hidden_state) if x == 1], dim=0)
            query_mask = torch.tensor([x for x in query_mask if x == 1]).to(device)
            # exit(0)

            # NOTE: we need to implement query_embeds as shape [hidden_size]
            query_embeds = torch.mean(select_query_hidden_state, dim=0)

            # memory_embeds.shape: [mem_len, hidden_size]
            memory_embeds = self.tkdr.compress_by_mrmr(
                select_context_hidden_state, 
                query_embeds, 
                merge_size, 
                # self.key_percentage
            )

            memorys_list.append(memory_embeds.unsqueeze(0))

        # NOTE: wait for check
        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        att_mask = att_mask.to(device)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        if self.use_transform_layer:
            if self.compressor_size != self.target_size:
                final_memorys = self.dimension_alignment_layer(final_memorys)
                
            aligned_memorys = self.memory_fusion_layer(
                inputs_embeds=final_memorys,
                attention_mask=att_mask,
                output_hidden_states=True
            ).hidden_states[-1]
        else:
            aligned_memorys = final_memorys
        
        pprint(input_ids.shape)
        pprint(merge_size)
        pprint(aligned_memorys.shape)
        # pprint(att_mask.shape)
        # exit(0)
        return aligned_memorys, att_mask

    def generate_pooling_memorys(
        self,
        input_ids,
        input_mask,
        merge_size
    ):
        device = input_ids.device
        if self.is_random:
            merge_size = self.generate_merge_size()
        
        # pprint(input_ids.dtype)
        # pprint(input_ids.shape)
        # # input_ids = input_ids.int()
        # exit(0)
        last_hidden_state = self.llm_encoder(
            input_ids=input_ids,
            attention_mask=input_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        # exit(0)
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            current_mask = input_mask[i]
            is_all_zero = torch.all(current_mask == 0).item()
            # pprint(is_all_zero)
            if is_all_zero:
                sequence_length = last_hidden_state.shape[0]
                if sequence_length % merge_size == 0:
                    pad_length = 0
                else:
                    pad_length = merge_size - (sequence_length % merge_size)
                
                mem_size = int((sequence_length + pad_length) / merge_size)
                memorys = torch.zeros(mem_size, hidden_size).unsqueeze(0).to(device)
                # pprint(memorys.shape)
                memorys_list.append(memorys)
                continue
            
            current_hidden_state = last_hidden_state[i]
            
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            current_mask = torch.tensor([x for x in current_mask if x == 1]).to(device)
            
            # # select_hidden_state: [seq, hidden_size]
            # generate memorys then transform
            sequence_length = select_hidden_state.shape[0]
            res_len = sequence_length % merge_size
            if res_len != 0:
                res_hidden_state = select_hidden_state[-res_len:,:]
                res_memory = torch.mean(res_hidden_state, dim=0).unsqueeze(0)
                select_hidden_state = select_hidden_state[:-res_len, :]
                
            hidden_state_reshaped = select_hidden_state.reshape(-1, merge_size, hidden_size)
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            memorys = torch.mean(hidden_state_reshaped, dim=1)
            
            if res_len != 0:
                memorys = torch.cat((memorys, res_memory), dim=0).unsqueeze(0)
            else:
                memorys = memorys.unsqueeze(0)
            
            # if sequence_length % merge_size == 0:
            #     pad_length = 0
            # else:
            #     pad_length = merge_size - (sequence_length % merge_size)
            
            # pad_embeds = torch.zeros(pad_length, select_hidden_state.shape[1]).to(device)
            # pad_mask = torch.zeros(pad_length).to(device)
            # total_hidden_state = torch.cat((select_hidden_state, pad_embeds), dim=0)
            # pad_att_mask = torch.cat((current_mask, pad_mask), dim=0)
            # # pad_att_mask: [mem_len, merge_size]
            # pad_att_mask = pad_att_mask.reshape(-1, merge_size)
            # # hidden_state_reshaped.shape: [mem_len, merge_size, hidden_size]
            # hidden_state_reshaped = total_hidden_state.reshape(-1, merge_size, select_hidden_state.shape[1])
            
            # hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            # memorys = self.memory_fusion_layer(
            #     inputs_embeds=hidden_state_reshaped,
            #     output_hidden_states=True,
            #     attention_mask=pad_att_mask
            # ).hidden_states[-1]
            # # memorys = torch.mean(hidden_state_reshaped, dim=1).unsqueeze(0)
            # memorys = self.get_avg_embeds(memorys, pad_att_mask).unsqueeze(0)
            memorys_list.append(memorys)
            
        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        att_mask = att_mask.to(device)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        if self.use_transform_layer:
            if self.compressor_size != self.target_size:
                final_memorys = self.dimension_alignment_layer(final_memorys)
                
            aligned_memorys = self.memory_fusion_layer(
                inputs_embeds=final_memorys,
                attention_mask=att_mask,
                output_hidden_states=True
            ).hidden_states[-1]
        else:
            aligned_memorys = final_memorys
        
        return aligned_memorys, att_mask

    def generate_autoregressive_pooling_memorys(
        self,
        input_ids,
        input_mask,
        merge_size,
        pre_mem_embeds=None,
        pre_mem_attention_mask=None
    ):
        # pprint(input_ids.shape)
        # pprint(input_mask.shape)
        device = input_ids.device
        if pre_mem_embeds is not None:
            input_embeds = self.llm.model.embed_tokens(input_ids)
            final_input_embeds = torch.cat((pre_mem_embeds, input_embeds), dim=1)
            final_attention_mask = torch.cat((pre_mem_attention_mask, input_mask), dim=1)
            last_hidden_state = self.llm_encoder(
                inputs_embeds=final_input_embeds,
                attention_mask=final_attention_mask,
                output_hidden_states=True
            ).hidden_states[-1][:, pre_mem_embeds.shape[1]:, :]
        else:
            last_hidden_state = self.llm_encoder(
                input_ids=input_ids,
                attention_mask=input_mask,
                output_hidden_states=True
            ).hidden_states[-1]
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            current_mask = input_mask[i]
            is_all_zero = torch.all(current_mask == 0).item()
            # pprint(is_all_zero)
            if is_all_zero:
                sequence_length = last_hidden_state.shape[0]
                if sequence_length % merge_size == 0:
                    pad_length = 0
                else:
                    pad_length = merge_size - (sequence_length % merge_size)
                
                mem_size = int((sequence_length + pad_length) / merge_size)
                memorys = torch.zeros(mem_size, hidden_size).unsqueeze(0).to(device)
                # pprint(memorys.shape)
                memorys_list.append(memorys)
                continue
            
            current_hidden_state = last_hidden_state[i]
            
            # pprint(current_hidden_state.shape)
            # pprint(current_mask)
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            # pprint(select_hidden_state.shape)
            # exit(0)
            
            sequence_length = select_hidden_state.shape[0]
            if sequence_length % merge_size == 0:
                pad_length = 0
            else:
                pad_length = merge_size - (sequence_length % merge_size)
            
            padding = torch.zeros(pad_length, hidden_size).to(device)
            total_hidden_state = torch.cat((select_hidden_state, padding), dim=0)
            # hidden_state_reshaped.shape: [mem_len, merge_size, hidden_size]
            hidden_state_reshaped = total_hidden_state.reshape(-1, merge_size, hidden_size)
            # pprint(hidden_state_reshaped.shape)
            
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            memorys = self.memory_fusion_layer(
                inputs_embeds=hidden_state_reshaped,
                output_hidden_states=True
            ).hidden_states[-1]
            memorys = torch.mean(memorys, dim=1).unsqueeze(0)
            # pprint(memorys.shape)
            # exit(0)
            memorys_list.append(memorys)
        # exit(0)
            

        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        pprint(max_len)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        # pprint(final_memorys_list[0].shape)
        # pprint(final_memorys_list[1].shape)
        # pprint(att_mask)
        # exit(0)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        aligned_memorys = final_memorys
        # self.semantic_alignment_layer.weight.data = self.semantic_alignment_layer.weight.data.to(torch.bfloat16)
        # pprint(final_memorys.dtype)
        # pprint(self.semantic_alignment_layer.weight.dtype)
        # exit(0)
        # aligned_memorys = self.semantic_alignment_layer(final_memorys).to(torch.bfloat16)
        att_mask = att_mask.to(device)
        
        if pre_mem_embeds is not None:
            total_memorys = torch.cat((pre_mem_embeds, aligned_memorys), dim=1)
            total_attention_mask = torch.cat((pre_mem_attention_mask, att_mask), dim=1)
        else:
            total_memorys = aligned_memorys
            total_attention_mask = att_mask
        
        return total_memorys, total_attention_mask

    def generate_memorys_then_transform(
        self,
        input_ids,
        input_mask,
    ):
        # pprint(input_ids.shape)
        # pprint(input_mask.shape)
        device = input_ids.device
        last_hidden_state = self.llm_encoder(
            input_ids=input_ids,
            attention_mask=input_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        # pprint(last_hidden_state.dtype)
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            current_mask = input_mask[i]
            is_all_zero = torch.all(current_mask == 0).item()
            # pprint(is_all_zero)
            if is_all_zero:
                sequence_length = last_hidden_state.shape[0]
                if sequence_length % self.merge_size == 0:
                    pad_length = 0
                else:
                    pad_length = self.merge_size - (sequence_length % self.merge_size)
                
                mem_size = int((sequence_length + pad_length) / self.merge_size)
                memorys = torch.zeros(mem_size, hidden_size).unsqueeze(0).to(device)
                # pprint(memorys.shape)
                memorys_list.append(memorys)
                continue
            
            current_hidden_state = last_hidden_state[i]
            
            # pprint(current_hidden_state.shape)
            # pprint(current_mask)
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            # pprint(select_hidden_state.shape)
            # exit(0)
            
            sequence_length = select_hidden_state.shape[0]
            if sequence_length % self.merge_size == 0:
                pad_length = 0
            else:
                pad_length = self.merge_size - (sequence_length % self.merge_size)
            
            padding = torch.zeros(pad_length, select_hidden_state.shape[1]).to(device)
            total_hidden_state = torch.cat((select_hidden_state, padding), dim=0)
            # hidden_state_reshaped.shape: [mem_len, merge_size, hidden_size]
            hidden_state_reshaped = total_hidden_state.reshape(-1, self.merge_size, select_hidden_state.shape[1])
            # pprint(hidden_state_reshaped.shape)
            
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            # memorys = self.memory_fusion_layer(
            #     inputs_embeds=hidden_state_reshaped,
            #     output_hidden_states=True
            # ).hidden_states[-1]
            memorys = torch.mean(hidden_state_reshaped, dim=1).unsqueeze(0)
            # pprint(memorys.shape)
            # exit(0)
            memorys_list.append(memorys)
        # exit(0)
            

        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        # pprint(att_mask.shape)
        # for idx, e in enumerate(memorys_list):
        #     pad_len = max_len - e.shape[1]
        #     pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
        #     pad_memorys = torch.cat((e, pad_embeds), dim=1)

        #     if pad_len != 0:
        #         pad_mask = torch.zeros(pad_len).to(device)
        #         att_mask[idx][-pad_len:] = pad_mask
        #         # pprint(pad_len)
        #         # pprint(pad_mask.shape)
        #     final_memorys_list.append(pad_memorys)
        for idx, e in enumerate(memorys_list):
            # pprint(idx)
            # pprint(e.shape)
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            is_all_zero = torch.all(e == 0).item()
            if is_all_zero:
                pad_mask = torch.zeros(max_len).to(device)
                att_mask[idx] = pad_mask
            else:
                if pad_len != 0:
                    pad_mask = torch.zeros(pad_len).to(device)
                    att_mask[idx][-pad_len:] = pad_mask
                    # pprint(pad_len)
                    # pprint(pad_mask.shape)
            final_memorys_list.append(pad_memorys)

        # pprint(final_memorys_list[0].shape)
        # pprint(final_memorys_list[1].shape)
        # pprint(att_mask)
        # exit(0)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        # final_memorys.shape: [batch_size, mem_len, hidden_size]
        final_memorys = self.memory_fusion_layer(
                inputs_embeds=final_memorys,
                output_hidden_states=True
            ).hidden_states[-1]
        aligned_memorys = final_memorys
        # self.semantic_alignment_layer.weight.data = self.semantic_alignment_layer.weight.data.to(torch.bfloat16)
        # pprint(final_memorys.dtype)
        # pprint(self.semantic_alignment_layer.weight.dtype)
        # exit(0)
        # aligned_memorys = self.semantic_alignment_layer(final_memorys).to(torch.bfloat16)
        att_mask = att_mask.to(device)
        # pprint(final_memorys.shape)
        # pprint(att_mask.shape)
        # exit(0)
        
        return aligned_memorys, att_mask
    
    def get_mean_query_embeddings(
        self,
        query_hidden_state, # [batch_size, actual_q_len, hidden_size]
        query_mask
    ):
        # 1. 扩展 query_mask 的维度
        mask = query_mask.unsqueeze(-1)  # [batch_size, query_len, 1]
        
        # 2. 对 query_hidden_state 进行掩码操作
        masked_hidden_state = query_hidden_state * mask  # [batch_size, query_len, hidden_size]
        
        # 3. 计算每个样本的有效长度
        valid_length = mask.sum(dim=1)  # [batch_size, 1]
        
        # 4. 对 masked_hidden_state 在 query_len 维度上求和，然后除以有效长度
        mean_query_embeddings = masked_hidden_state.sum(dim=1) / valid_length  # [batch_size, hidden_size]
        return mean_query_embeddings

    def weighted_pooling(
        self,
        mean_query_embedding, 
        hidden_state_reshaped
    ):
        dot_product = torch.einsum("msh,h->ms", hidden_state_reshaped, mean_query_embedding)  # [mem_len, merge_size]
        scaled_dot_product = dot_product
        
        # scale_factor = torch.sqrt(torch.tensor(hidden_state_reshaped.size(-1), dtype=torch.bfloat16))  # sqrt(hidden_size)
        # scaled_dot_product = dot_product / scale_factor  # [mem_len, merge_size]

        weights = F.softmax(scaled_dot_product, dim=-1)  # [mem_len, merge_size]

        # weights: [mem_len, merge_size] -> unsqueeze 为 [mem_len, merge_size, 1]
        # hidden_state_reshaped: [mem_len, merge_size, hidden_size]
        weighted_sum = torch.einsum("msh,ms->mh", hidden_state_reshaped, weights)  # [mem_len, hidden_size]

        return weighted_sum
        
    def res_weighted_pooling(
        self,
        mean_query_embedding, 
        res_hidden_state
    ):
        dot_product = torch.einsum("lh,h->l", res_hidden_state, mean_query_embedding)  # [res_len]
        scaled_dot_product = dot_product
        
        # scale_factor = torch.sqrt(torch.tensor(res_hidden_state.size(-1), dtype=torch.bfloat16))  # sqrt(hidden_size)
        # scaled_dot_product = dot_product / scale_factor  # [res_len]

        weights = F.softmax(scaled_dot_product, dim=-1)  # [res_len]

        # res_hidden_state: [res_len, hidden_size]
        weighted_sum = torch.einsum("lh,l->h", res_hidden_state, weights)  # [hidden_size]

        return weighted_sum
    
    def generate_query_guided_pooling_memorys(
        self,
        query_ids, # with bos token
        query_mask,
        input_ids,
        input_mask,
        merge_size
    ):

        device = input_ids.device
        
        final_input_ids = torch.cat((query_ids, input_ids), dim=1)
        final_input_mask = torch.cat((query_mask, input_mask), dim=1)
        
        query_len = query_ids.shape[1]
        last_hidden_state = self.llm_encoder(
            input_ids=final_input_ids,
            attention_mask=final_input_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        # mean_query_embeddings.shape: [batch_size, hidden_size]
        mean_query_embeddings = self.get_mean_query_embeddings(
            query_hidden_state=last_hidden_state[:, :query_len, :],
            query_mask=query_mask
        )
        
        last_hidden_state = last_hidden_state[:, query_len:, :]
        
        batch_size = input_ids.shape[0]
        hidden_size = last_hidden_state.shape[2]
        memorys_list = []
        for i in range(batch_size):
            mean_query_embedding = mean_query_embeddings[i]
            current_mask = input_mask[i]
            current_hidden_state = last_hidden_state[i]
            
            select_hidden_state = torch.cat([y.unsqueeze(0) for x, y in zip(current_mask, \
                current_hidden_state) if x == 1], dim=0)
            current_mask = torch.tensor([x for x in current_mask if x == 1]).to(device)
            
            # # select_hidden_state: [seq, hidden_size]
            # generate memorys then transform
            sequence_length = select_hidden_state.shape[0]
            res_len = sequence_length % merge_size
            if res_len != 0:
                # res_hidden_state.shape: [res_len, hidden_size]
                res_hidden_state = select_hidden_state[-res_len:,:]
                # res_memory = torch.mean(res_hidden_state, dim=0).unsqueeze(0)
                res_memory = self.res_weighted_pooling(
                    mean_query_embedding=mean_query_embedding,
                    res_hidden_state=res_hidden_state
                )
                res_memory = res_memory.unsqueeze(0)
                
                select_hidden_state = select_hidden_state[:-res_len, :]
                
            hidden_state_reshaped = select_hidden_state.reshape(-1, merge_size, hidden_size)
            hidden_state_reshaped = hidden_state_reshaped.to(torch.bfloat16)
            memorys = self.weighted_pooling(
                mean_query_embedding=mean_query_embedding,
                hidden_state_reshaped=hidden_state_reshaped
            )
            
            if res_len != 0:
                memorys = torch.cat((memorys, res_memory), dim=0).unsqueeze(0)
            else:
                memorys = memorys.unsqueeze(0)
            
            memorys_list.append(memorys)
            
        # padding the memorys
        m_l = [e.shape[1] for e in memorys_list]
        max_len = max(m_l)
        final_memorys_list = []
        att_mask = torch.ones(batch_size, max_len)
        for idx, e in enumerate(memorys_list):
            # pad the memorys embeddings
            pad_len = max_len - e.shape[1]
            if pad_len == 0:
                pad_memorys = e
            else:
                pad_embeds = torch.zeros(1, pad_len, e.shape[2]).to(device)
                pad_memorys = torch.cat((e, pad_embeds), dim=1)   
            
            if pad_len != 0:
                pad_mask = torch.zeros(pad_len).to(device)
                att_mask[idx][-pad_len:] = pad_mask

            final_memorys_list.append(pad_memorys)

        att_mask = att_mask.to(device)
        final_memorys = torch.cat(final_memorys_list, dim=0).to(torch.bfloat16)
        
        aligned_memorys = self.memory_fusion_layer(
            inputs_embeds=final_memorys,
            attention_mask=att_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        return aligned_memorys, att_mask

    # def get_weighted_memorys(
    #     self,
    #     input_ids,
    #     input_mask
    # ):
    #     batch_size = input_ids.shape[0]
    #     output = self.llm_encoder(
    #         input_ids=input_ids,
    #         attention_mask=input_mask,
    #         output_hidden_states=True
    #     )
        
    #     # last_hidden_state.shape: [batch_size, context_tokens, hidden_size]
    #     last_hidden_state = output.hidden_states[-1][:, :-1, :]
    #     # shift_last_hidden_state.shape: [batch_size, -1, merge_size, hidden_size]
    #     shift_last_hidden_state = last_hidden_state.view(batch_size, -1, self.merge_size, self.compressor_hidden_size)
        
    #     logits = output.logits[..., :-1, :].contiguous()
    #     labels = input_ids[..., 1:, :].contiguous()
        
    #     shift_logits = logits.view(-1, logits.size(-1))
    #     shift_labels = labels.view(-1)
        
    #     # loss.shape: [batch_size, context_tokens]
    #     # loss.shape: [batch_size, 512]
    #     loss = self.loss_fct(shift_logits, shift_labels)
    #     group_loss = loss.reshape(batch_size, -1, self.merge_size)
    #     # weighs.shape: [batch_size, -1, merge_size, 1]
    #     weights = torch.softmax(group_loss, dim=-1).unsqueeze(-1)
    #     weighted_hidden_state = shift_last_hidden_state * weights
    #     # weighted_memorys.shape: [batch_size, -1, hidden_size]
    #     weighted_memorys = torch.sum(weighted_hidden_state, dim=-2)
        
    #     return weighted_memorys
    
    def normalize(
        self,
        x
    ):
        # x.shape: [seq_len, hidden_size]
        max_val = torch.max(x)
        min_val = torch.min(x)
        
        normalized_x = (x - min_val) / (max_val - min_val + 1e-8)
        return self.scale * normalized_x
    
    # def min_max_normalize(
    #     self,
    #     data
    # ):
    #     for i in range(data.shape[0]):
    #         for j in range(data.shape[1]):
    #             data[i][j] = self.normalize(data[i][j])
    #             pprint(data[i][j])
                
    #     return data        
    
    def get_weighted_memorys(
        self,
        input_ids,
        input_mask
    ):
        batch_size = input_ids.shape[0]
        device = input_ids.device
        input_ids_dtype = input_ids.dtype
        # add pad token
        if input_ids.shape[1] % self.merge_size == 0:
            pad_len = 0
        else:
            pad_len = self.merge_size - (input_ids.shape[1] % self.merge_size)
        pad_tokens = torch.full((batch_size, pad_len), self.tokenizer.pad_token_id, dtype=input_ids_dtype).to(device)
        pad_mask = torch.zeros(batch_size, pad_len, dtype=input_ids_dtype).to(device)
        input_ids = torch.cat([input_ids, pad_tokens], dim=1)
        input_mask = torch.cat([input_mask, pad_mask], dim=1)  
        
        # concate start token
        start_tokens = torch.full((batch_size, 1), self.tokenizer.bos_token_id, dtype=input_ids_dtype).to(device)
        start_attention_mask = torch.ones(batch_size, 1, dtype=input_ids_dtype).to(device)
        input_ids = torch.cat([start_tokens, input_ids], dim=1)
        attention_mask = torch.cat([start_attention_mask, input_mask], dim=1)    
    
        output = self.llm_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # last_hidden_state.shape: [batch_size, context_tokens, hidden_size]
        last_hidden_state = output.hidden_states[-1][:, :-1, :]
        # pprint(last_hidden_state.shape)
        # exit(0)
        
        mermorys_len = math.ceil(last_hidden_state.shape[1] / self.merge_size)
        pprint(mermorys_len)
        # shift_last_hidden_state.shape: [batch_size, mermorys_len, merge_size, hidden_size]
        shift_last_hidden_state = last_hidden_state.view(batch_size, mermorys_len, self.merge_size, self.compressor_hidden_size)
        
        logits = output.logits[..., :-1, :].contiguous()
        labels = input_ids[:, 1:].contiguous()
        # pprint(logits.shape)
        # pprint(labels.shape)
        
        shift_logits = logits.view(-1, logits.size(-1))
        shift_labels = labels.view(-1)
        # pprint(shift_logits.shape)
        # pprint(shift_labels.shape)
        # exit(0)
        
        # loss.shape: [batch_size, context_tokens]
        # loss.shape: [batch_size, 512]
        ppl = self.ppl_fct(shift_logits, shift_labels)
        group_ppl = ppl.reshape(batch_size, mermorys_len, self.merge_size)
        # normalized_group_ppl = self.min_max_normalize(group_ppl)
        normalized_group_ppl = torch.cat([self.normalize(e).unsqueeze(0) for e in group_ppl], dim=0)
        # convert_l = normalized_group_ppl.detach().cpu().tolist()
        # with open("normalized_convert.json", "w") as f:
        #     json.dump(convert_l, f, indent=4)
        # pprint(group_ppl)
        # exit(0)
        # weighs.shape: [batch_size, mermorys_len, merge_size]
        if self.mean:
            weights = torch.full((normalized_group_ppl.shape[0], normalized_group_ppl.shape[1]\
                , normalized_group_ppl.shape[2]), 1.0 / self.merge_size).to(torch.bfloat16).to(device) 
        else:
            weights = torch.softmax(normalized_group_ppl, dim=-1)

        # convert_weights = weights.detach().cpu().tolist()
        # with open(f"weights_{self.scale}.json", "a") as f:
        #     json.dump(convert_weights, f, indent=4)
        # exit(0)
        reshaped_mask = input_mask.reshape(batch_size, mermorys_len, self.merge_size)
        memorys_mask = (reshaped_mask.sum(dim=-1) != 0).float().to(input_ids.dtype)
        # all_ones = (memorys_mask == 1).all()
        # pprint(all_ones)
        weights = weights * reshaped_mask
            
        weighted_hidden_state = shift_last_hidden_state * (weights.unsqueeze(-1))
        # weighted_memorys.shape: [batch_size, mermorys_len, hidden_size]
        weighted_memorys = torch.sum(weighted_hidden_state, dim=-2)
        pprint(weighted_memorys.shape)
        aligned_memorys = self.semantic_alignment_layer(weighted_memorys).to(torch.bfloat16)
        # att_mask = torch.ones(weighted_memorys.shape[0], weighted_memorys.shape[1]).to(device)
        
        return aligned_memorys, memorys_mask

    def construct_input_embeds(
        self,
        input_ids, 
        input_mask,
        bert_prompt_ids,
        bert_prompt_mask
    ):
        segments, memorys, segments_attention_mask = self.split_segments(
            input_ids,
            self.memory_sequence,
            input_mask
        ) 
        
        pprint(len(segments))
        pprint(segments[0].shape)
        # exit(0)
        
        self.memory_sequence = self.memory_sequence.to(segments[0].device)

        memorys_embeds = self.memory_token_embed(self.memory_sequence - self.vocab_size)
        memorys_embeds = memorys_embeds.repeat(input_ids.shape[0], 1, 1)
        memorys_attention_mask = torch.ones((input_ids.shape[0], \
                                             self.memory_sequence.shape[0])).to(bert_prompt_mask.device)

        bert_prompt_embeds = self.llm_encoder.embed_tokens(bert_prompt_ids)
        total_mem_size = len(segments) * self.mem_size
        final_memorys = None

        total_loss = 0
        for idx, segment in enumerate(segments):
            segment_embeds = self.llm_encoder.embed_tokens(segment)
            # final_input_embeds = torch.cat((bert_prompt_embeds, segment_embeds, memorys_embeds), dim=1)
            # attention_mask = torch.cat((bert_prompt_mask, segments_attention_mask[idx]  \
            #                             , memorys_attention_mask), dim=1)

            # memorys_mask = [False] * (bert_prompt_embeds.shape[1] + segment_embeds.shape[1]) \
            #                 + [True] * memorys_embeds.shape[1]
            # query_mask = [True] * bert_prompt_embeds.shape[1] + [False] * segment_embeds.shape[1] \
            #                 + [False] * memorys_embeds.shape[1]

            # final_input_embeds = torch.cat((segment_embeds, bert_prompt_embeds, memorys_embeds), dim=1)
            # attention_mask = torch.cat((segments_attention_mask[idx], bert_prompt_mask,   \
            #                             memorys_attention_mask), dim=1)

            # memorys_mask = [False] * (segment_embeds.shape[1] + bert_prompt_embeds.shape[1]) \
            #                 + [True] * memorys_embeds.shape[1]
            # query_mask = [False] * segment_embeds.shape[1] + [True] * bert_prompt_embeds.shape[1] \
            #                 + [False] * memorys_embeds.shape[1]

            final_input_embeds = torch.cat((segment_embeds, memorys_embeds), dim=1)
            attention_mask = torch.cat((segments_attention_mask[idx], memorys_attention_mask), dim=1)

            memorys_mask = [False] * segment_embeds.shape[1] + [True] * memorys_embeds.shape[1]
            # query_mask = [False] * segment_embeds.shape[1] + [True] * bert_prompt_embeds.shape[1] \
            #                 + [False] * memorys_embeds.shape[1]

            compress_outputs = self.llm_encoder(
                inputs_embeds=final_input_embeds,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            # pprint(len(compress_outputs.hidden_states))
            # pprint(compress_outputs.hidden_states[0].shape)
            # exit(0)
            last_hidden_state = compress_outputs.last_hidden_state
            # enc_query_embeds = last_hidden_state[:, query_mask, :]
            # enc_memorys_embeds = last_hidden_state[:, memorys_mask, :]
            # segment_loss = self.mrMR_loss2(
            #     query_embeds=enc_query_embeds,
            #     memory_embeds=enc_memorys_embeds,
            #     query_attention_mask=bert_prompt_mask
            # )
            # memorys_hidden_state = [e[:, memorys_mask, :] for e in compress_outputs.hidden_states]
            # global_memorys_embeds = self.vertical_pooling(
            #     query_embeds=enc_query_embeds,
            #     memorys_hidden_state=memorys_hidden_state,
            #     query_attention_masks=bert_prompt_mask,
            # )
            global_memorys_embeds = last_hidden_state[:, memorys_mask, :]
            # segment_loss = self.mrMR_loss2(
            #     enc_query_embeds, 
            #     global_memorys_embeds,
            #     bert_prompt_mask
            # )
            # segment_relevence = self.max_relevence_score(
            #     query_embeds=enc_query_embeds,
            #     memory_embeds=enc_memorys_embeds,
            #     query_attention_mask=bert_prompt_mask                
            # )
            # total_loss += segment_loss
            if final_memorys == None:
                final_memorys = global_memorys_embeds
            else:
                final_memorys = torch.cat((final_memorys, global_memorys_embeds), dim=1)

            del segment, segment_embeds, final_input_embeds, global_memorys_embeds
            torch.cuda.empty_cache()
        # with open("attention_masks.txt", "w") as f:
        #     f.write(str(attention_masks))
        # exit(0)

        aligned_memorys_embeds = self.semantic_alignment_layer(final_memorys)
        return {
            "final_memorys" : final_memorys,
            "total_mem_size" : total_mem_size,
            "aligned_memorys_embeds" : aligned_memorys_embeds,
            # "total_loss" : 1.0 * total_loss / len(segments)
        }

    def split_to_segments(
        self,
        enc_doc_ids,
        enc_doc_mask,
        segment_size
    ):
        if self.is_random:
            autoregressive_merge_size = self.generate_merge_size()

        num_segments = math.ceil(1.0 * enc_doc_ids.shape[1] / segment_size)
        if num_segments == 1:
            segments = [enc_doc_ids]
            segments_masks = [enc_doc_mask]
        else:
            segments = [enc_doc_ids[:, i * segment_size : (i + 1) * segment_size] if i != num_segments - 1 else \
                enc_doc_ids[:, i * segment_size : ] for i in range(num_segments)]
            segments_masks = [enc_doc_mask[:, i * segment_size : (i + 1) * segment_size] if i != num_segments - 1 else \
                enc_doc_mask[:, i * segment_size : ] for i in range(num_segments)]
            
        return segments, segments_masks

    def generate_post_append_memorys(
        self,
        input_ids,
        input_mask
    ):
        device = input_ids.device
        batch_size = input_ids.shape[0]
        
        mem_map = {
            4: 128,
            8: 64
        }
        
        if self.is_random:
            actual_mem_size = self.generate_post_append_size()
        else:
            if not self.icae_infer:
                actual_mem_size = self.mem_size
            else:
                actual_mem_size = mem_map[self.merge_size]
            
        self.memory_sequence = self.memory_sequence.to(device)
        
        memorys_embeddings = torch.cat([self.memory_token_embed(self.memory_sequence[:actual_mem_size] - self.vocab_size).unsqueeze(0) for _ in range(batch_size)], dim=0)
        memorys_attention_mask = torch.ones(memorys_embeddings.shape[0], memorys_embeddings.shape[1]).to(device)
        
        input_embeddings = self.llm_encoder.get_base_model().model.embed_tokens(input_ids)
        
        final_input_embeddings = torch.cat((input_embeddings, memorys_embeddings), dim=1)
        attention_mask = torch.cat((input_mask, memorys_attention_mask), dim=1)
        
        last_hidden_state = self.llm_encoder(
            inputs_embeds=final_input_embeddings,
            attention_mask=attention_mask,
            output_hidden_states=True
        ).hidden_states[-1]
        
        final_memorys = last_hidden_state[:, -actual_mem_size:, :]
        # aligned_memorys = final_memorys.to(torch.bfloat16)
        aligned_memorys = self.semantic_alignment_layer(final_memorys).to(torch.bfloat16)
        
        return aligned_memorys, memorys_attention_mask
    
    def forward(
        self,
        enc_doc_ids,
        enc_doc_mask,
        target_doc_ids,
        target_doc_mask,
        llm_ins_ids=None,
        llm_ins_mask=None,
        enc_prefix_ids=None,
        enc_prefix_mask=None,
        enc_repeat_ids=None,
        enc_repeat_mask=None,
        llm_answer_ids=None,
        llm_answer_mask=None,
        llm_que_ids=None,
        llm_que_mask=None,
        enc_continue_ids=None,
        enc_continue_mask=None,
        **kwargs
    ):
        if self.split:
            # NOTE: we want all GPUs have the same seq_len to esure proper behaviour in autoregressive
            batch_size, seq_len = enc_doc_ids.shape[:2]
            seq_len_global = torch.tensor(seq_len, device=enc_doc_ids.device)
            device = enc_doc_ids.device
            
            dist.all_reduce(seq_len_global, op=dist.ReduceOp.MAX)

            diff = seq_len_global - seq_len
            doc_mask = torch.zeros(batch_size, diff, dtype=enc_doc_ids.dtype, device=device)
            pad_mask = torch.zeros(batch_size, diff, dtype=enc_doc_mask.dtype, device=device)
            enc_doc_ids = torch.cat([enc_doc_ids, doc_mask], dim=1)
            enc_doc_mask = torch.cat([enc_doc_mask, pad_mask], dim=1)

        # enc_doc_ids, enc_doc_mask, target_doc_ids, target_doc_mask = self.batch_text_extraction(
        #     enc_doc_ids=enc_doc_ids,
        #     enc_doc_mask=enc_doc_mask,
        # )
        
        if self.split:
            segments, segments_mask = self.split_to_segments(
                enc_doc_ids, 
                enc_doc_mask, 
                self.training_args.segment_size
            )
            segments_merge_size = [self.generate_merge_size() for _ in range(len(segments))]
        else:
            segments = [enc_doc_ids]
            segments_mask = [enc_doc_mask]
        
        if self.post_append:
            aligned_memorys, memorys_masks = zip(*[
                self.generate_post_append_memorys(input_ids=s, input_mask=m)
                for s, m in zip(segments, segments_mask)
            ])     
        else:
            # NOTE: In autoregressive settings, we should set:
            # segment_size 
            # split = True 
            # autoregressive = True
            if self.split:
                if self.autoregressive:
                    if self.is_random:
                        autoregressive_merge_size = self.generate_merge_size()
                    aligned_memorys = None
                    memorys_mask = None
                    for idx, (s, m, m_s) in enumerate(zip(segments, segments_mask, segments_merge_size)):
                        aligned_memorys, memorys_mask = self.generate_autoregressive_tkdr_memorys(
                            input_ids=s, 
                            input_mask=m,
                            merge_size=autoregressive_merge_size,
                            query_ids=kwargs['enc_que_ids'],
                            query_input_mask=kwargs['enc_que_mask'],
                            pre_mem_embeds=aligned_memorys,
                            pre_mem_attention_mask=memorys_mask,
                            end_flag=(idx == len(segments) - 1)
                        )   
                else:
                    if self.is_random:
                        autoregressive_merge_size = self.generate_merge_size()
                    if self.launch_tkdr:
                        aligned_memorys, memorys_masks = zip(*[
                            self.generate_tkdr_memorys(
                                input_ids=s, 
                                input_mask=m,
                                merge_size=autoregressive_merge_size,
                                query_ids=kwargs['enc_que_ids'],
                                query_input_mask=kwargs['enc_que_mask']
                            )
                            for s, m, m_s in zip(segments, segments_mask, segments_merge_size)
                        ])
                    else:
                        aligned_memorys, memorys_masks = zip(*[
                            self.generate_pooling_memorys(
                                input_ids=s, 
                                input_mask=m,
                                merge_size=m_s
                            )
                            for s, m, m_s in zip(segments, segments_mask, segments_merge_size)
                        ])
            else:
                if self.keft:
                    aligned_memorys, memorys_masks = zip(*[
                        self.generate_query_guided_pooling_memorys(
                            input_ids=s, 
                            input_mask=m, 
                            query_ids=kwargs['enc_que_ids'],
                            query_mask=kwargs['enc_que_mask'],
                            merge_size=self.merge_size
                        )
                        for s, m in zip(segments, segments_mask)
                    ])
                elif self.launch_tkdr:
                    aligned_memorys, memorys_masks = zip(*[
                        self.generate_tkdr_memorys(
                            input_ids=s, 
                            input_mask=m,
                            merge_size=self.merge_size,
                            query_ids=kwargs['enc_que_ids'],
                            query_input_mask=kwargs['enc_que_mask']
                        )
                        for s, m in zip(segments, segments_mask)
                    ])
                else:
                    aligned_memorys, memorys_masks = zip(*[
                        self.generate_pooling_memorys(
                            input_ids=s, 
                            input_mask=m, 
                            merge_size=self.merge_size
                        )
                        for s, m in zip(segments, segments_mask)
                    ])

        if not self.autoregressive:
            if len(aligned_memorys) == 1:
                aligned_memorys = aligned_memorys[0]
                memorys_mask = memorys_masks[0]
            else:
                aligned_memorys = torch.cat(aligned_memorys, dim=1)
                memorys_mask = torch.cat(memorys_masks, dim=1)

        pprint(aligned_memorys.shape)  

        # if not self.fine_tune:
        if random.random() < self.restatement_ratio:
            prefix_embeds = self.llm.model.embed_tokens(enc_prefix_ids)
            target_embeds = self.llm.model.embed_tokens(target_doc_ids)
            
            repeat_embeds = self.llm.model.embed_tokens(enc_repeat_ids)
            llm_input_embedings = torch.cat((prefix_embeds, aligned_memorys,\
                repeat_embeds, target_embeds), dim=1)
            pprint(llm_input_embedings.shape)
            llm_attention_mask = torch.cat((enc_prefix_mask, memorys_mask, \
                enc_repeat_mask, target_doc_mask), dim=1)
            
            llm_reconstruct_labels = torch.full_like(llm_attention_mask, -100)
            llm_reconstruct_labels[:, -target_doc_ids.size(1):] = target_doc_ids.masked_fill(
                ~target_doc_mask.bool(), -100,
            )
            
            llm_outputs = self.llm(
                inputs_embeds=llm_input_embedings,
                attention_mask=llm_attention_mask
            )  
            logits = llm_outputs.logits

            effective_logits = logits[:, :-1,:].reshape(-1, logits.size(-1))
            target_ids = llm_reconstruct_labels[:, 1:].reshape(-1).to(torch.long)
            loss = self.loss_fct(effective_logits, target_ids)
        else:
            # device = target_doc_ids.device
            # # target_doc_embeds = self.llm.get_base_model().model.embed_tokens(target_doc_ids)
            # target_doc_embeds = self.llm.model.embed_tokens(target_doc_ids)

            # llm_input_embedings = torch.cat((aligned_memorys, target_doc_embeds), dim=1)
            # llm_attention_mask = torch.cat((memorys_mask, target_doc_mask), dim=1)
            
            # # kwargs['labels'].shape: [batch_size, target_doc_len]
            # # pad_labels.shape: [batch_size, mem_len]
            # pad_labels = torch.full(memorys_mask.shape, -100).to(device)
            # llm_fine_tune_labels = torch.cat((pad_labels, kwargs['labels']), dim=1)
            
            # llm_outputs = self.llm(
            #     inputs_embeds=llm_input_embedings,
            #     attention_mask=llm_attention_mask
            # )  
            # pprint(llm_input_embedings.shape)
            
            # logits = llm_outputs.logits

            # effective_logits = logits[:, :-1,:].reshape(-1, logits.size(-1))
            # target_ids = llm_fine_tune_labels[:, 1:].reshape(-1).to(torch.long)
            # loss = self.loss_fct(effective_logits, target_ids)   
            # pprint(loss)  
            
            # if not self.full:
            # prefix_embeds = self.llm.get_base_model().model.embed_tokens(enc_prefix_ids)
            # lora config
            llm_ins_embeds = self.llm.model.embed_tokens(llm_ins_ids)
            llm_que_embeds = self.llm.model.embed_tokens(llm_que_ids)
            answer_embeds = self.llm.model.embed_tokens(llm_answer_ids)
            # else:
            #     prefix_embeds = self.llm.model.embed_tokens(enc_prefix_ids)
            # llm_ins_embeds = self.llm.model.embed_tokens(llm_ins_ids)
            # llm_que_embeds = self.llm.model.embed_tokens(llm_que_ids)
            # answer_embeds = self.llm.model.embed_tokens(llm_answer_ids)
            
            # llm_input_embedings = torch.cat((prefix_embeds, aligned_memorys, \
            #     llm_ins_embeds, llm_que_embeds, answer_embeds), dim=1)
            # llm_attention_mask = torch.cat((enc_prefix_mask, memorys_mask, \
            #     llm_ins_mask, llm_que_mask, llm_answer_mask), dim=1)

            llm_input_embedings = torch.cat((llm_ins_embeds, aligned_memorys, \
                llm_que_embeds, answer_embeds), dim=1)
            llm_attention_mask = torch.cat((llm_ins_mask, memorys_mask, \
                llm_que_mask, llm_answer_mask), dim=1)
            
            pprint(llm_answer_mask.shape)
            llm_fine_tune_labels = torch.full_like(llm_attention_mask, -100)
            llm_fine_tune_labels[:, -llm_answer_ids.size(1):] = llm_answer_ids.masked_fill(
                ~llm_answer_mask.bool(), -100,
            )
            
            llm_outputs = self.llm(
                inputs_embeds=llm_input_embedings,
                attention_mask=llm_attention_mask
            )  
            pprint(llm_input_embedings.shape)
            
            logits = llm_outputs.logits

            effective_logits = logits[:, :-1,:].reshape(-1, logits.size(-1))
            target_ids = llm_fine_tune_labels[:, 1:].reshape(-1).to(torch.long)
            loss = self.loss_fct(effective_logits, target_ids)   
            # pprint(loss.dtype)
            # exit(0)
        pprint(loss)

        return {"loss" : loss, "logits" : logits}

