
import typing as tp


# from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils.transformer_utils import DropPath, precompute_freqs_cis, apply_rotary_emb, interleave_tokens, noise_augment
from utils.rar_utils import TargetAwarePE, AnnealSchedule, build_pair_perm, apply_perm, compute_tape_delta_next
import math
import os

torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.matmul.allow_tf32 = True   # A100+

class TypeEmbedding(nn.Module):
    def __init__(self, num_types: int, embed_dim: int, dropout: float = 0.1):
        super().__init__()
        self.type_embedding = nn.Embedding(num_types, embed_dim)
        self.dropout = nn.Dropout(dropout)
    #     self._init_weights()

    # def _init_weights(self):
    #     nn.init.normal_(self.type_embedding.weight, mean=0.0, std=0.02)

    def forward(self, token_embeds: torch.Tensor, type_id: int) -> torch.Tensor:
        """
        Args:
            token_embeds: Tensor of shape [batch_size, seq_len, embed_dim]
            type_id: int 
        Returns:
            Tensor of shape [batch_size, seq_len, embed_dim]
        """
        if isinstance(type_id, int):
            # type_id = torch.full((token_embeds.size(0),), type_id, dtype=torch.long, device=token_embeds.device)
            type_id = torch.full((token_embeds.size(0),), type_id, device=token_embeds.device)
        type_embeds = self.type_embedding(type_id)
        type_embeds = type_embeds.unsqueeze(1).expand(-1, token_embeds.size(1), -1)
        return self.dropout(token_embeds + type_embeds)

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class DyT(nn.Module):
    def __init__(self, dim: int, init_alpha: float = 0.5, eps: float = 1e-5):
        # Transformers without Normalization https://arxiv.org/abs/2503.10622
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
        self.gammma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))


    def forward(self, x):
        output = torch.tanh(self.alpha * x)
        return output * self.gammma + self.beta

# @jit.script
# def fused_layerscale(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
#     return x * scale

