import torch
from typing import Optional, Tuple
from transformers.cache_utils import Cache
from dllm_cache import FeatureCache
import torch.nn as nn
import types
def logout_cache_Dream(
    model: nn.Module, tf_block_module_key_name: str
) -> None:
    
    target_module: Optional[nn.ModuleList] = None
    for name, module in model.named_modules():
        if name == tf_block_module_key_name:
            target_module = module  # type: ignore
    if target_module is None:
        return
    for tf_block in target_module:
        forward_fn = getattr(tf_block, "_old_forward", tf_block.forward)
        tf_block.forward = forward_fn

def register_cache_Dream(
    model: nn.Module,
    tf_block_module_key_name: str,
    test_flops: bool = False,
) -> None:
    target_module: Optional[nn.ModuleList] = None
    for name, module in model.named_modules():
        if name == tf_block_module_key_name:
            target_module = module  # type: ignore
    for tf_block in target_module:
        setattr(tf_block, "_old_forward", tf_block.forward)
        if test_flops:
            tf_block.forward = types.MethodType(decoder_hook, tf_block)  # test_flops
        else:
            tf_block.forward = types.MethodType(decoder_hook, tf_block) 
        setattr(tf_block.self_attn, "_old_forward", tf_block.self_attn.forward)
        tf_block.self_attn.forward = types.MethodType(attention, tf_block.self_attn)


def refresh_index(
    new_features: torch.Tensor,
    cached_features: torch.Tensor = None,
    transfer_ratio: float = 0.5,
    layer_id:int = 0
) -> torch.Tensor:
    batch_size, gen_len, d_model = new_features.shape
    # 计算需要替换的token数量
    num_replace = int(gen_len * transfer_ratio)
    cos_sim = torch.nn.functional.cosine_similarity(new_features, cached_features, dim=-1)
    # 按相似度排序，返回相似度低的token
    transfer_index = torch.topk(cos_sim, largest=False, k=num_replace).indices  
    return transfer_index



