from torch import nn
from einops import rearrange, repeat
import torch
import math
from xformers.ops.fmha import memory_efficient_attention
from .moe import MoeLayer
from .positional_embeddings import sinusoidal_positional_embedding1d, sinusoidal_positional_embedding2d, precompute_freqs_cis, apply_rotary_emb

def modulate(x, shift, scale):
    return x * (1 + scale) + shift

def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
    return keys, values

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def pad_token2maxlen(token, max_len, dim):
    current_len = token.size(dim)
    padding_len = max_len - current_len

    if padding_len <= 0:
        return token
    
    pad_shape = list(token.shape)
    pad_shape[dim] = padding_len
    padding = torch.full(pad_shape, 0, dtype=token.dtype, device=token.device)
    return torch.cat([token, padding], dim=dim)

def tensor_or_list_op(x, op, *args, **kwargs):
    # [(x, xx), (x, xx), ...] -> [(op(x_list), op(xx_list)]
    # 
    is_single = kwargs.pop("is_single", False)
    if isinstance(x, list):
        if is_single:
            return [op(item, *args, **kwargs) for item in x]
        return [op(item, *args, **kwargs) for item in zip(*x)]
    
    return op(x, *args, **kwargs)
    

class Attention(nn.Module):
    def __init__(self, n_heads, hidden_dim, n_kv_heads, cross_attn_dim=None, positional_embeddings=None):
        super().__init__()
    
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim

        head_dim = hidden_dim // n_heads
        self.head_dim = head_dim

        self.n_kv_heads: int = n_kv_heads
        self.scale = self.head_dim**-0.5

        kv_dim = default(cross_attn_dim, hidden_dim)
        kv_proj_dim = n_kv_heads * head_dim if exists(cross_attn_dim) else hidden_dim
        self.repeats = hidden_dim // kv_proj_dim

        self.wq = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.wk = nn.Linear(kv_dim, kv_proj_dim, bias=False)
        self.wv = nn.Linear(kv_dim, kv_proj_dim, bias=False)

        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.positional_embeddings = positional_embeddings

    def forward(
        self,
        x: torch.Tensor,
        content: torch.Tensor = None,
        attn_mask: torch.Tensor = None,
        content_mask: torch.Tensor = None,
        freqs_cis: torch.Tensor = None
    ) -> torch.Tensor:
        b, n, d = x.shape

        content = default(content, x)
        xq, xk, xv = self.wq(x), self.wk(content), self.wv(content)

        xq = rearrange(xq, "b n (h d) -> b n h d", d=self.head_dim)
        xk = rearrange(xk, "b n (h d) -> b n h d", d=self.head_dim)
        xv = rearrange(xv, "b n (h d) -> b n h d", d=self.head_dim)

        if not exists(content_mask) and exists(freqs_cis):
            if len(freqs_cis) == 1:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

            elif len(freqs_cis) == 2:
                freqs_cis_h, freqs_cis_w = freqs_cis
                dim = self.head_dim // 2
                xhq, xhk = apply_rotary_emb(xq[..., :dim], xk[..., :dim], freqs_cis=freqs_cis_h)
                xwq, xwk = apply_rotary_emb(xq[..., dim:], xk[..., dim:], freqs_cis=freqs_cis_w)
                xq = torch.cat([xhq, xwq], dim=-1)
                xk = torch.cat([xhk, xwk], dim=-1)

        key, val = repeat_kv(xk, xv, self.repeats, dim=1)

        mask = None

        if exists(content_mask):
            assert len(content_mask.shape) == 2
            mask = (1 - content_mask.to(x.dtype)) * -10000.0
            mask = repeat(mask, "b l -> b h n l", h=self.n_heads, n = n).to(xq.dtype)
        elif exists(attn_mask):
            assert len(attn_mask.shape) == 3
            mask = (1 - attn_mask.to(x.dtype)) * -10000.0
            mask = repeat(mask, "b i j -> b h i j", h=self.n_heads).to(xq.dtype)

        output = memory_efficient_attention(
            xq, key, val, mask
        )

        return self.wo(output.view(-1, n, self.hidden_dim))



