from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F

from hyvideo.modules.activation_layers import get_activation_layer
from hyvideo.modules.norm_layers import get_norm_layer
from hyvideo.modules.embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from hyvideo.modules.attenion import attention, parallel_attention, get_cu_seqlens
from hyvideo.modules.posemb_layers import apply_rotary_emb
from hyvideo.modules.modulate_layers import ModulateDiT, modulate, apply_gate

from scaling_cache.cache_functions import cal_type
from scaling_cache.utils import derivative_approximation, taylor_formula, taylor_cache_init
from scaling_cache.utils import interleaved_cache_update, interleaved_alpha_update, interleaved_error_update, scaling_formula


def hy_double_stream_block_forward(
        self,
        img: torch.Tensor,
        txt: torch.Tensor,
        vec: torch.Tensor,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: tuple = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
        update_cache: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        cache_dic["cal_amount"][current['stream']] += 1
        (
            img_mod1_shift,
            img_mod1_scale,
            img_mod1_gate,
            img_mod2_shift,
            img_mod2_scale,
            img_mod2_gate,
        ) = self.img_mod(vec).chunk(6, dim=-1)
        (
            txt_mod1_shift,
            txt_mod1_scale,
            txt_mod1_gate,
            txt_mod2_shift,
            txt_mod2_scale,
            txt_mod2_gate,
        ) = self.txt_mod(vec).chunk(6, dim=-1)

        # Prepare image for attention.
        img_modulated = self.img_norm1(img)
        img_modulated = modulate(
            img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
        )
        img_qkv = self.img_attn_qkv(img_modulated)
        img_q, img_k, img_v = rearrange(
            img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
        )
        # Apply QK-Norm if needed
        img_q = self.img_attn_q_norm(img_q).to(img_v)
        img_k = self.img_attn_k_norm(img_k).to(img_v)

        # Apply RoPE if needed.
        if freqs_cis is not None:
            img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
            assert (
                img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
            ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
            img_q, img_k = img_qq, img_kk

        # Prepare txt for attention.
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = modulate(
            txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
        )
        txt_qkv = self.txt_attn_qkv(txt_modulated)
        txt_q, txt_k, txt_v = rearrange(
            txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
        )
        # Apply QK-Norm if needed.
        txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
        txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)

        # Run actual attention.
        q = torch.cat((img_q, txt_q), dim=1)
        k = torch.cat((img_k, txt_k), dim=1)
        v = torch.cat((img_v, txt_v), dim=1)
        assert (
            cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
        ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"

        # attention computation start
        if not self.hybrid_seq_parallel_attn:
            attn = attention(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
                batch_size=img_k.shape[0],
                mode="flash",
            )
        else:
            attn = parallel_attention(
                self.hybrid_seq_parallel_attn,
                q,
                k,
                v,
                img_q_len=img_q.shape[1],
                img_kv_len=img_k.shape[1],
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv
            )

        img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
        img_attn_out = self.img_attn_proj(img_attn)

        if update_cache:
            if cache_dic['taylor_cache']:
                current['module'] = 'img-attn-taylor'
                taylor_cache_init(cache_dic, current)
                derivative_approximation(cache_dic, current, img_attn_out)
            elif cache_dic['scaling_cache']:
                current['module'] = 'img-attn'
                interleaved_error_update(cache_dic=cache_dic, current=current, feature=img_attn_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_attn_out)

        img = img + apply_gate(img_attn_out, gate=img_mod1_gate)

        img_mlp_out = self.img_mlp(
            modulate(
                self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
            )
        )

        if update_cache:
            if cache_dic['taylor_cache']:
                current['module'] = 'img-mlp-taylor'
                taylor_cache_init(cache_dic, current)
                derivative_approximation(cache_dic, current, img_mlp_out)
            elif cache_dic['scaling_cache']:
                current['module'] = 'img-mlp'
                interleaved_error_update(cache_dic=cache_dic, current=current, feature=img_mlp_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_mlp_out)

        img = img + apply_gate(img_mlp_out, gate=img_mod2_gate)

        # Calculate the txt bloks.
        txt_attn_out = self.txt_attn_proj(txt_attn)

        if update_cache:
            if cache_dic['taylor_cache']:
                current['module'] = 'txt-attn-taylor'
                taylor_cache_init(cache_dic, current)
                derivative_approximation(cache_dic, current, txt_attn_out)
            elif cache_dic['scaling_cache']:
                current['module'] = 'txt-attn'
                interleaved_error_update(cache_dic=cache_dic, current=current, feature=txt_attn_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_attn_out)

        txt = txt + apply_gate(txt_attn_out, gate=txt_mod1_gate)
        txt_mlp_out = self.txt_mlp(
            modulate(
                self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
            )
        )

        if update_cache:
            if cache_dic['taylor_cache']:
                current['module'] = 'txt-mlp-taylor'
                taylor_cache_init(cache_dic, current)
                derivative_approximation(cache_dic, current, txt_mlp_out)
            elif cache_dic['scaling_cache']:
                current['module'] = 'txt-mlp'
                interleaved_error_update(cache_dic=cache_dic, current=current, feature=txt_mlp_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_mlp_out)

        txt = txt + apply_gate(txt_mlp_out, gate=txt_mod2_gate)
        return img, txt

def hy_double_stream_block_taylor_forward(
        self,
        img: torch.Tensor,
        txt: torch.Tensor,
        vec: torch.Tensor,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: tuple = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        if current['type'] == 'full':
            img, txt = hy_double_stream_block_forward(
                self,
                img,
                txt,
                vec,
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
                freqs_cis,
                cache_dic,
                current,
                update_cache=True
            )

        elif current['type'] == 'Taylor':
            (
            img_mod1_shift,
            img_mod1_scale,
            img_mod1_gate,
            img_mod2_shift,
            img_mod2_scale,
            img_mod2_gate,
            ) = self.img_mod(vec).chunk(6, dim=-1)
            (
                txt_mod1_shift,
                txt_mod1_scale,
                txt_mod1_gate,
                txt_mod2_shift,
                txt_mod2_scale,
                txt_mod2_gate,
            ) = self.txt_mod(vec).chunk(6, dim=-1)

            current['module'] = 'img-attn-taylor'
            img = img + apply_gate(taylor_formula(cache_dic, current), gate=img_mod1_gate)

            current['module'] = 'img-mlp-taylor'
            img = img + apply_gate(
                taylor_formula(cache_dic, current),
                gate=img_mod2_gate,
            )
            # Calculate the txt bloks.
            current['module'] = 'txt-attn-taylor'
            txt = txt + apply_gate(taylor_formula(cache_dic, current), gate=txt_mod1_gate)
            current['module'] = 'txt-mlp-taylor'
            txt = txt + apply_gate(
                taylor_formula(cache_dic, current),
                gate=txt_mod2_gate,
            )

        return img, txt


def hy_double_stream_block_scaling_forward(
        self,
        img: torch.Tensor,
        txt: torch.Tensor,
        vec: torch.Tensor,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: tuple = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if current['type'] == 'full':
            img, txt = hy_double_stream_block_forward(
                    self,
                    img,
                    txt,
                    vec,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                    freqs_cis,
                    cache_dic,
                    current,
                    update_cache=True
                )
        elif current['type'] == 'Scaling':
            (
            img_mod1_shift,
            img_mod1_scale,
            img_mod1_gate,
            img_mod2_shift,
            img_mod2_scale,
            img_mod2_gate,
            ) = self.img_mod(vec).chunk(6, dim=-1)
            (
                txt_mod1_shift,
                txt_mod1_scale,
                txt_mod1_gate,
                txt_mod2_shift,
                txt_mod2_scale,
                txt_mod2_gate,
            ) = self.txt_mod(vec).chunk(6, dim=-1)
            if cache_dic['update_alpha']:
                # Prepare image for attention.
                img_modulated = self.img_norm1(img)
                img_modulated = modulate(
                    img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
                )
                img_qkv = self.img_attn_qkv(img_modulated)
                img_q, img_k, img_v = rearrange(
                    img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
                )
                # Apply QK-Norm if needed
                img_q = self.img_attn_q_norm(img_q).to(img_v)
                img_k = self.img_attn_k_norm(img_k).to(img_v)

                # Apply RoPE if needed.
                if freqs_cis is not None:
                    img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
                    assert (
                        img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
                    ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
                    img_q, img_k = img_qq, img_kk

                # Prepare txt for attention.
                txt_modulated = self.txt_norm1(txt)
                txt_modulated = modulate(
                    txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
                )
                txt_qkv = self.txt_attn_qkv(txt_modulated)
                txt_q, txt_k, txt_v = rearrange(
                    txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
                )
                # Apply QK-Norm if needed.
                txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
                txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)

                # Run actual attention.
                q = torch.cat((img_q, txt_q), dim=1)
                k = torch.cat((img_k, txt_k), dim=1)
                v = torch.cat((img_v, txt_v), dim=1)
                assert (
                    cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
                ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"

                # attention computation start
                if not self.hybrid_seq_parallel_attn:
                    attn = attention(
                        q,
                        k,
                        v,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        batch_size=img_k.shape[0],
                        mode="vanilla" if ((cache_dic['cache_type'] == 'attention') or (cache_dic['test_FLOPs'])) else "flash",
                    )
                else:
                    attn = parallel_attention(
                        self.hybrid_seq_parallel_attn,
                        q,
                        k,
                        v,
                        img_q_len=img_q.shape[1],
                        img_kv_len=img_k.shape[1],
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv
                    )
                    
                img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]

            layer_dict = cache_dic['cache'][-1][current['stream']][current['layer']]

            current['module'] = 'img-attn'
            if cache_dic['update_alpha']:
                img_attn_out = self.img_attn_proj(img_attn)
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=img_attn_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_attn_out)
            else:
                img_attn_out = scaling_formula(cache_dic=cache_dic, current=current)

            img = img + apply_gate(img_attn_out, gate=img_mod1_gate)

            current['module'] = 'img-mlp'
            if cache_dic['update_alpha']:
                img_mlp_out = self.img_mlp(
                    modulate(
                        self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
                    )
                )
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=img_mlp_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_mlp_out)
            else:
                img_mlp_out =  scaling_formula(cache_dic=cache_dic, current=current)

            img = img + apply_gate(
                img_mlp_out,
                gate=img_mod2_gate,
            )

            # Calculate the txt bloks.
            current['module'] = 'txt-attn'
            if cache_dic['update_alpha']:
                txt_attn_out = self.txt_attn_proj(txt_attn)
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=txt_attn_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_attn_out)
            else:
                txt_attn_out =  scaling_formula(cache_dic=cache_dic, current=current)

            txt = txt + apply_gate(txt_attn_out, gate=txt_mod1_gate)

            current['module'] = 'txt-mlp'
            if cache_dic['update_alpha']:
                txt_mlp_out = self.txt_mlp(
                    modulate(
                        self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
                    )
                )
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=txt_mlp_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_mlp_out)
            else:
                txt_mlp_out =  scaling_formula(cache_dic=cache_dic, current=current)

            txt = txt + apply_gate(
                txt_mlp_out,
                gate=txt_mod2_gate,
            )
        return img, txt

