
import torch
from typing import Optional, Tuple
import torch.nn as nn
import types
from dllm_cache import FeatureCache
def logout_cache_LLADA(
    model: nn.Module, tf_block_module_key_name: str
) -> None:
    
    target_module: Optional[nn.ModuleList] = None
    for name, module in model.named_modules():
        if name == tf_block_module_key_name:
            target_module = module  # type: ignore
    if target_module is None:
        return
    for tf_block in target_module:
        forward_fn = getattr(tf_block, "_old_forward", tf_block.forward)
        tf_block.forward = forward_fn

def register_cache_LLADA(
    model: nn.Module,
    tf_block_module_key_name: str,
    test_flops: bool = False,
) -> None:
    target_module: Optional[nn.ModuleList] = None
    for name, module in model.named_modules():
        if name == tf_block_module_key_name:
            target_module = module  # type: ignore
    for tf_block in target_module:
        setattr(tf_block, "_old_forward", tf_block.forward)
        if test_flops:
            tf_block.forward = types.MethodType(cache_hook_test_flops, tf_block)  # test_flops
        else:
            tf_block.forward = types.MethodType(cache_hook_feature, tf_block) 
        setattr(tf_block, "_old_attention", tf_block.attention)
        tf_block.attention = types.MethodType(_attention, tf_block)  #仅仅是面对qkv不一样长度的修改
        setattr(tf_block.rotary_emb, "_old_forward", tf_block.rotary_emb.forward)
        tf_block.rotary_emb.forward = types.MethodType(RoPe_forward, tf_block.rotary_emb)  #仅仅是面对qkv不一样长度的修改,可能会造成GQA不支持