def decoder_hook(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        
        residual = hidden_states
        x = hidden_states
        feature_cache = FeatureCache()
        feature_cache.update_step(self.self_attn.layer_idx)
        prompt_length = feature_cache.prompt_length
        refresh_gen = feature_cache.refresh_gen(layer_id=self.self_attn.layer_idx)
        refresh_prompt = feature_cache.refresh_prompt(layer_id=self.self_attn.layer_idx)
        x_prompt = x[:, :prompt_length]
        x_gen = x[:, prompt_length:]
        transfer_ratio = feature_cache.transfer_ratio
        bs, seq_len, dim = x.shape
        feature_cache.expect_length=feature_cache.gen_length
        transfer = transfer_ratio > 0 and transfer_ratio <= 1

        def project(x):
            x_normed = self.input_layernorm(x)
            q = self.self_attn.q_proj(x_normed)
            k = self.self_attn.k_proj(x_normed)
            v = self.self_attn.v_proj(x_normed)
            return q, k, v
        

        def compute_mlp(x):
            x_norm = self.post_attention_layernorm(x)
            x = self.mlp(x_norm)
            return x
        

        if refresh_gen and refresh_prompt:
            q, k, v = project(x)
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", features={"k": k[:, :prompt_length, :], "v": v[:, :prompt_length, :]}, cache_type="prompt")
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", features={"k": k[:, prompt_length:, :], "v": v[:, prompt_length:, :]}, cache_type="gen")
            attn = self.self_attn(q,k,v,attention_mask,position_embeddings)
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", features=attn[:, :prompt_length, :], cache_type="prompt")
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", features=attn[:, prompt_length:, :], cache_type="gen")

        elif refresh_gen and not refresh_prompt:
            q_gen, k_gen, v_gen = project(x_gen)
            
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", features={"k": k_gen, "v": v_gen}, cache_type="gen")
            kv_cache_prompt = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", cache_type="prompt")
            k = torch.cat([kv_cache_prompt["k"], k_gen], dim=1)
            v = torch.cat([kv_cache_prompt["v"], v_gen], dim=1)
            gen_index = (torch.arange(seq_len-prompt_length)+prompt_length).unsqueeze(0).expand(bs,-1).to(q_gen.device)
            attn_gen = self.self_attn(q_gen,k,v,attention_mask,position_embeddings,gen_index)

            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", features=attn_gen, cache_type="gen")
            att_prompt_cache = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", cache_type="prompt")
            attn = torch.cat([att_prompt_cache, attn_gen], dim=1)
        elif not refresh_gen and refresh_prompt:
            q_prompt, k_prompt, v_prompt = project(x_prompt)
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", features={"k": k_prompt, "v": v_prompt}, cache_type="prompt")
            kv_cache_gen = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", cache_type="gen")
            att_gen_cache = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", cache_type="gen")
            if transfer:
                x_gen_norm = self.input_layernorm(x_gen)
                v_gen = self.self_attn.v_proj(x_gen_norm)
                index = refresh_index(v_gen,kv_cache_gen["v"],transfer_ratio,self.self_attn.layer_idx)
                index_expanded = index.unsqueeze(-1).expand(-1, -1, dim)  # [batch_size, num_replace, d_model]
                x_gen_selected = torch.gather(x_gen_norm, dim=1, index=index_expanded)
                q_gen_index = self.self_attn.q_proj(x_gen_selected)
                k_gen_index = self.self_attn.k_proj(x_gen_selected) 
                kv_cache_gen["v"] = v_gen
                kv_cache_gen["k"].scatter_(dim=1, index=index.unsqueeze(-1).expand(-1, -1, dim//7), src=k_gen_index)
                feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", features={"k": kv_cache_gen["k"], "v": kv_cache_gen["v"]}, cache_type="gen")
            k = torch.cat([k_prompt, kv_cache_gen["k"]], dim=1)
            v = torch.cat([v_prompt, kv_cache_gen["v"]], dim=1)
            if transfer:
                q_prompt_gen_index = torch.cat([q_prompt,q_gen_index],dim=1)
                prompt_index = torch.arange(prompt_length).unsqueeze(0).expand(bs,-1).to(q_prompt_gen_index.device)
                gen_index = index+prompt_length
                att_prompt_gen_index = self.self_attn(q_prompt_gen_index, k, v,attention_mask,position_embeddings,torch.cat([prompt_index,gen_index],dim=1))
                att_prompt = att_prompt_gen_index[:, :prompt_length, :]
                att_gen_index = att_prompt_gen_index[:, prompt_length:, :]
                att_gen_cache.scatter_(dim=1, index=index_expanded, src=att_gen_index)
                feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", features=att_gen_cache, cache_type="gen")
            else:
                # att_prompt = self.self_attn(q_prompt, k, v,q_index=torch.arange(prompt_length).unsqueeze(0).expand(bs,-1))
                att_prompt = self.self_attn(q_prompt, k, v,attention_mask,position_embeddings,torch.arange(prompt_length).unsqueeze(0).expand(bs,-1))
            feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", features=att_prompt, cache_type="prompt")
            attn = torch.cat([att_prompt, att_gen_cache], dim=1)
        else:
            att_gen_cache = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", cache_type="gen")
            if transfer:
                x_gen_norm = self.input_layernorm(x_gen)
                v_gen = self.self_attn.v_proj(x_gen_norm)
                kv_cache_gen = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", cache_type="gen")
                kv_cache_prompt = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", cache_type="prompt")
                index = refresh_index(v_gen,kv_cache_gen["v"],transfer_ratio,self.self_attn.layer_idx)
                index_expanded = index.unsqueeze(-1).expand(-1, -1, dim)  # [batch_size, num_replace, d_model]
                x_gen_selected = torch.gather(x_gen_norm, dim=1, index=index_expanded)
                q_gen_index = self.self_attn.q_proj(x_gen_selected)
                k_gen_index = self.self_attn.k_proj(x_gen_selected) 
                kv_cache_gen["v"] = v_gen
                kv_cache_gen["k"].scatter_(dim=1, index=index.unsqueeze(-1).expand(-1, -1, dim//7), src=k_gen_index)
                feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="kv_cache", features={"k": kv_cache_gen["k"], "v": kv_cache_gen["v"]}, cache_type="gen")
                k = torch.cat([kv_cache_prompt["k"], kv_cache_gen["k"]], dim=1)
                v = torch.cat([kv_cache_prompt["v"], kv_cache_gen["v"]], dim=1)
                gen_index = index+prompt_length
                att_gen_index = self.self_attn(q_gen_index, k, v,attention_mask,position_embeddings,gen_index)
                att_gen_cache.scatter_(dim=1, index=index_expanded, src=att_gen_index)
                feature_cache.set_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", features=att_gen_cache, cache_type="gen")
            att_prompt_cache = feature_cache.get_cache(layer_id=self.self_attn.layer_idx, feature_name="attn", cache_type="prompt")
            attn = torch.cat([att_prompt_cache, att_gen_cache], dim=1)

        # Fully Connected
        x = residual + attn
        residual = x
        x_prompt = x[:, :prompt_length]
        x_gen = x[:, prompt_length:]


        if refresh_gen and refresh_prompt:
            x = compute_mlp(x)
            feature_cache.set_cache(self.self_attn.layer_idx, "mlp", x[:, prompt_length:, :], cache_type="gen")
            feature_cache.set_cache(self.self_attn.layer_idx, "mlp", x[:, :prompt_length, :], cache_type="prompt")
        elif refresh_gen and not refresh_prompt:
            x_gen = compute_mlp(x_gen)
            feature_cache.set_cache(self.self_attn.layer_idx, "mlp", x_gen, cache_type="gen")
            x_prompt_cache = feature_cache.get_cache(self.self_attn.layer_idx, "mlp", cache_type="prompt")
            x = torch.cat([x_prompt_cache, x_gen], dim=1)
        elif not refresh_gen and refresh_prompt:
            x_gen_cache = feature_cache.get_cache(self.self_attn.layer_idx, "mlp", cache_type="gen")
            if transfer:
                x_gen_selected = torch.gather(x_gen, dim=1, index=index_expanded)
                x_prompt_gen_index = torch.cat([x_prompt,x_gen_selected],dim=1)
                x_prompt_gen_index = compute_mlp(x_prompt_gen_index)
                x_prompt = x_prompt_gen_index[:, :prompt_length, :]
                x_gen_index = x_prompt_gen_index[:, prompt_length:, :]
                x_gen_cache.scatter_(dim=1, index=index_expanded, src=x_gen_index)
                feature_cache.set_cache(self.self_attn.layer_idx, "mlp", x_gen_cache, cache_type="gen")
            else:
                x_prompt = compute_mlp(x_prompt)
            feature_cache.set_cache(self.self_attn.layer_idx, "mlp", x_prompt, cache_type="prompt")
            x = torch.cat([x_prompt, x_gen_cache], dim=1)
        else:
            x_gen_cache = feature_cache.get_cache(self.self_attn.layer_idx, "mlp", cache_type="gen")
            if transfer:
                x_gen_selected = torch.gather(x_gen, dim=1, index=index_expanded)
                x_gen_index = compute_mlp(x_gen_selected)
                x_gen_cache.scatter_(dim=1, index=index_expanded, src=x_gen_index)
                feature_cache.set_cache(self.self_attn.layer_idx, "mlp", x_gen_cache, cache_type="gen")
            x_prompt_cache = feature_cache.get_cache(self.self_attn.layer_idx, "mlp", cache_type="prompt")
            x = torch.cat([x_prompt_cache, x_gen_cache], dim=1)
            

        x = residual + x
        outputs = (x,)
        return outputs



def attention(
        self,
        query_states: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        q_index:torch.Tensor =None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = query_states.size()
        bsz, k_len, _ = key_states.size()
        # key_states value_states.shape[1,4, prompt_len+generate_len, 128]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = position_embeddings
        # cos.shape = (bs, prompt_len+generate_len, self.head_dim)
        # sin.shape = (bs, prompt_len+generate_len, self.head_dim)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin,q_index)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        # repeat_kv:key_states value_states.shape[1,28, prompt_len+generate_len, 128]
        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()
        # print(f"query_states shape: {query_states.shape}")
        # print(f"key_states shape: {key_states.shape}")
        # print(f"value_states shape: {value_states.shape}")
        # print(f"attention_mask shape: {attention_mask.shape}")
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            # attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
            attn_mask=None,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=False,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output


def apply_rotary_pos_emb(q, k, cos, sin, q_index=None, unsqueeze_dim=1):
    def rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    if q_index is not None:
                bs,_ = q_index.shape
                q_embed = []
                for i in range(bs):
                    q_i = (q[i].squeeze(0) * cos[i,:,q_index[i]]) + (rotate_half(q[i].squeeze(0)) * sin[i,:,q_index[i]])
                    q_embed.append(q_i.unsqueeze(0))
                q_embed = torch.cat(q_embed,dim=0)
    else:
        q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)