class LayerScale(nn.Module):
    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf)."""
    def __init__(self, channels: int, init: float = 1e-4):
        super().__init__()
        self.scale = nn.Parameter(torch.full((channels,), init))
        # self.scale = nn.Parameter(torch.full((channels,), init, device=device, dtype=dtype))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.scale
        # return fused_layerscale(x, self.scale)

def find_multiple(n: int, k: int):
    if n % k == 0:
        return n
    return n + k - (n % k)

class FeedForward(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        ffn_dim_multiplier: tp.Optional[float] = None, 
        ffn_dropout_p: float = 0.1, 
        multiple_of: int = 256
    ):
        super().__init__()
        hidden_dim = 4 * d_model
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = find_multiple(hidden_dim, multiple_of)

        # self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        # self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w13 = nn.Linear(d_model, hidden_dim * 2, bias=False)
        self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
        self.ffn_dropout = nn.Dropout(ffn_dropout_p)

    def forward(self, x):
        a, b = self.w13(x).chunk(2, dim=-1)
        out = F.silu(a) * b
        return self.ffn_dropout(self.w2(out))
        # return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
        super().__init__()
        cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, H, S, D]
        assert input_pos.shape[0] == k_val.shape[2]
        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val

        return k_out, v_out

class RingKVCache(nn.Module):
    def __init__(self, max_batch_size, window_size, n_head, head_dim, dtype):
        super().__init__()
        self.window_size = int(window_size)
        cache_shape = (max_batch_size, n_head, self.window_size, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('pos', torch.full((max_batch_size,), -1, dtype=torch.long), persistent=False)
        self.register_buffer('valid', torch.zeros((max_batch_size,), dtype=torch.long), persistent=False)

    def update(self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor):
        B, H, S, D = k_val.shape
        assert S == input_pos.shape[0]
        assert S == 1, "RingKVCache.update currently assumes streaming (S == 1)"
        s = 0
        self.pos = (self.pos + 1) % self.window_size
        slot = self.pos.view(B, 1, 1, 1)                  # (B,1,1,1), long
        index = slot.expand(B, H, 1, D)                   # (B,H,1,D)
        self.k_cache.scatter_(dim=2, index=index, src=k_val[:, :, 0:1, :])
        self.v_cache.scatter_(dim=2, index=index, src=v_val[:, :, 0:1, :])
        self.valid = torch.clamp(self.valid + 1, max=self.window_size)
        return self.k_cache, self.v_cache
    def gather_recent(self):
        B, H, W, D = self.k_cache.shape
        l = int(self.valid[0].item())

        if l == 0:
            return self.k_cache[:, :, :0, :], self.v_cache[:, :, :0, :]

        end = int(self.pos[0].item())
        
        idx = torch.arange(end - l + 1, end + 1, device=self.k_cache.device) % self.window_size
        
        keys   = self.k_cache[:, :, idx, :]
        values = self.v_cache[:, :, idx, :]
        
        return keys, values



class StreamingMultiheadAttention(nn.Module):
    """Similar to `nn.MultiheadAttention`.

    Args:
        embed_dim (int): Dimension to project to.
        num_heads (int): Number of heads.
        causal (bool): Causal mask applied automatically.
        context (int, optional): Number of time steps the attention can access to.
            When causal, can access `context` time steps into the past, and when non causal,
            can access `context // 2` steps in the past, and the same in the future.
        rope (`RotaryEmbedding`, optional): Rope embedding to use.
        weights_per_step (int): use different weights per time step. If non zero, should correspond to the
            number of possible time steps.
        device (torch.device, optional): Device on which to initialize.
        dtype (torch.dtype, optional): dtype to use.
    """

    _fsdp_final = True

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_kv_heads: tp.Optional[int] = None,
        attn_dropout_p: float = 0.0,
        resid_dropout_p: float = 0.1,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
        
        total_kv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
        
        # key, query, value projections for all heads, but in a batch
        self.wqkv = nn.Linear(embed_dim, total_kv_dim, bias=False) # in_proj
        self.wo = nn.Linear(embed_dim, embed_dim, bias=False) # out_proj
        self.kv_cache = None
        self.sliding_window = None
        
        # regularization
        self.attn_dropout_p = attn_dropout_p
        self.resid_dropout = nn.Dropout(resid_dropout_p)
        

    def forward(
        self, 
        query: torch.Tensor, 
        freqs_cis: torch.Tensor = None, 
        input_pos: tp.Optional[torch.Tensor] = None, 
        mask: tp.Optional[torch.Tensor] = None
    ):
        # state = self._streaming_state
        bsz, seqlen, _ = query.shape
        kv_size = self.num_kv_heads * self.head_dim
        xq, xk, xv = self.wqkv(query).split([self.embed_dim, kv_size, kv_size], dim=-1)
        
        xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
        
        xq = apply_rotary_emb(xq, freqs_cis)
        xk = apply_rotary_emb(xk, freqs_cis)
        
        xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
        
        
        if isinstance(self.kv_cache, RingKVCache):
            self.kv_cache.update(input_pos, xk, xv)
            keys, values = self.kv_cache.gather_recent()
            attn_mask = None
            is_causal = False
        else:
            keys, values = self.kv_cache.update(input_pos, xk, xv) if self.kv_cache else (xk, xv)
            if (self.sliding_window is not None) and (self.kv_cache is not None):
                end = int(input_pos[-1].item())
                start = max(0, end - self.sliding_window + 1)
                keys   = keys[:, :, start:end+1, :]
                values = values[:, :, start:end+1, :]
                attn_mask = None
                is_causal = False
            else:
                attn_mask = mask
                is_causal = (mask is None)
        
    
        
        output = F.scaled_dot_product_attention(
            xq, keys, values, 
            attn_mask=attn_mask, 
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p if self.training else 0,
            enable_gqa=True, 
        ) 
        
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.embed_dim)

        output = self.resid_dropout(self.wo(output))
    
        return output


class StreamingTransformerLayer(nn.Module):
    """TransformerLayer with Streaming / Causal support.

    Args:
        d_model (int): Dimension of the data.
        num_heads (int): Number of heads.
        dim_feedforward (int): Intermediate dimension of FF module.
        causal (bool): Causal mask applied automatically.
        context (int, optional): Receptive field for the causal mask, infinite if None.
        custom (bool): Use custom MHA implementation, for testing / benchmarking.
        rope (`RotaryEmbedding`, optional): Rope embedding to use.
        norm (str): Normalization to use. Currently, only 'layer_norm' is supported.
        layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale.
        gating (str): if provided, replaces FFN with special gating, like GLU, GSiGLU etc.
        weights_per_step (int): use different weights per time step. If non zero, should correspond to the
            number of possible time steps.
        skip_self_attn: If true, skips the self attention module and the norm
        device (torch.device, optional): Device on which to initialize.
        dtype (torch.dtype, optional): dtype to use.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        layer_scale: tp.Optional[float] = None,
        drop_path=0.0,
        ffn_dim_multiplier: tp.Optional[float] = None,
        ffn_dropout_p: float = 0.1,
        multiple_of: int = 256,
        num_kv_heads: tp.Optional[int] = None,
        attn_dropout_p: float = 0.0,
        resid_dropout_p: float = 0.1,
        norm_type: str = "rms",
        
    ):
        super().__init__()
        
        self.attention = StreamingMultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            attn_dropout_p=attn_dropout_p,
            resid_dropout_p=resid_dropout_p,
        )
        self.feed_forward = FeedForward(d_model, ffn_dim_multiplier, ffn_dropout_p, multiple_of)
        if norm_type == "rms":
            # self.attention_norm = RMSNorm(d_model, eps=1e-5)
            # self.ffn_norm = RMSNorm(d_model, eps=1e-5)
            self.attention_norm = nn.RMSNorm(d_model, eps=1e-5)
            self.ffn_norm = nn.RMSNorm(d_model, eps=1e-5)
        elif norm_type == "dyt":
            self.attention_norm = DyT(d_model, init_alpha=0.5, eps=1e-5)
            self.ffn_norm = DyT(d_model, init_alpha=0.5, eps=1e-5)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.layer_scale_1: nn.Module
        self.layer_scale_2: nn.Module
        
        if layer_scale is None:
            self.layer_scale_1 = nn.Identity()
            self.layer_scale_2 = nn.Identity()
        else:
            self.layer_scale_1 = LayerScale(d_model, layer_scale) 
            self.layer_scale_2 = LayerScale(d_model, layer_scale)

    def forward(
        self, x: torch.Tensor, 
        freqs_cis: torch.Tensor, 
        start_pos: int, 
        mask: tp.Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        h = x + self.drop_path(self.layer_scale_1(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))) # attention
        out = h + self.drop_path(self.layer_scale_2(self.feed_forward(self.ffn_norm(h))))
        return out