def hy_single_stream_block_forward(
        self,
        x: torch.Tensor,
        vec: torch.Tensor,
        txt_len: int,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
        update_cache: bool = True
    ) -> torch.Tensor:
        cache_dic["cal_amount"][current['stream']] += 1
        mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
        x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)

        qkv, mlp = torch.split(
            self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
        )
        q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)

        # Apply QK-Norm if needed.
        q = self.q_norm(q).to(v)
        k = self.k_norm(k).to(v)

        # Apply RoPE if needed.
        if freqs_cis is not None:
            img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
            img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
            img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
            assert (
                img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
            ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
            img_q, img_k = img_qq, img_kk
            q = torch.cat((img_q, txt_q), dim=1)
            k = torch.cat((img_k, txt_k), dim=1)

        # Compute attention.
        assert (
            cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
        ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"

        # attention computation start
        if not self.hybrid_seq_parallel_attn:
            attn = attention(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
                batch_size=x.shape[0],
                mode="flash",
            )
        else:
            attn = parallel_attention(
                self.hybrid_seq_parallel_attn,
                q,
                k,
                v,
                img_q_len=img_q.shape[1],
                img_kv_len=img_k.shape[1],
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv
            )
        # Compute activation in mlp stream, cat again and run second linear layer.
        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
        if update_cache:
            if cache_dic['taylor_cache']:
                current['module'] = 'total-taylor'
                taylor_cache_init(cache_dic, current)
                derivative_approximation(cache_dic, current, output)
            elif cache_dic['scaling_cache']:
                current['module'] = 'total'
                interleaved_error_update(cache_dic=cache_dic, current=current, feature=output)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=output)

        return x + apply_gate(output, gate=mod_gate)

def hy_single_stream_block_taylor_forward(
        self,
        x: torch.Tensor,
        vec: torch.Tensor,
        txt_len: int,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
    ) -> torch.Tensor:
        mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)

        if current['type'] == 'full':
            return hy_single_stream_block_forward(
                self,
                x,
                vec,
                txt_len,
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
                freqs_cis,
                cache_dic,
                current,
                update_cache=True
            )
        elif current['type'] == 'Taylor':
            current['module'] = 'total-taylor'
            output = taylor_formula(cache_dic, current)
        else:
            raise(ValueError)

        return x + apply_gate(output, gate=mod_gate)