class FeedForward(nn.Module):
    def __init__(self, hidden_dim, mlp_ratio):
        super().__init__()
        intermediate_size = int(hidden_dim * mlp_ratio)
        self.w1 = nn.Linear(hidden_dim, intermediate_size, bias=False)
        self.w2 = nn.Linear(intermediate_size, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, intermediate_size, bias=False)

    def forward(self, x) -> torch.Tensor:
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

# https://arxiv.org/pdf/1910.07467.pdf
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x, *args, **kwargs):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class AdaRMSNorm(nn.Module):
    r"""
    Norm layer modified to incorporate timestep embeddings.

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.
        num_embeddings (`int`): The size of the embeddings dictionary.
    """

    def __init__(self, embedding_dim: int, eps: float = 1e-6):
        super().__init__()
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = RMSNorm(embedding_dim, eps=eps)

    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
        emb = self.linear(self.silu(emb))
        scale, shift = torch.chunk(emb, 2, dim=-1)
        x = self.norm(x) * (1 + scale) + shift
        return x


class TransformerBlock(nn.Module):
    def __init__(self, n_heads, hidden_dim, n_kv_heads, norm_eps, cross_attn_dim = None, 
    mlp_ratio = None, positional_embeddings=None, moe=None, norm_layer = RMSNorm):
        super().__init__()
        self.attention1 = Attention(n_heads, hidden_dim, n_kv_heads, positional_embeddings=positional_embeddings)
        self.attention_norm1 = norm_layer(hidden_dim, eps=norm_eps)

        if exists(cross_attn_dim):
            self.attention2 = Attention(n_heads, hidden_dim, n_kv_heads, cross_attn_dim=cross_attn_dim, positional_embeddings=positional_embeddings)
            self.attention_norm2 = norm_layer(hidden_dim, eps=norm_eps)

        self.ffn_norm = norm_layer(hidden_dim, eps=norm_eps)

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, 6 * hidden_dim, bias=True),
        )

        if exists(moe):
            self.feed_forward = MoeLayer(
                experts=[FeedForward(hidden_dim, mlp_ratio) for _ in range(moe['num_experts'])],
                gate=nn.Linear(hidden_dim, moe['num_experts'], bias=False),
                moe_args=moe,
            )
        else:
            self.feed_forward = FeedForward(hidden_dim, mlp_ratio)

    def forward(self, x, adaln_input, content = None, attn_mask = None, content_mask = None, freqs_cis = None):
        '''
        TODO: add emb inject
        NOTE: maybe we can reduce the ada norm computation as same as rmsnorm
        '''
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(adaln_input).chunk(6, dim=-1)
        
        x = x + gate_msa * self.attention1(
                modulate(self.attention_norm1(x), shift_msa, scale_msa),
                attn_mask = attn_mask,
                freqs_cis = freqs_cis
            )

        if hasattr(self, "attention2"):
            norm_x2 = self.attention_norm2(x)
            x = self.attention2(norm_x2, content=content, content_mask=content_mask, ) + x

        x = x + gate_mlp * self.feed_forward(
            modulate(self.ffn_norm(x), shift_mlp, scale_mlp),
        )
        return x

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        dtype = t.dtype
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_freq = t_freq.to(dtype)
        t_emb = self.mlp(t_freq)
        return t_emb
    
# TODO: add 
class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6,
        )
        self.linear = nn.Linear(
            hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                hidden_size, 2 * hidden_size, bias=True),
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x
    