def _attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        attention_bias: Optional[torch.Tensor] = None,
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False,
        q_index: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        
        B, q_len, C = q.size()  
        B, k_len, C = k.size()
        B, v_len, C = v.size()
        dtype = k.dtype 
        if self.q_norm is not None and self.k_norm is not None:
            q = self.q_norm(q).to(dtype=dtype)
            k = self.k_norm(k).to(dtype=dtype)
        q = q.view(B, q_len, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
        # shape: (B, n_kv_h, T, hs)
        k = k.view(B, k_len, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
        # shape: (B, n_kv_h, T, hs)
        v = v.view(B, v_len, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
        if layer_past is not None:
            past_key, past_value = layer_past
            k = torch.cat((past_key, k), dim=-2)
            v = torch.cat((past_value, v), dim=-2)
        present = (k, v) if use_cache else None
        query_len, key_len = q.shape[-2], k.shape[-2]  # could be different if layer_past not None
        if self.config.rope:
            # Apply rotary embeddings.
            q, k = self.rotary_emb(q, k,q_index=q_index)
        if attention_bias is not None: 
            attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype)
        att = self._scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=0.0 if not self.training else self.config.attention_dropout,
            is_causal=False,
        )
        att = att.transpose(1, 2).contiguous().view(B,  q_len, C)
        return self.attn_out(att), present

def RoPe_forward(self, q: torch.Tensor, k: torch.Tensor,q_index:torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.config.rope_full_precision:
            q_, k_ = q.float(), k.float()
        else:
            q_, k_ = q, k
        with torch.autocast(q.device.type, enabled=False):
            query_len, key_len = q_.shape[-2], k_.shape[-2]  # could be different if layer_past not None
            pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
            pos_sin = pos_sin.type_as(q_)
            pos_cos = pos_cos.type_as(q_)
            if q_index is not None:
                bs,_ = q_index.shape
                q_list = []
                for i in range(bs):
                    q_i = self.apply_rotary_pos_emb(
                    pos_sin[:, :, q_index[i], :],
                    pos_cos[:, :, q_index[i], :],
                    q_[i].unsqueeze(0),
                    )
                    q_list.append(q_i)
                q_ = torch.cat(q_list,dim=0)
            else:
                q_ = self.apply_rotary_pos_emb(
                    pos_sin[:, :, key_len - query_len : key_len, :],
                    pos_cos[:, :, key_len - query_len : key_len, :],
                    q_,
                )
            k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
        return q_.type_as(q), k_.type_as(k)


# v verify
def refresh_index(
    new_features: torch.Tensor,
    cached_features: torch.Tensor = None,
    transfer_ratio: float = 0.5,
    layer_id:int = 0
) -> torch.Tensor:
    batch_size, gen_len, d_model = new_features.shape
    # 计算需要替换的token数量
    num_replace = int(gen_len * transfer_ratio)
    cos_sim = torch.nn.functional.cosine_similarity(new_features, cached_features, dim=-1)
    # 按相似度排序，返回相似度低的token
    transfer_index = torch.topk(cos_sim, largest=False, k=num_replace).indices  
    return transfer_index



def cache_hook_feature(
    self,
    x: torch.Tensor,
    attention_bias: Optional[torch.Tensor] = None,
    layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
    feature_cache = FeatureCache()
    feature_cache.update_step(self.layer_id)
    prompt_length = feature_cache.prompt_length
    x_prompt = x[:, :prompt_length, :]  # Prompt部分
    x_gen = x[:, prompt_length:, :]      # 生成部分
    refresh_gen = feature_cache.refresh_gen(layer_id=self.layer_id)
    refresh_prompt = feature_cache.refresh_prompt(layer_id=self.layer_id)
    transfer_ratio = feature_cache.transfer_ratio
    bs, seq_len, dim = x.shape
    feature_cache.expect_length=seq_len-prompt_length
    transfer = transfer_ratio > 0 and transfer_ratio <= 1
    
    def attention(q, k, v,q_index:torch.Tensor=None):
        if self._activation_checkpoint_fn is not None:
            att, _ = self._activation_checkpoint_fn(self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache,q_index=q_index)
        else:
            att, _ = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache,q_index=q_index)
        return att
    
    def compute_mlp(input_x):
        
        if self._activation_checkpoint_fn is not None:
            x = self._activation_checkpoint_fn(self.ff_norm, input_x)
        else:
            x = self.ff_norm(input_x)
        x, x_up = self.ff_proj(x), self.up_proj(x)
        if self._activation_checkpoint_fn is not None:
            x = self._activation_checkpoint_fn(self.act, x)
        else:
            x = self.act(x)
        x = x * x_up
        return self.ff_out(x)
    
    def project(x):
        x_normed = self.attn_norm(x)
        q = self.q_proj(x_normed)
        k = self.k_proj(x_normed)
        v = self.v_proj(x_normed)
        return q, k, v
    

    # Attention 部分
    if refresh_gen and refresh_prompt:
        q, k, v = project(x)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k[:, :prompt_length, :], "v": v[:, :prompt_length, :]}, cache_type="prompt")
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k[:, prompt_length:, :], "v": v[:, prompt_length:, :]}, cache_type="gen")
        att = attention(q, k, v)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att[:, :prompt_length, :], cache_type="prompt")
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att[:, prompt_length:, :], cache_type="gen")

    elif refresh_gen and not refresh_prompt:
        q, k_gen, v_gen = project(x_gen)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k_gen, "v": v_gen}, cache_type="gen")
        kv_cache_prompt = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="prompt")
        k = torch.cat([kv_cache_prompt["k"], k_gen], dim=1)
        v = torch.cat([kv_cache_prompt["v"], v_gen], dim=1)
        att_gen = attention(q, k, v)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_gen, cache_type="gen")
        att_prompt_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="prompt")
        att = torch.cat([att_prompt_cache, att_gen], dim=1)
    
    elif not refresh_gen and refresh_prompt:
        q_prompt, k_prompt, v_prompt = project(x_prompt)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k_prompt, "v": v_prompt}, cache_type="prompt")
        kv_cache_gen = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="gen")
        att_gen_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="gen")
        if transfer:
            x_gen_normed = self.attn_norm(x_gen)
            v_gen = self.v_proj(x_gen_normed)
            index = refresh_index(v_gen,kv_cache_gen["v"],transfer_ratio,self.layer_id)
            index_expanded = index.unsqueeze(-1).expand(-1, -1, dim)  # [batch_size, num_replace, d_model]
            x_gen_selected = torch.gather(x_gen_normed, dim=1, index=index_expanded)
            q_gen_index = self.q_proj(x_gen_selected)
            k_gen_index = self.k_proj(x_gen_selected) 
            kv_cache_gen["v"] = v_gen
            kv_cache_gen["k"].scatter_(dim=1, index=index_expanded, src=k_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": kv_cache_gen["k"], "v": kv_cache_gen["v"]}, cache_type="gen")
        k = torch.cat([k_prompt, kv_cache_gen["k"]], dim=1)
        v = torch.cat([v_prompt, kv_cache_gen["v"]], dim=1)
        if transfer:
            q_prompt_gen_index = torch.cat([q_prompt,q_gen_index],dim=1)
            prompt_index = torch.arange(prompt_length).unsqueeze(0).expand(bs,-1).to(q_prompt_gen_index.device)
            gen_index = index+prompt_length
            att_prompt_gen_index = attention(q_prompt_gen_index, k, v,q_index=torch.cat([prompt_index,gen_index],dim=1))
            att_prompt = att_prompt_gen_index[:, :prompt_length, :]
            att_gen_index = att_prompt_gen_index[:, prompt_length:, :]
            att_gen_cache.scatter_(dim=1, index=index_expanded, src=att_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_gen_cache, cache_type="gen")
        else:
            att_prompt = attention(q_prompt, k, v,q_index=torch.arange(prompt_length).unsqueeze(0).expand(bs,-1))
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_prompt, cache_type="prompt")
        att = torch.cat([att_prompt, att_gen_cache], dim=1)
    else:
        att_gen_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="gen")
        if transfer:
            x_gen_normed = self.attn_norm(x_gen)
            v_gen = self.v_proj(x_gen_normed)
            kv_cache_gen = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="gen")
            kv_cache_prompt = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="prompt")
            index = refresh_index(v_gen,kv_cache_gen["v"],transfer_ratio,self.layer_id)
            index_expanded = index.unsqueeze(-1).expand(-1, -1, dim)  # [batch_size, num_replace, d_model]
            x_gen_selected = torch.gather(x_gen_normed, dim=1, index=index_expanded)
            # 替换掉部分kv并且写回cache
            q_gen_index = self.q_proj(x_gen_selected)
            k_gen_index = self.k_proj(x_gen_selected) 
            kv_cache_gen["v"] = v_gen
            kv_cache_gen["k"].scatter_(dim=1, index=index_expanded, src=k_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": kv_cache_gen["k"], "v": kv_cache_gen["v"]}, cache_type="gen")
            k = torch.cat([kv_cache_prompt["k"], kv_cache_gen["k"]], dim=1)
            v = torch.cat([kv_cache_prompt["v"], kv_cache_gen["v"]], dim=1)
            # 计算部分attn
            att_gen_index = attention(q_gen_index, k, v, q_index=index+prompt_length)
            # 替换掉部分attn_gen并且写回cache
            att_gen_cache.scatter_(dim=1, index=index_expanded, src=att_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_gen_cache, cache_type="gen")
       
        att_prompt_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="prompt")
        att = torch.cat([att_prompt_cache, att_gen_cache], dim=1)
    
    x = x + self.dropout(att)
    
    # Feed-forward 部分
    og_x = x
    x_prompt = x[:, :prompt_length, :]
    x_gen = x[:, prompt_length:, :]

    if refresh_gen and refresh_prompt:
        x = compute_mlp(x)
        feature_cache.set_cache(self.layer_id, "mlp", x[:, prompt_length:, :], cache_type="gen")
        feature_cache.set_cache(self.layer_id, "mlp", x[:, :prompt_length, :], cache_type="prompt")
        
    elif refresh_gen and not refresh_prompt:
        x_gen = compute_mlp(x_gen)
        feature_cache.set_cache(self.layer_id, "mlp", x_gen, cache_type="gen")
        x_prompt_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="prompt")
        x = torch.cat([x_prompt_cache, x_gen], dim=1)
    
    elif refresh_prompt and not refresh_gen:
        x_gen_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="gen")
        if transfer:
            x_gen_selected = torch.gather(x_gen, dim=1, index=index_expanded)
            x_prompt_gen_index = torch.cat([x_prompt,x_gen_selected],dim=1)
            x_prompt_gen_index = compute_mlp(x_prompt_gen_index)
            x_prompt = x_prompt_gen_index[:, :prompt_length, :]
            x_gen_index = x_prompt_gen_index[:, prompt_length:, :]
            x_gen_cache.scatter_(dim=1, index=index_expanded, src=x_gen_index)
            feature_cache.set_cache(self.layer_id, "mlp", x_gen_cache, cache_type="gen")
        else:
            x_prompt = compute_mlp(x_prompt)
        feature_cache.set_cache(self.layer_id, "mlp", x_prompt, cache_type="prompt")
        x = torch.cat([x_prompt, x_gen_cache], dim=1)
    
    else:
        x_gen_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="gen")
        if transfer:
            x_gen_selected = torch.gather(x_gen, dim=1, index=index_expanded)
            x_gen_index = compute_mlp(x_gen_selected)
            # x_gen_index = x_gen_selected
            x_gen_cache.scatter_(dim=1, index=index_expanded, src=x_gen_index)
            feature_cache.set_cache(self.layer_id, "mlp", x_gen_cache, cache_type="gen")
        x_prompt_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="prompt")
        x = torch.cat([x_prompt_cache, x_gen_cache], dim=1)
    
    x = self.dropout(x)
    x = og_x + x

    return x, None