def hy_single_stream_block_scaling_forward(
        self,
        x: torch.Tensor,
        vec: torch.Tensor,
        txt_len: int,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
    ) -> torch.Tensor:

        if current['type'] == 'full':
            return hy_single_stream_block_forward(
                self,
                x,
                vec,
                txt_len,
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
                freqs_cis,
                cache_dic,
                current,
                update_cache=True
            )
        elif current['type'] == 'Scaling':
            mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
            if cache_dic['update_alpha']:
                x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
                qkv, mlp = torch.split(
                    self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
                )
                q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)

                # Apply QK-Norm if needed.
                q = self.q_norm(q).to(v)
                k = self.k_norm(k).to(v)

                # Apply RoPE if needed.
                if freqs_cis is not None:
                    img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
                    img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
                    img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
                    assert (
                        img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
                    ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
                    img_q, img_k = img_qq, img_kk
                    q = torch.cat((img_q, txt_q), dim=1)
                    k = torch.cat((img_k, txt_k), dim=1)

                # Compute attention.
                assert (
                    cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
                ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"

                # attention computation start
                if not self.hybrid_seq_parallel_attn:
                    attn = attention(
                        q,
                        k,
                        v,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        batch_size=x.shape[0],
                        mode="vanilla" if ((cache_dic['cache_type'] == 'attention') or (cache_dic['test_FLOPs'])) else "flash",
                    )
                else:
                    attn = parallel_attention(
                        self.hybrid_seq_parallel_attn,
                        q,
                        k,
                        v,
                        img_q_len=img_q.shape[1],
                        img_kv_len=img_k.shape[1],
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv
                    )

            layer_dict = cache_dic['cache'][-1][current['stream']][current['layer']]
            current['module'] = 'total'
            if cache_dic['update_alpha']:
                output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=output)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=output)
            else:
                output = scaling_formula(cache_dic=cache_dic, current=current)
        else:
            raise(ValueError)

        return x + apply_gate(output, gate=mod_gate)