# TODO: add 
class Transformer(nn.Module):
    def __init__(
        self,
        n_layers: int,
        patch_size: int,
        n_heads: int,
        n_kv_heads: int,
        hidden_dim: int,
        norm_eps: float,
        max_hw_len: int,
        max_video_frame: int,
        max_image_frame: int,
        cross_attn_dim: int = None,
        mlp_ratio: int = 3.5,
        latent_dim: int = 4,
        positional_embeddings = None,
        moe: dict = None,
        norm_layer: nn.Module = RMSNorm
    ):
        super().__init__()

        self.patch_size = patch_size
        patch_dim = (patch_size) ** 2 * latent_dim
        self.patch_embedding = nn.Linear(patch_dim, hidden_dim)
        # self.final_layer = nn.Linear(hidden_dim, patch_size, self.out_channels)
        self.final_layer = FinalLayer(hidden_dim, patch_size, latent_dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(
                n_heads,
                hidden_dim,
                n_kv_heads=n_kv_heads,
                norm_eps=norm_eps,
                cross_attn_dim=cross_attn_dim if i % 2 == 0 else None,
                mlp_ratio=mlp_ratio,
                positional_embeddings=positional_embeddings,
                moe=moe,
                norm_layer=norm_layer
            )
        for i in range(n_layers)] )
        self.norm = norm_layer(hidden_dim, eps=norm_eps)
        self.output = nn.Linear(hidden_dim, patch_dim)
        self.max_hw_len = max_hw_len
        self.max_video_frame = max_video_frame
        self.max_image_frame = max_image_frame
        self.pad_token = nn.Parameter(torch.empty(hidden_dim))
        self.t_embedder = TimestepEmbedder(hidden_dim)
        self.cap_embedder = nn.Sequential(
            nn.LayerNorm(cross_attn_dim),
            nn.Linear(cross_attn_dim, hidden_dim, bias=True),
        )
        self.positional_embeddings = positional_embeddings
        self.hidden_dim = hidden_dim
        self.pe_pad_token = nn.Parameter(torch.empty(hidden_dim)) if self.positional_embeddings != "rope2d" else nn.Parameter(torch.empty(hidden_dim // 2))

        self.head_dim = hidden_dim // n_heads
        self.pre_cache = False


    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    def pre_cache_positional_embeddings(self):
        if self.positional_embeddings.startswith("sinusoidal_positional_embedding"):
            pe_cache1d = sinusoidal_positional_embedding1d(self.hidden_dim)
            self.register_buffer('pe_cache1d', pe_cache1d)
            if self.positional_embeddings == "sinusoidal_positional_embedding2d":
                pe_cache2d = sinusoidal_positional_embedding2d(self.hidden_dim)
                self.register_buffer('pe_cache2d', pe_cache2d)
        elif self.positional_embeddings.startswith("rope"):
            pe_cache1d = precompute_freqs_cis(self.head_dim).to(self.device)
            # self.register_buffer('pe_cache1d', pe_cache1d, persistent=False)
            self.pe_cache1d = pe_cache1d
            if self.positional_embeddings == "rope2d":
                pe_cache2d = precompute_freqs_cis(self.head_dim//2)
                self.pe_cache2d = pe_cache2d.to(self.device)
                # self.register_buffer('pe_cache2d', pe_cache2d, persistent=False)

        self.pre_cache = True
    
    def unpatchify(self, x_embs, video_shape_list, images_shape_list):
        """
        x: (b, f, max_hw, (phz pwz c))
        imgs: (N, H, W, C)
        """
        patch_size = self.patch_size
        video_embs = x_embs[:, :self.max_video_frame]
        image_embs = x_embs[:, self.max_video_frame:]

        video_latent_list, images_latent_list = [], []
        for idx, video_shape in enumerate(video_shape_list):
            f, h, w = video_shape
            num_patches = h * w // (patch_size ** 2)
            video_token = video_embs[idx, :, :num_patches]
            video_latent = rearrange(video_token, "f (ph pw) (phz pwz c) ->f c (ph phz) (pw pwz)", phz=patch_size, pwz=patch_size, ph=h//patch_size, pw=w//patch_size)
            video_latent_list.append(video_latent)
        
        for i, image_shape_list in enumerate(images_shape_list):
            image_latent_list = []
            for j, image_shape in enumerate(image_shape_list):
                h, w = image_shape
                num_patches = h * w // (patch_size ** 2)
                image_token = image_embs[i, j, :num_patches]
                image_latent = rearrange(image_token, "(ph pw) (phz pwz c) ->c (ph phz) (pw pwz)", phz=patch_size, pwz=patch_size, ph=h//patch_size, pw=w//patch_size)
                image_latent_list.append(image_latent)
            images_latent_list.append(image_latent_list)

        return video_latent_list, images_latent_list

    def build_positional_embedding(self, ph, pw, max_hw, f = -1):
        if not self.pre_cache:
            self.pre_cache_positional_embeddings()
            
        if self.positional_embeddings == "sinusoidal_positional_embedding1d":
            pe = self.pe_cache1d[:ph*pw]
            pe = pad_token2maxlen(pe, max_hw, dim=0)
          
        elif self.positional_embeddings == "sinusoidal_positional_embedding2d":
            pe = self.pe_cache2d[:ph, :pw]
            pe = rearrange(pe, "h w d -> (h w) d")
            pe = pad_token2maxlen(pe, max_hw, dim=0)

        elif self.positional_embeddings == "rope1d":
            pe = self.pe_cache1d[:ph*pw]
            pe = pad_token2maxlen(pe, max_hw, dim=0)
        
        elif self.positional_embeddings == "rope2d":
            freqs_h = self.pe_cache2d[:ph] # h x d
            freqs_h = repeat(freqs_h, "h d -> (h w) d", w = pw)
            freqs_h = pad_token2maxlen(freqs_h, max_hw, dim=0)
            freqs_w = self.pe_cache2d[:pw] # w x d
            freqs_w = repeat(freqs_w, "w d -> (h w) d", h = ph)
            freqs_w = pad_token2maxlen(freqs_w, max_hw, dim=0)
            pe = [freqs_h, freqs_w] # h w d

        if f >= 1:
            if isinstance(pe, list):
                pe = [repeat(item, "n d -> f n d", f=f) for item in pe]
            else:
                pe = repeat(pe, "n d -> f n d", f=f)
        return pe               

        
    def patchify_and_embed(self, video_latent_list, images_latent_list):
        patch_size = self.patch_size
        video_mask_list, images_mask_list, video_shape_list, images_shape_list = [], [], [], []
        video_token_list, images_token_list = [], []
        video_spatial_pe_list, images_spatial_pe_list = [], []

        for video_latent in video_latent_list:
            f, c, h, w = video_latent.shape
            video_shape_list.append((f, h, w))
            # TODO: temporal patch embedding
            ph, pw = h // patch_size, w // patch_size
            num_patches = ph * pw
            
            video_token = rearrange(video_latent, "f c (ph phz) (pw pwz) -> f (ph pw) (phz pwz c)", phz=patch_size, pwz=patch_size)
            video_token = self.patch_embedding(video_token)
            
            mask = torch.zeros((f, self.max_hw_len, self.max_hw_len), dtype=torch.bool, device=video_token.device)
            mask[:, :num_patches, :num_patches] = 1
            video_mask_list.append(mask)

            pe = self.build_positional_embedding(ph, pw, self.max_hw_len, f=self.max_video_frame) # f x n x d
            video_token = pad_token2maxlen(video_token, max_len=self.max_hw_len, dim=1) # f n d

            video_token_list.append(video_token)
            video_spatial_pe_list.append(pe)

        video_embs = torch.stack(video_token_list, dim=0) # b x f1 x max_hw x c
        video_mask = torch.stack(video_mask_list, dim=0) # b x f1 x max_hw x max_hw

        video_spatial_pe = tensor_or_list_op(video_spatial_pe_list, torch.stack, dim = 0) # [ (x1, y1), (x2, y2)] -> [op(x1, x2), op(y1, y2)]
        
        if len(images_latent_list[0]) > 0:
            for images_latent in images_latent_list:
                shape_list, image_token_list, batch_mask_list, pe_list = [], [], [], []

                for image_latent in images_latent:
                    c, h, w = image_latent.shape
                    shape_list.append((h, w))
                    ph, pw = h // patch_size, w // patch_size
                    num_patches = ph * pw

                    image_latent = rearrange(image_latent, "c (ph phz) (pw pwz) -> (ph pw) (phz pwz c)", phz=patch_size, pwz=patch_size)
                    image_token = self.patch_embedding(image_latent)

                    mask = torch.zeros((self.max_hw_len, self.max_hw_len), dtype=torch.bool, device=video_token.device)
                    mask[:num_patches, :num_patches] = 1
                    batch_mask_list.append(mask)

                    pe = self.build_positional_embedding(ph, pw, self.max_hw_len)
                    image_token = pad_token2maxlen(image_token, self.max_hw_len, dim=0) # max_len x c

                    image_token_list.append(image_token) # max_hw_len x c
                    pe_list.append(pe) if exists(pe) else None

                images_token_list.append(torch.stack(image_token_list, dim = 0))  # [f x max_hw x c]
                images_shape_list.append(shape_list)
                images_mask_list.append(torch.stack(batch_mask_list, dim=0))
                images_spatial_pe_list.append(tensor_or_list_op(pe_list, torch.stack, dim=0)) # max_image_frames x n x d/ f x (2 x n x d)

            images_embs = torch.stack(images_token_list, dim=0) # b x f2 x max_hw x c
            images_mask = torch.stack(images_mask_list, dim=0)
            images_spatial_pe = tensor_or_list_op(images_spatial_pe_list, torch.stack, dim = 0)
            

            x_embs = torch.concat([video_embs, images_embs], dim=1) # b x f x max_hw x c
            spatial_mask = rearrange(torch.concat([video_mask, images_mask], dim=1), "b f i j-> (b f) i j") # bf x max_hw x max_hw 
            spatial_positional_embeddings = tensor_or_list_op([video_spatial_pe, images_spatial_pe], torch.concat, dim=1)
        else:
            x_embs = video_embs
            spatial_mask = rearrange(video_mask, "b f i j -> (b f) i j")
            spatial_positional_embeddings = video_spatial_pe

        b, total_num_frame, max_hw, c = x_embs.shape
        video_num_frame = video_embs.shape[1]
        temporal_mask = torch.zeros(total_num_frame, total_num_frame).to(x_embs.device)
        temporal_mask.fill_diagonal_(1)
        temporal_mask[:video_num_frame, :video_num_frame] = 1
        temporal_mask = repeat(temporal_mask, "i j -> bhw i j", bhw=(b * max_hw)) # bhw x f x f

        params = {
            "x_embs": x_embs,
            "shape_list": [video_shape_list, images_shape_list],
            "mask": [spatial_mask, temporal_mask]
        }

        spatial_positional_embeddings = tensor_or_list_op(spatial_positional_embeddings, rearrange, "b f n c -> (b f) n c", is_single=True)

        temporal_positional_embeddings = self.pe_cache1d[:total_num_frame]
        temporal_positional_embeddings = repeat(temporal_positional_embeddings, "f d -> (b n) f d", b=b, n=max_hw)

        params['positional_embedding'] = [spatial_positional_embeddings, temporal_positional_embeddings]

        return params


    def set_use_memory_efficient_attention_xformers(
            self, valid: bool, attention_op=None
        ) -> None:
            def fn_recursive_set_mem_eff(module: torch.nn.Module):
                if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                    module.set_use_memory_efficient_attention_xformers(valid, attention_op)

                for child in module.children():
                    fn_recursive_set_mem_eff(child)

            for module in self.children():
                if isinstance(module, torch.nn.Module):
                    fn_recursive_set_mem_eff(module)

    def enable_gradient_checkpointing(self):
            def fn_recursive_enable_checkpointing(module: torch.nn.Module):
                if hasattr(module, "enable_checkpointing"):
                    module.enable_checkpointing()

                for child in module.children():
                    fn_recursive_enable_checkpointing(child)
            for module in self.children():
                if isinstance(module, torch.nn.Module):
                    fn_recursive_enable_checkpointing(module)

    def forward(
        self,
        video_latent_list,
        images_latent_list,
        timesteps,
        cond_embs,
        content_mask = None,
    ) -> torch.Tensor:
        params = self.patchify_and_embed(video_latent_list, images_latent_list)
        x_embs = params['x_embs']
        video_shape_list, images_shape_list = params['shape_list']
        spatial_mask, temporal_mask = params['mask']
        # b x f x n x d
        b, f, n, d = x_embs.shape
        x_embs = rearrange(x_embs, "b f n d -> (b f) n d")
        device = x_embs.device

        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = device.type == "mps"
            if isinstance(timesteps, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(device)
            timesteps = repeat(timesteps, "1 -> b", b=b)
        timesteps = timesteps.to(x_embs.dtype)
        t_embs = self.t_embedder(timesteps)
        cap_mask_float = content_mask.float().unsqueeze(-1) # b x f x max_len x 1
        # cond_embs: b x f x n x d
        cap_feats_pool = ((cond_embs * cap_mask_float).sum(dim=-2) / cap_mask_float.sum(dim=-2)).to(x_embs.dtype) # b x f x d
        cap_emb = self.cap_embedder(cap_feats_pool) # b x f x d

        spatial_t_embs = repeat(t_embs, "b d -> (b f) n d", f = f, n = n) + repeat(cap_emb, "b f d -> (b f) n d", n = n)
        temporal_t_embs = repeat(t_embs, "b d -> (b n) f d", f = f, n = n) + repeat(cap_emb, "b f d -> (b n) f d", n = n)

        spatial_content_mask, temporal_content_mask = content_mask, None

        cond_embs = rearrange(cond_embs, "b f n d -> (b f) n d")
        spatial_content_mask = rearrange(spatial_content_mask, "b f l -> (b f) l") if exists(spatial_content_mask) else None

        spatial_freqs_cis, temporal_freqs_cis = None, None
        if self.positional_embeddings.startswith("rope"):
            spatial_freqs_cis, temporal_freqs_cis = params['positional_embedding']

        for i in range(0, len(self.blocks), 2):
            spatial_block, temporal_block = self.blocks[i:i+2]

            if self.positional_embeddings.startswith("sinusoidal_positional_embedding"):
                x_embs += params['positional_embedding'][0]
            x_embs = spatial_block(x_embs, spatial_t_embs, cond_embs, 
                                   attn_mask=spatial_mask, 
                                   content_mask=spatial_content_mask,
                                   freqs_cis=spatial_freqs_cis)
            
            x_embs = rearrange(x_embs, "(b f) n d -> (b n) f d", f = f)

            if self.positional_embeddings.startswith("sinusoidal_positional_embedding"):
                x_embs[:, :self.max_video_frame] += params['positional_embedding'][1]
            x_embs = temporal_block(x_embs, temporal_t_embs, 
                                    attn_mask = temporal_mask, 
                                    content_mask = temporal_content_mask,
                                    freqs_cis=temporal_freqs_cis)
            x_embs = rearrange(x_embs, "(b n) f d -> (b f) n d", n = n)
        
        # x_embs = self.output(self.norm(x_embs, spatial_t_embs))
        x_embs = self.final_layer(x_embs, spatial_t_embs)
        x_embs = rearrange(x_embs, "(b f) n d -> b f n d", f = f)

        video_pred_list, images_pred_list = self.unpatchify(x_embs, video_shape_list, images_shape_list)
        return video_pred_list, images_pred_list


def Latte_S_2(**kwargs):
    return Transformer(
        n_layers = 12,
        patch_size = 2,
        n_heads = 6,
        n_kv_heads = 6,
        hidden_dim = 384,
        norm_eps = 1e-05,
        cross_attn_dim = 4096,
        mlp_ratio = 3.5,
        **kwargs
    )

def Latte_XL_2(**kwargs):
    return Transformer(
        n_layers = 28,
        patch_size = 2,
        n_heads = 16,
        n_kv_heads = 16,
        hidden_dim = 1152,
        norm_eps = 1e-05,
        cross_attn_dim = 4096,
        mlp_ratio = 3.5,
        **kwargs
    )

if __name__ == "__main__":
    video_latent_list = []
    device = "cuda"
    b = 3
    for i in range(b):
        video_latent = torch.randn((4, 16, 16, 16), device=device)
        video_latent_list.append(video_latent)
    images_latent_list = []
    for i in range(b):
        image_latent_list = []
        for j in range(4):
            image_latent = torch.randn((4, 14, 14), device=device)
            image_latent_list.append(image_latent)
        images_latent_list.append(image_latent_list)
    
    model = Latte_S_2(positional_embeddings='rope2d', moe={
        "num_experts": 4,
        "num_experts_per_tok": 2
    },
        norm_layer = AdaRMSNorm
        ).to(device)
    t = torch.randn(3, device=device)
    content = torch.randn(3, 128, 4096, device=device)
    content_mask = torch.ones(3, 128, device=device)
    video_pred_list, images_pred_list = model(video_latent_list, images_latent_list, t, content, content_mask)

    for i, (video_pred, video_latent) in enumerate(zip(video_pred_list, video_latent_list)):
        if video_pred.shape != video_latent.shape:
            print(f"Video latent shape mismatch at index {i}: expected {video_latent.shape}, got {video_pred.shape}")

    for batch_index, (images_pred_batch, images_latent_batch) in enumerate(zip(images_pred_list, images_latent_list)):
        for image_index, (image_pred, image_latent) in enumerate(zip(images_pred_batch, images_latent_batch)):
            if image_pred.shape != image_latent.shape:
                print(f"Image latent shape mismatch at batch {batch_index}, index {image_index}: expected {image_latent.shape}, got {image_pred.shape}")