# 不知道为什么一旦你使用scatter_会导致显存暴涨这个是一个 bug
def cache_hook_test_flops(
    self,
    x: torch.Tensor,
    attention_bias: Optional[torch.Tensor] = None,
    layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
    feature_cache = FeatureCache()
    feature_cache.update_step(self.layer_id)
    prompt_length = feature_cache.prompt_length
    x_prompt = x[:, :prompt_length, :]  # Prompt部分
    x_gen = x[:, prompt_length:, :]      # 生成部分
    refresh_gen = feature_cache.refresh_gen(layer_id=self.layer_id)
    refresh_prompt = feature_cache.refresh_prompt(layer_id=self.layer_id)
    transfer_ratio = feature_cache.transfer_ratio
    bs, seq_len, dim = x.shape
    transfer = transfer_ratio > 0 and transfer_ratio <= 1
    
    def attention(q, k, v,q_index:torch.Tensor=None):
        if self._activation_checkpoint_fn is not None:
            att, _ = self._activation_checkpoint_fn(self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache,q_index=q_index)
        else:
            att, _ = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache,q_index=q_index)
        return att
    
    def compute_mlp(input_x):
        
        if self._activation_checkpoint_fn is not None:
            x = self._activation_checkpoint_fn(self.ff_norm, input_x)
        else:
            x = self.ff_norm(input_x)
        x, x_up = self.ff_proj(x), self.up_proj(x)
        if self._activation_checkpoint_fn is not None:
            x = self._activation_checkpoint_fn(self.act, x)
        else:
            x = self.act(x)
        x = x * x_up
        return self.ff_out(x)
    
    def project(x):
        x_normed = self.attn_norm(x)
        q = self.q_proj(x_normed)
        k = self.k_proj(x_normed)
        v = self.v_proj(x_normed)
        return q, k, v
    

    # Attention 部分
    if refresh_gen and refresh_prompt:
        q, k, v = project(x)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k[:, :prompt_length, :], "v": v[:, :prompt_length, :]}, cache_type="prompt")
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k[:, prompt_length:, :], "v": v[:, prompt_length:, :]}, cache_type="gen")
        att = attention(q, k, v)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att[:, :prompt_length, :], cache_type="prompt")
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att[:, prompt_length:, :], cache_type="gen")

    elif refresh_gen and not refresh_prompt:
        q, k_gen, v_gen = project(x_gen)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k_gen, "v": v_gen}, cache_type="gen")
        kv_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="prompt")
        k = torch.cat([kv_cache["k"], k_gen], dim=1)
        v = torch.cat([kv_cache["v"], v_gen], dim=1)
        att_gen = attention(q, k, v)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_gen, cache_type="gen")
        att_prompt_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="prompt")
        att = torch.cat([att_prompt_cache, att_gen], dim=1)
    
    elif not refresh_gen and refresh_prompt:
        q_prompt, k_prompt, v_prompt = project(x_prompt)
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": k_prompt, "v": v_prompt}, cache_type="prompt")
        kv_cache_gen = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="gen")
        att_gen_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="gen")
        if transfer:
            x_gen_normed = self.attn_norm(x_gen)
            v_gen = self.v_proj(x_gen_normed)
            index = refresh_index(v_gen,kv_cache_gen["v"],transfer_ratio,self.layer_id)
            index_expanded = index.unsqueeze(-1).expand(-1, -1, dim)  # [batch_size, num_replace, d_model]
            x_gen_selected = torch.gather(x_gen_normed, dim=1, index=index_expanded)
            q_gen_index = self.q_proj(x_gen_selected)
            k_gen_index = self.k_proj(x_gen_selected) 
            kv_cache_gen["v"] = v_gen
            # kv_cache_gen["k"].scatter_(dim=1, index=index_expanded, src=k_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": kv_cache_gen["k"], "v": kv_cache_gen["v"]}, cache_type="gen")
        k = torch.cat([k_prompt, kv_cache_gen["k"]], dim=1)
        v = torch.cat([v_prompt, kv_cache_gen["v"]], dim=1)
        if transfer:
            q_prompt_gen_index = torch.cat([q_prompt,q_gen_index],dim=1)
            prompt_index = torch.arange(prompt_length).unsqueeze(0).expand(bs,-1).to(q_prompt_gen_index.device)
            gen_index = index+prompt_length
            att_prompt_gen_index = attention(q_prompt_gen_index, k, v,q_index=torch.cat([prompt_index,gen_index],dim=1))
            att_prompt = att_prompt_gen_index[:, :prompt_length, :]
            att_gen_index = att_prompt_gen_index[:, prompt_length:, :]
            # att_gen_cache.scatter_(dim=1, index=index_expanded, src=att_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_gen_cache, cache_type="gen")
        else:
            att_prompt = attention(q_prompt, k, v,q_index=torch.arange(prompt_length).unsqueeze(0).expand(bs,-1))
        feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_prompt, cache_type="prompt")
        att = torch.cat([att_prompt, att_gen_cache], dim=1)
    else:
        att_gen_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="gen")
        if transfer:
            x_gen_normed = self.attn_norm(x_gen)
            v_gen = self.v_proj(x_gen_normed)
            kv_cache_gen = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="gen")
            kv_cache_prompt = feature_cache.get_cache(layer_id=self.layer_id, feature_name="kv_cache", cache_type="prompt")
            index = refresh_index(v_gen,kv_cache_gen["v"],transfer_ratio,self.layer_id)
            index_expanded = index.unsqueeze(-1).expand(-1, -1, dim)  # [batch_size, num_replace, d_model]
            x_gen_selected = torch.gather(x_gen_normed, dim=1, index=index_expanded)
            # 替换掉部分kv并且写回cache
            q_gen_index = self.q_proj(x_gen_selected)
            k_gen_index = self.k_proj(x_gen_selected) 
            kv_cache_gen["v"] = v_gen
            # kv_cache_gen["k"].scatter_(dim=1, index=index_expanded, src=k_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="kv_cache", features={"k": kv_cache_gen["k"], "v": kv_cache_gen["v"]}, cache_type="gen")
            k = torch.cat([kv_cache_prompt["k"], kv_cache_gen["k"]], dim=1)
            v = torch.cat([kv_cache_prompt["v"], kv_cache_gen["v"]], dim=1)
            # 计算部分attn
            att_gen_index = attention(q_gen_index, k, v, q_index=index+prompt_length)
            # 替换掉部分attn_gen并且写回cache
            # att_gen_cache.scatter_(dim=1, index=index_expanded, src=att_gen_index)
            feature_cache.set_cache(layer_id=self.layer_id, feature_name="attn", features=att_gen_cache, cache_type="gen")
       
        att_prompt_cache = feature_cache.get_cache(layer_id=self.layer_id, feature_name="attn", cache_type="prompt")
        att = torch.cat([att_prompt_cache, att_gen_cache], dim=1)
    
    x = x + self.dropout(att)
    
    # Feed-forward 部分
    og_x = x
    x_prompt = x[:, :prompt_length, :]
    x_gen = x[:, prompt_length:, :]

    if refresh_gen and refresh_prompt:
        x = compute_mlp(x)
        feature_cache.set_cache(self.layer_id, "mlp", x[:, prompt_length:, :], cache_type="gen")
        feature_cache.set_cache(self.layer_id, "mlp", x[:, :prompt_length, :], cache_type="prompt")
        
    elif refresh_gen and not refresh_prompt:
        x_gen = compute_mlp(x_gen)
        feature_cache.set_cache(self.layer_id, "mlp", x_gen, cache_type="gen")
        x_prompt_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="prompt")
        x = torch.cat([x_prompt_cache, x_gen], dim=1)
    
    elif refresh_prompt and not refresh_gen:
        x_gen_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="gen")
        if transfer:
            x_gen_selected = torch.gather(x_gen, dim=1, index=index_expanded)
            x_prompt_gen_index = torch.cat([x_prompt,x_gen_selected],dim=1)
            x_prompt_gen_index = compute_mlp(x_prompt_gen_index)
            x_prompt = x_prompt_gen_index[:, :prompt_length, :]
            x_gen_index = x_prompt_gen_index[:, prompt_length:, :]
            # x_gen_cache.scatter_(dim=1, index=index_expanded, src=x_gen_index)
            feature_cache.set_cache(self.layer_id, "mlp", x_gen_cache, cache_type="gen")
        else:
            x_prompt = compute_mlp(x_prompt)
        feature_cache.set_cache(self.layer_id, "mlp", x_prompt, cache_type="prompt")
        x = torch.cat([x_prompt, x_gen_cache], dim=1)
    else:
        x_gen_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="gen")
        if transfer:
            x_gen_selected = torch.gather(x_gen, dim=1, index=index_expanded)
            x_gen_index = compute_mlp(x_gen_selected)
            # x_gen_cache.scatter_(dim=1, index=index_expanded, src=x_gen_index)
            feature_cache.set_cache(self.layer_id, "mlp", x_gen_cache, cache_type="gen")
        x_prompt_cache = feature_cache.get_cache(self.layer_id, "mlp", cache_type="prompt")
        x = torch.cat([x_prompt_cache, x_gen_cache], dim=1)
    
    x = self.dropout(x)
    x = og_x + x

    return x, None