class StreamingTransformer(nn.Module):
    """Transformer with Streaming / Causal support.

    Args:
        d_model (int): Dimension of the data.
        num_heads (int): Number of heads.
        dim_feedforward (int): Intermediate dimension of FF module.
        causal (bool): Causal mask applied automatically.
        context (int, optional): Receptive field for the causal mask, infinite if None.
        layer_scale (float, optional): If not None, LayerScale will be used
            with the given value as initial scale.
        positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none).
        max_period (float): Maximum period of the time embedding.
        positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
        layer_class: (subclass of `StreamingTransformerLayer): class to use
            to initialize the layers, allowing further customization outside of AudioCraft.
        device (torch.device, optional): Device on which to initialize.
        dtype (torch.dtype, optional): dtype to use.
        **kwargs: See `StreamingTransformerLayer`.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_layers: int,
        seq_len: int,
        num_types: int = 2,
        type_drop_p: float = 0.1,
        layer_scale: tp.Optional[float] = None,
        drop_path_rate: float = 0.0,
        ffn_dim_multiplier: tp.Optional[float] = None,
        ffn_dropout_p: float = 0.1,
        multiple_of: int = 256,
        num_kv_heads: tp.Optional[int] = None,
        attn_dropout_p: float = 0.0,
        resid_dropout_p: float = 0.1,
        initializer_range: float = 0.02,
        token_dropout_p: float = 0.0,
        max_period: float = 10_000,
        input_type: str = "interleave",
        noise_augmentation: bool = False,
        trans_cfg_dropout_prob: float = 0.1,
        audio_pretraining: bool = False,
        layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
        k_max: float = 0.5,
        num_aggregated_tokens: int = 1,
        norm_type: str = "rms",
        condition_merge: bool = False,
        enable_rar: bool = False,
        rope_scaling: tp.Optional[dict] = None,
        **kwargs,
    ):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.max_period = max_period
        self.input_type = input_type
        self.noise_augmentation = noise_augmentation
        self.k_max = k_max
        self.trans_cfg_dropout_prob = trans_cfg_dropout_prob
        self.audio_pretraining = audio_pretraining
        self.num_aggregated_tokens = num_aggregated_tokens
        self.condition_merge = condition_merge
        self.num_layers = num_layers  
        self.enable_rar = enable_rar
        self.context_extension_mode = "none"   # {"pi", "ntk", "sliding", "none"}
        self.sliding_window = None   
        if rope_scaling is not None:
            self.rope_scaling = rope_scaling
        
        if self.trans_cfg_dropout_prob > 0.0 or self.audio_pretraining:
            self.null_embedding = nn.Parameter(torch.randn(1, d_model) / math.sqrt(d_model))
        if condition_merge:
            self.init_audio_token = nn.Parameter(torch.randn(1, d_model) / math.sqrt(d_model))
        
        if input_type == "interleave":
            self.seq_len = seq_len + (seq_len * num_aggregated_tokens) # audio seq_length + video seq_length
        elif input_type == "concat":
            self.seq_len = seq_len
            self.concat_proj = nn.Linear(2 * d_model, d_model, bias=False)
        else:
            raise ValueError(f"Invalid input type: {input_type}")

        # rope
        self.freqs_cis = precompute_freqs_cis(
            self.seq_len, 
            d_model // num_heads, 
            max_period,
            rope_scaling=rope_scaling, 
        )
        # self.register_buffer("freqs_cis", freqs_cis)
        
        self.x_token_dropout = nn.Dropout(token_dropout_p)
        self.y_token_dropout = nn.Dropout(token_dropout_p)
        
        if self.enable_rar:
            self.tape = TargetAwarePE(d_model, max_len=self.seq_len, base=max_period)
            self.use_tape_inference = False 
            self.rar_schedule = AnnealSchedule()

        # type emedding
        if num_types > 0:
            self.type_emb = TypeEmbedding(num_types=num_types, embed_dim=d_model, dropout=type_drop_p)
        else:
            self.type_emb = None
        
        # self.mask_tokened = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        # output layer
        if norm_type == "rms":
            # self.output_norm = RMSNorm(d_model, eps=1e-5)
            self.output_norm = nn.RMSNorm(d_model, eps=1e-5)
        elif norm_type == "dyt":
            self.output_norm = DyT(d_model, init_alpha=0.5, eps=1e-5)
        
        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
        
        self.layers = nn.ModuleList()
        for layer_id in range(num_layers):
            self.layers.append(
                layer_class(
                    d_model=d_model,
                    num_heads=num_heads,
                    layer_scale=layer_scale,
                    drop_path=dpr[layer_id],
                    ffn_dim_multiplier=ffn_dim_multiplier,
                    ffn_dropout_p=ffn_dropout_p,
                    multiple_of=multiple_of,
                    num_kv_heads=num_kv_heads,
                    attn_dropout_p=attn_dropout_p,
                    resid_dropout_p=resid_dropout_p,
                    norm_type=norm_type,
                    **kwargs,
                )
            )
        # KVCache
        self.max_batch_size = -1
        self.max_seq_length = -1
        
        # weight init
        self.initializer_range = initializer_range
        self.initialize_weights()
    
    def set_context_extension(self, mode: str, *, factor: float = None, window_size: int = None):
        assert mode in {"pi", "ntk", "sliding", "none"}
        self.context_extension_mode = mode
        if mode in {"pi", "ntk"}:
            assert factor is not None and factor >= 1.0
            self.rope_scaling = {"type": ("linear" if mode=="pi" else "ntk"), "factor": factor}
            self.sliding_window = None
        elif mode == "sliding":
            assert window_size is not None and window_size > 0
            self.rope_scaling = None
            self.sliding_window = int(window_size)
        else:
            self.rope_scaling = None
            self.sliding_window = None

    def initialize_weights(self):        
        # Initialize nn.Linear and nn.Embedding
        self.apply(self._init_weights)
        
        scaled_std = self.initializer_range / math.sqrt(2 * self.num_layers)
        for layer_module in self.layers:
            if hasattr(layer_module, 'feed_forward') and hasattr(layer_module.feed_forward, 'w2'):
                nn.init.normal_(layer_module.feed_forward.w2.weight, mean=0.0, std=scaled_std)
                if layer_module.feed_forward.w2.bias is not None:
                    nn.init.zeros_(layer_module.feed_forward.w2.bias)

            if hasattr(layer_module, 'attention') and hasattr(layer_module.attention, 'wo'):
                nn.init.normal_(layer_module.attention.wo.weight, mean=0.0, std=scaled_std)
                if layer_module.attention.wo.bias is not None:
                    nn.init.zeros_(layer_module.attention.wo.bias)

    def _init_weights(self, module):
        std = self.initializer_range
        if isinstance(module, nn.Linear):
            # module.weight.data.normal_(mean=0.0, std=std)
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
            # if module.bias is not None:
            #     module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)
        
        elif isinstance(module, (nn.LayerNorm, nn.RMSNorm, RMSNorm, DyT)):
            if hasattr(module, 'weight'):
                nn.init.ones_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                nn.init.zeros_(module.bias)

    def setup_caches(self, max_batch_size, max_seq_length, dtype):
        # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
        #     return
        head_dim = self.d_model // self.num_heads
        max_seq_length = find_multiple(max_seq_length, 8) # max_seq_length should be multiple of 8
        self.max_seq_length = max_seq_length
        self.max_batch_size = max_batch_size
        if self.sliding_window is not None:
            for b in self.layers:
                b.attention.kv_cache = RingKVCache(max_batch_size, self.sliding_window,
                                                self.num_heads, head_dim, dtype)
                b.attention.sliding_window = self.sliding_window
        else:
            for b in self.layers:
                b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.num_heads, head_dim, dtype)
                b.attention.sliding_window = None
        # standard causal mask
        causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
        self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
        self.freqs_cis = precompute_freqs_cis(
            self.max_seq_length, 
            self.d_model // self.num_heads, 
            self.max_period,
            rope_scaling=self.rope_scaling,
            train_seq_len=self.seq_len,
        )

    
    def forward(
        self, 
        video: tp.Optional[torch.Tensor] = None, # video or uncond token
        audio: tp.Optional[torch.Tensor] = None, # audio token
        input_pos: tp.Optional[torch.Tensor] = None, 
        targets: tp.Optional[torch.Tensor] = None,
        mask: tp.Optional[torch.Tensor] = None,
        valid: tp.Optional[torch.Tensor] = None,  
        global_step: int = None,
        total_steps: int = None, 
        perm_ratio: tp.Optional[float] = None,
        *args, 
        **kwargs
    ):
        assert self.freqs_cis is not None, "Caches must be initialized first"
        if self.condition_merge:
            B, T, C = audio.shape
            audio = torch.concat([self.init_audio_token.expand(B, 1, C), audio], dim=1)[:, :-1] # add init token to audio and drop last token
        
        
        if self.audio_pretraining:
            assert video is None and audio is not None
            B, T, C = audio.shape
            if self.type_emb is not None:
                type_ids_y = int(1) # audio index
                audio = self.type_emb(audio, type_ids_y)
            audio = self.y_token_dropout(audio)
            video = self.null_embedding.expand(B, self.num_aggregated_tokens * T, C) # null embedding for video branch

        else:
            dummy_mask = (video.abs().sum(dim=(-2, -1)) < 1e-6)   # video (b, t, c)
            real_mask  = ~dummy_mask
            # print("real_mask", real_mask)
            if real_mask.any():
                v_real = video[real_mask]
                if self.type_emb is not None:
                    type_id_video = 0                         
                    v_real = self.type_emb(v_real, type_id_video)
                v_real = self.x_token_dropout(v_real)   
                if v_real.dtype != video.dtype:
                    v_real = v_real.to(dtype=video.dtype)  
                video[real_mask] = v_real      
            null_tokens = self.null_embedding.expand(1, video.shape[1], video.shape[2]).to(video.dtype)  # (1,T,C)
            video[dummy_mask] = null_tokens           
            if self.trans_cfg_dropout_prob > 0.0:
                drop_mask = (torch.rand(video.size(0), 1, 1, device=video.device)
                            < self.trans_cfg_dropout_prob)          
                drop_mask &= real_mask.view(-1, 1, 1)
                video = torch.where(drop_mask, null_tokens.expand_as(video), video)
            
            type_id_audio = 1
            if self.type_emb is not None:
                audio = self.type_emb(audio, type_id_audio)
            audio = self.y_token_dropout(audio)
        
        if self.noise_augmentation:
            audio = noise_augment(audio, self.k_max)

        if self.input_type == "interleave":
            if self.condition_merge:
                hidden_state, r = interleave_tokens(audio, video) # audio first
            else:
                hidden_state, r = interleave_tokens(video, audio) # video first
            # TODO: make applicable when r is not 1.
            # if r > 1:
            #     mask = 
        elif self.input_type == "concat": # early fusion.
            assert video.shape[1] == audio.shape[1]
            hidden_state = self.concat_proj(torch.cat((video, audio), dim=-1))
            r = 1.0
        else:
            raise ValueError(f"Invalid input type: {self.input_type}")
        
        
        # apply rar
        hidden_B, hidden_L, _ = hidden_state.shape
        hidden_device = hidden_state.device
        pair_perm = perm_tokens = inv_perm = None
        
        if self.enable_rar:
            alpha = 0.0
            if self.training:
                alpha = self.rar_schedule(global_step=global_step, total_steps=total_steps, perm_ratio=perm_ratio)
            pair_perm, perm_tokens, inv_perm = build_pair_perm(hidden_B, hidden_L, alpha, hidden_device, anchor0=True)
            hidden_state = apply_perm(hidden_state, perm_tokens)

            if self.training or self.use_tape_inference:
                delta_next = compute_tape_delta_next(
                    L=hidden_L, B=hidden_B, device=hidden_device, perm_tokens=perm_tokens,
                )
                tape = self.tape(delta_next)
                hidden_state = hidden_state + alpha * tape
        
        # apply rope
        self.freqs_cis = self.freqs_cis.to(hidden_device)
        if self.training:
            freqs_cis = self.freqs_cis[:hidden_state.size(1)]
        else:
            freqs_cis = self.freqs_cis[input_pos]
        
        # transformer layers
        for layer in self.layers:
            hidden_state = layer(hidden_state, freqs_cis, input_pos, mask)
        
        # output normalization
        hidden_state = self.output_norm(hidden_state)

        if self.enable_rar and self.training and inv_perm is not None:
            hidden_state = apply_perm(hidden_state, inv_perm)
        
        if self.input_type == "interleave":
            assert r == 1.0
            if not self.condition_merge:
                hidden_state = hidden_state[:, ::2, :]

        return hidden_state

    def inference(
        self, 
        input_token: torch.Tensor,
        type_id: int,
        input_pos: tp.Optional[torch.Tensor] = None, 
        mask: tp.Optional[torch.Tensor] = None,
        trans_cfg_scale: float = 1.0, 
        audio_pretraining: bool = False,
        *args, 
        **kwargs
    ):
        bs = input_token.size(0)
        input_token = input_token.view(bs, 1, -1)
        assert self.freqs_cis is not None, "Caches must be initialized first"
        # if x is not None and y is not None: # training or naive inference
        input_emb = self.type_emb(input_token, type_id)
        if type_id == int(0):
            B, T, C = input_emb.shape
            if audio_pretraining:
                hidden_state = self.null_embedding.expand(B, T, C)
            else:
                if trans_cfg_scale > 1.0:
                    hidden_state = self.x_token_dropout(input_emb)
                    cond_null = self.null_embedding.expand(B, T, C)
                    hidden_state = torch.cat([hidden_state, cond_null], dim=0)
                else:
                    hidden_state = self.x_token_dropout(input_emb)
        
        elif type_id == int(1):
            hidden_state = self.y_token_dropout(input_emb)
            if trans_cfg_scale > 1.0 and not audio_pretraining:
                hidden_state = torch.cat([hidden_state, hidden_state], dim=0)

        mask = self.causal_mask[:hidden_state.size(0), None, input_pos]
        
        self.freqs_cis = self.freqs_cis.to(hidden_state.device)
        freqs_cis = self.freqs_cis[input_pos]

        for layer in self.layers:
            hidden_state = layer(hidden_state, freqs_cis, input_pos, mask)
        
        hidden_state = self.output_norm(hidden_state)
        
        return hidden_state


class MultiheadAttention(nn.Module):
    """Similar to `nn.MultiheadAttention` for aggregate trasnformer.

    Args:
        embed_dim (int): Dimension to project to.
        num_heads (int): Number of heads.
    """

    _fsdp_final = True

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_kv_heads: tp.Optional[int] = None,
        attn_type: str = "sdpa",
        cross_attention: bool = False,
        attn_dropout_p: float = 0.0,
        resid_dropout_p: float = 0.1,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.attn_type = attn_type
        self.cross_attention = cross_attention
        self.attn_dropout_p = attn_dropout_p
        self.resid_dropout = nn.Dropout(resid_dropout_p)
        # self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        if self.cross_attention:
            self.wq = nn.Linear(embed_dim, embed_dim, bias=False)
            self.wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
        else:
            # total_kv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
            # key, query, value projections for all heads, but in a batch
            self.wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False) # in_proj
        self.wo = nn.Linear(embed_dim, embed_dim, bias=False) # out_proj
    #     return out

    def forward(
        self, 
        query: torch.Tensor,
        context: tp.Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        
        bsz, seqlen_q, _ = query.shape
        if self.cross_attention:
            assert context is not None, "Context must be provided for cross-attention"
            
            _, seqlen_k, _ = context.shape
            xq = self.wq(query)
            xq = xq.view(bsz, seqlen_q, self.num_heads, self.head_dim).transpose(1,2)
            # key/value projection from key (or x)
            xk, xv = self.wkv(context).split(self.embed_dim, dim=-1)
            
            xk = xk.view(bsz, seqlen_k, self.num_heads, self.head_dim).transpose(1, 2)
            xv = xv.view(bsz, seqlen_k, self.num_heads, self.head_dim).transpose(1, 2)
            # xk, xv = map(lambda x: x.transpose(1, 2), (xk, xv))
            
            keys, values = xk, xv
        else:
            xq, xk, xv = self.wqkv(query).split(self.embed_dim, dim=-1)
            
            xq = xq.view(bsz, seqlen_q, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, seqlen_q, head_dim]
            xk = xk.view(bsz, seqlen_q, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, seqlen_q, head_dim]
            xv = xv.view(bsz, seqlen_q, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, seqlen_q, head_dim]
            # xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
            
            keys, values = xk, xv
            # keys = keys.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
            # values = values.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
        
        output = F.scaled_dot_product_attention(
                xq, keys, values, 
                dropout_p=self.attn_dropout_p if self.training else 0,
                is_causal=False,
            )

        # output = output.transpose(1, 2).contiguous().view(bsz, seqlen_q, self.embed_dim)
        output = output.transpose(1, 2).reshape(bsz, seqlen_q, self.embed_dim)

        output = self.resid_dropout(self.wo(output))
    
        return output


class AggregateTransformerLayer(nn.Module):
    """TransformerLayer.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        layer_scale: tp.Optional[float] = None,
        drop_path=0.0,
        ffn_dim_multiplier: tp.Optional[float] = None,
        ffn_dropout_p: float = 0.1,
        multiple_of: int = 256,
        num_kv_heads: tp.Optional[int] = None,
        attn_dropout_p: float = 0.0,
        resid_dropout_p: float = 0.1,
        attn_type: str = "sdpa",
        use_cross_attention: bool = False,
        norm_type_agg: str = "rms",
    ):
        super().__init__()
        self.use_cross_attention = use_cross_attention
        
        self.attention = MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            attn_type=attn_type,
            cross_attention=use_cross_attention,
            attn_dropout_p=attn_dropout_p,
            resid_dropout_p=resid_dropout_p,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.feed_forward = FeedForward(d_model, ffn_dim_multiplier, ffn_dropout_p, multiple_of)
        if norm_type_agg == "rms":
            self.attention_norm = nn.RMSNorm(d_model, eps=1e-5)
            self.ffn_norm = nn.RMSNorm(d_model, eps=1e-5)
        elif norm_type_agg == "dyt":
            self.attention_norm = DyT(d_model, init_alpha=0.5, eps=1e-5)
            self.ffn_norm = DyT(d_model, init_alpha=0.5, eps=1e-5)
        self.layer_scale_1: nn.Module
        self.layer_scale_2: nn.Module
        
        if layer_scale is None:
            self.layer_scale_1 = nn.Identity()
            self.layer_scale_2 = nn.Identity()
        else:
            self.layer_scale_1 = LayerScale(d_model, layer_scale) 
            self.layer_scale_2 = LayerScale(d_model, layer_scale)
            
    def forward(
        self, x: torch.Tensor, 
        y: tp.Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # print("Flash attention", torch.backends.cuda.is_flash_attention_available())
        # if not self.use_cross_attention:
        h = x + self.drop_path(self.layer_scale_1(self.attention(self.attention_norm(x)))) # attention
        out = h + self.drop_path(self.layer_scale_2(self.feed_forward(self.ffn_norm(h))))
        return out

class AggregateTransformer(nn.Module):
    """Transformer for spatial aggregation on video frames.

    Args:
        d_model (int): Dimension of the data.
        num_heads (int): Number of heads.
        dim_feedforward (int): Intermediate dimension of FF module.
        causal (bool): Causal mask applied automatically.
        context (int, optional): Receptive field for the causal mask, infinite if None.
        layer_scale (float, optional): If not None, LayerScale will be used
            with the given value as initial scale.
        positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none).
        max_period (float): Maximum period of the time embedding.
        positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
        layer_class: (subclass of `StreamingTransformerLayer): class to use
            to initialize the layers, allowing further customization outside of AudioCraft.
        device (torch.device, optional): Device on which to initialize.
        dtype (torch.dtype, optional): dtype to use.
        **kwargs: See `StreamingTransformerLayer`.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_layers: int,
        grid_feature_length: int,
        num_aggregated_tokens: int = 1,
        layer_scale: tp.Optional[float] = None,
        ffn_dim_multiplier: tp.Optional[float] = None,
        ffn_dropout_p: float = 0.1,
        attn_dropout_p: float = 0.0,
        resid_dropout_p: float = 0.1,
        multiple_of: int = 256,
        num_kv_heads: tp.Optional[int] = None,
        initializer_range: float = 0.02,
        layer_class: tp.Type[AggregateTransformerLayer] = AggregateTransformerLayer,
        aggregate_trans_architecture: str = "linear",
        aggregation_method: str = "self",
        norm_type_agg: str = "rms",
        **kwargs,
    ):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_aggregated_tokens = num_aggregated_tokens
        self.aggregate_trans_architecture = aggregate_trans_architecture
        self.aggregation_method = aggregation_method
        
        # positonal embedding
        scale = self.d_model ** -0.5
        self.positional_embedding = nn.Parameter(
            scale * torch.randn(grid_feature_length, self.d_model)
            )
        self.aggregated_token_positional_embedding = nn.Parameter(
            scale * torch.randn(num_aggregated_tokens, self.d_model)
            )

        # output layer
        if norm_type_agg == "rms":
            self.output_norm = nn.RMSNorm(d_model, eps=1e-5)
        elif norm_type_agg == "dyt":
            self.output_norm = DyT(d_model, init_alpha=0.5, eps=1e-5)
        
        self.layers = nn.ModuleList()
        use_cross = (aggregation_method == "cross")
        for layer_id in range(num_layers):
            self.layers.append(
                layer_class(
                    d_model=d_model,
                    num_heads=num_heads,
                    layer_scale=layer_scale,
                    ffn_dim_multiplier=ffn_dim_multiplier,
                    ffn_dropout_p=ffn_dropout_p,
                    multiple_of=multiple_of,
                    num_kv_heads=num_kv_heads,
                    attn_type = aggregate_trans_architecture,
                    use_cross_attention=use_cross,
                    norm_type_agg=norm_type_agg,
                    attn_dropout_p=attn_dropout_p,
                    resid_dropout_p=resid_dropout_p,
                    **kwargs,
                )
            )
        # weight init
        self.initializer_range = initializer_range
        self.initialize_weights()
        
    def initialize_weights(self):        
        # Initialize nn.Linear and nn.Embedding
        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = self.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)

    def forward(
        self, 
        x: tp.Optional[torch.Tensor] = None, 
        y: tp.Optional[torch.Tensor] = None,
        *args, 
        **kwargs
    ):
        # add PE

        x = x + self.positional_embedding
        y = y + self.aggregated_token_positional_embedding
        if self.aggregation_method == "self":
            hidden_state = torch.cat((x, y), dim=1)
            for layer in self.layers:
                hidden_state = layer(hidden_state)
            hidden_state = self.output_norm(hidden_state)
            hidden_state = hidden_state[:, -self.num_aggregated_tokens:]
            
            return hidden_state
        elif self.aggregation_method == "cross":
            for layer in self.layers:
                hidden_state = layer(x, y)
            hidden_state = self.output_norm(hidden_state)
            
            return hidden_state
        