from re import L
from turtle import forward
from typing import Iterable, Dict, List
from attr import has

import torch
from einops import rearrange, repeat
from torch import Tensor
from torch import nn
from torch.nn import Identity
from custom_model.common_models import Reshape
from perceiver_pytorch.caching import cache_by_name_fn
from perceiver_pytorch.modalities import InputModality, modality_encoding
from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, cache_fn, fourier_encode, \
    FeedForwardGELU
from perceiver_pytorch.common import build_perceiver_layers
from perceiver_pytorch.perceiver_pytorch import default, checkpoint, partial, einsum, exists
from timm.models.vision_transformer import Attention
import copy
from fmoe.gates import NoisyGate, NaiveGate
from custom_model.vmoe_module import VMoETransformerAttentionQKV, VMoETransformerMLP, VMoETransformerMLPUnlimitCapacity, VMoETransformerSeperateQKV, NoisyVMoEGate, AttentionWithPrint

from timm.models.vision_transformer import VisionTransformer
from thop import profile, clever_format

# from timm.models.vision_transformer import Attention

GATES = {
    'NoisyGate': NoisyGate,
    'NoisyVMoEGate': NoisyVMoEGate
}

from custom_model.vmoe_module import MultiModalityConfig

def modality_encoding(batch_size: int, axes, modality_index: int, num_modalities: int, embed=None,
                      device=torch.device('cpu')) -> Tensor:
    """
    Return one-hot encoding of modality given num_modalities, batch size and axes.
    The result need to be compatible with the modality data for concatenation.
    :param modality_index:
    :param num_modalities:
    :return:
    """
    #modality_index=0
    if embed is None:
        one_hot = torch.eye(num_modalities, num_modalities, device=device)[modality_index]
    else:
        one_hot=embed[modality_index]
    to_expand = [batch_size]
    one_hot = one_hot.unsqueeze(0)
    for i, axis in enumerate(axes):
        one_hot = one_hot.unsqueeze(0)
        to_expand.append(axis)
    if embed is None:
        to_expand.append(num_modalities)
    else:
        to_expand.append(len(embed[0]))

    one_hot = one_hot.expand(to_expand)
    return one_hot


def findmodalityandindex(ms,mn):
    for i,m in enumerate(ms):
        if mn == m.name:
            return m,i

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = drop

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop_probs)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
    
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class AttentionMoEQKVMerge(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., args = None, top_k=2):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.qkv = VMoETransformerAttentionQKV(args.num_experts, dim, dim * 3, bias = qkv_bias, gate=args.gate, args = args, top_k=2)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def gate_loss(self, task_idx = None, modality_name = None):
        return self.qkv.gate_loss(task_idx, modality_name)
    
    def get_expert_count(self, num_modalities = 2):
        return self.qkv.get_expert_count()
    
    def gate_topk_logits(self, task_idx = None, modality_name = None):
        return self.qkv.gate_topk_logits(task_idx = task_idx, modality_name = modality_name)

    def forward(self, x, task_idx = None, modality_name = None):
        B, N, C = x.shape
        qkv = self.qkv(x, task_idx, modality_name).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class AttentionWithPrintMoE(AttentionWithPrint):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0, args = None):
        super().__init__(query_dim, context_dim, heads, dim_head, dropout)
        inner_dim = dim_head * heads
        # self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_kv = VMoETransformerAttentionQKV(
            args.num_experts, context_dim, inner_dim * 2, bias = False, gate=args.gate, args = args, top_k=args.attn_top_k
        )
        self.args = args
    
    def gate_loss(self, task_idx = None, modality_name = None):
        g_loss = 0.
        for m in modality_name:
            g_loss += self.to_kv.gate_loss(task_idx, m)
        # return self.to_kv.gate_loss(task_idx, modality_name)
        return g_loss
    
    def get_expert_count(self, num_modalities = 2):
        return {'kv': self.to_kv.get_expert_count()}
    
    def get_expert_info(self, num_modalities = 2):
        return self.to_kv.get_expert_count()
    
    def forward(self, x, context=None, mask=None, modality_name = None, task_id = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context, task_idx = task_id, modality_name = modality_name).chunk(2, dim=-1)
        # k = self.k(context, task_idx = task_id, modality_name = modality_name)
        # v = self.v(context, task_idx = task_id, modality_name = modality_name)
        # Cast query and keys to float 32 to avoid instability as attention weights grow
        # during training, per https://twitter.com/tsuname/status/1430653484827697155?s=20
        k = k.float()
        q = q.float()

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = checkpoint(partial(einsum,'b i d, b j d -> b i j'), q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)
        self.printattn = attn
        out = checkpoint(partial(einsum,'b i j, b j d -> b i d'), attn, v)
        # cast back to input type:
        out = out.type(x.dtype)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

class AttentionMoEQKVSeperate(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., args = None, top_k=2):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q = VMoETransformerSeperateQKV(args.num_experts, dim, dim, bias = qkv_bias, gate=args.gate, args = args, top_k=top_k)
        self.k = VMoETransformerSeperateQKV(args.num_experts, dim, dim, bias = qkv_bias, gate=args.gate, args = args, top_k=top_k)
        self.v = VMoETransformerSeperateQKV(args.num_experts, dim, dim, bias = qkv_bias, gate=args.gate, args = args, top_k=top_k)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.modality_topk = None

    def set_topk(self, topk):
        self.q.set_topk(topk)
        self.k.set_topk(topk)
        self.v.set_topk(topk)

    def set_modality_topk(self, modality_topk):
        self.modality_topk = modality_topk

    def gate_loss(self, task_idx = None, modality_name = None):
        return self.q.gate_loss(task_idx, modality_name) + self.k.gate_loss(task_idx, modality_name) + self.v.gate_loss(task_idx, modality_name)
    
    def get_expert_count(self, num_modalities = 2):
        return {'q': self.q.get_expert_count(num_modalities), 'k': self.k.get_expert_count(num_modalities), 'v': self.v.get_expert_count(num_modalities)}

    def gate_topk_logits(self, task_idx = None, modality_name = None):
        return {'q': self.q.gate_topk_logits(task_idx, modality_name), 'k': self.k.gate_topk_logits(task_idx, modality_name), 'v': self.v.gate_topk_logits(task_idx, modality_name)}
        # return self.qkv.gate_topk_logits(task_idx = task_idx, modality_name = modality_name)

    def forward(self, x, task_idx = None, modality_name = None):
        if self.modality_topk is None:
            B, N, C = x.shape
            q = self.q(x, task_idx, modality_name).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
            k = self.k(x, task_idx, modality_name).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
            v = self.v(x, task_idx, modality_name).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
        else:
            B, N, C = x.shape
            modality_list = sorted(self.modality_topk.keys())
            n_modality = len(self.modality_topk)
            n_seq = N // n_modality
            q, k, v = [], [], []
            for i in range(n_modality):
                topk = self.modality_topk[modality_list[i]]
                self.set_topk(topk)
                x_input = x[:, n_seq * i: n_seq * (i+1)]
                tq = self.q(x_input, task_idx, modality_name).reshape(B, n_seq, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
                tk = self.k(x_input, task_idx, modality_name).reshape(B, n_seq, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
                tv = self.v(x_input, task_idx, modality_name).reshape(B, n_seq, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
                q.append(tq)
                k.append(tk)
                v.append(tv)
            q = torch.concat(q, dim=1)
            k = torch.concat(k, dim=1)
            v = torch.concat(v, dim=1)
        # qkv = self.qkv(x, task_idx).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # B, num_heads, N, C//num_heads
        # q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class VanillaTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, args : MultiModalityConfig = None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.args = args
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, task_idx = None, modality_name = None):
        # print(self.attn.q.capacity_per_expert, self.mlp.capacity_per_expert)
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MultiModalityTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, args : MultiModalityConfig = None, attn_top_k=2, mlp_top_k=2):
        super().__init__()
        self.norm1 = norm_layer(dim)
        
        self.args = args
        # self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # self.attn = AttentionMoEQKVMerge(dim, num_heads, qkv_bias, attn_drop, proj_drop=drop, args = args)
        if hasattr(args, 'seperate_qkv'):
            if self.args.attn_use_moe:
                if args.seperate_qkv == True:
                    self.attn = AttentionMoEQKVSeperate(dim, num_heads, qkv_bias, attn_drop, proj_drop=drop, args = args, top_k=attn_top_k)
                else:    
                    self.attn = AttentionMoEQKVMerge(dim, num_heads, qkv_bias, attn_drop, proj_drop=drop, args = args, top_k=attn_top_k)
            else:
                self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        else:
            if self.args.attn_use_moe:
                self.attn = AttentionMoEQKVMerge(dim, num_heads, qkv_bias, attn_drop, proj_drop=drop, args = args, top_k=attn_top_k)
            else:
                self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        if args.unlimited_capacity_on_mlp:
            
            self.mlp = VMoETransformerMLPUnlimitCapacity(num_expert=args.num_experts, 
                                                         d_model=dim, 
                                                         d_hidden=mlp_hidden_dim, 
                                                         activation=act_layer,
                                                         capacity_per_expert=args.capacity_per_expert, 
                                                         args = args, 
                                                         drop=drop, 
                                                         top_k = mlp_top_k)
        else:
            self.mlp = VMoETransformerMLP(num_expert=args.num_experts, 
                                          d_model=dim, 
                                          d_hidden=mlp_hidden_dim, 
                                          activation=act_layer,
                                          capacity_per_expert=args.capacity_per_expert, 
                                          args = args, 
                                          drop=drop, 
                                          top_k = mlp_top_k)
        if not self.args.mlp_use_moe:
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def set_topk(self, topk):
        self.attn.set_topk(topk)
        self.mlp.set_topk(topk)

    def gate_logits(self, task_idx):
        if hasattr(self.mlp, 'gate_logits'):
            return self.mlp.gate_logits(task_idx)
        return None
    
    def gate_topk_logits(self, task_idx, modality_name):
        ret_dict = {}
        if hasattr(self.mlp, 'gate_topk_logits'):
            ret_dict['mlp'] = self.mlp.gate_topk_logits(task_idx, modality_name)
        if hasattr(self.attn, 'gate_topk_logits'):
            # print("HAVE" * 10)
            attn_ret = self.attn.gate_topk_logits(task_idx, modality_name)
            for key in attn_ret:
                ret_dict[key] = attn_ret[key]
        # print(ret_dict)
        # print(type(self.attn))
        return ret_dict

    def gate_loss(self, task_idx, modalities_name = {}):
        # a_loss = self.attn.gate_loss(task_idx, modality_name)
        # m_loss = self.mlp.gate_loss(task_idx, modality_name)
        # print(a_loss, m_loss)
        # return a_loss + m_loss
        if self.args.attn_use_moe:
            attn_loss = None
            if self.args.attn_modality_specific:
                for mn in modalities_name:
                    g_weight = 1
                    if self.args.auto_gate_loss and self.args.gating_loss_map != None:
                        g_weight = self.args.gating_loss_map[mn]
                    # print(g_weight)
                    if attn_loss == None:
                        attn_loss = self.attn.gate_loss(task_idx, mn) * g_weight
                    else:
                        if self.args.modality_joint:
                            continue
                        else:
                            attn_loss += self.attn.gate_loss(task_idx, mn) * g_weight
            else:
                attn_loss = self.attn.gate_loss(task_idx)
        if self.args.mlp_use_moe:
            mlp_loss = None
            if self.args.mlp_modality_specific:
                for mn in modalities_name:
                    g_weight = 1
                    if self.args.auto_gate_loss:
                        g_weight = self.args.gating_loss_map[mn]
                    if mlp_loss == None:
                        mlp_loss = self.mlp.gate_loss(task_idx, mn) * g_weight
                    else:
                        if self.args.modality_joint:
                            continue
                        else:
                            mlp_loss += self.mlp.gate_loss(task_idx, mn) * g_weight
            else:
                mlp_loss = self.mlp.gate_loss(task_idx)
        if self.args.attn_use_moe and self.args.mlp_use_moe:
            return attn_loss + mlp_loss
        elif self.args.attn_use_moe and not self.args.mlp_use_moe:
            return attn_loss
        elif not self.args.attn_use_moe and self.args.mlp_use_moe:
            return mlp_loss
        else:
            return None
        # return self.attn.gate_loss(task_idx, modality_name) + self.mlp.gate_loss(task_idx, modality_name)
    
    def set_contribution(self, args):
        pass

    def set_modality_topk(self, modality_topk):
        self.mlp.modality_topk = modality_topk
        self.attn.modality_topk = modality_topk
    
    def set_capacity(self, batch_size, token_num, modality_num, task_idx):
        if self.args.capacity_ratios is not None:
            capacity_ratio = self.args.capacity_ratios[int(task_idx)]
        else:
            capacity_ratio = self.args.capacity_ratio
        if hasattr(self.attn, 'q'):
            attn_capacity = round(self.args.attn_top_k * int(batch_size) * int(token_num) * capacity_ratio / self.args.num_experts) * modality_num
            if hasattr(self.attn.q, 'set_capacity'):
                self.attn.q.set_capacity(attn_capacity)
                self.attn.k.set_capacity(attn_capacity)
                self.attn.v.set_capacity(attn_capacity)
        if hasattr(self.mlp, 'set_capacity'):
            mlp_capacity = round(self.args.mlp_top_k * int(batch_size) * int(token_num) * capacity_ratio / self.args.num_experts) * modality_num
            self.mlp.set_capacity(mlp_capacity)
    
    def get_expert_info(self, num_modalities = 2):
        ret_dict = {}
        if self.args.co_input:
            if hasattr(self.attn, 'get_expert_count'):
                attn_expert_count = self.attn.get_expert_count(num_modalities = num_modalities)
            else:
                attn_expert_count = None
        else:
            if hasattr(self.attn, 'get_expert_count'):
                attn_expert_count = self.attn.get_expert_count()
            else:
                attn_expert_count = None
        if type(attn_expert_count) is dict:
            for key in attn_expert_count:
                ret_dict['attn_expert_count_{}'.format(key)] = attn_expert_count[key]
        else:
                ret_dict["attn_expert_count"] = attn_expert_count
        if self.args.co_input:
            if hasattr(self.mlp, 'get_expert_count'):
                ret_dict['mlp_expert_count'] = self.mlp.get_expert_count(num_modalities=num_modalities)
            else:
                ret_dict['mlp_expert_count'] = None
        else:
            if hasattr(self.mlp, 'get_expert_count'):
            # layer_name
                ret_dict['mlp_expert_count'] = self.mlp.get_expert_count()
            else:
                ret_dict['mlp_expert_count'] = None
        # print(ret_dict['mlp_expert_count'])
        return ret_dict

    def forward(self, x, task_idx = None, modality_name = None):
        # print(self.attn.q.capacity_per_expert, self.mlp.capacity_per_expert)
        if self.args.attn_use_moe:
            x = x + self.attn(self.norm1(x), task_idx, modality_name)
        else:
            x = x + self.attn(self.norm1(x))
        if self.args.mlp_use_moe:
            x = x + self.mlp(self.norm2(x), task_idx, modality_name)
        else:
            x = x + self.mlp(self.norm2(x))
        return x




# An implementation of Perceiver that can accept multiple data modalities in the same forward.
class MultiModalityMoETransformer(nn.Module):
    def __init__(
            self,
            *,
            modalities: Iterable[InputModality],
            depth,
            modalities_num = [2],
            mmoe_attn_dim = 64,
            num_latents=512,
            latent_dim=512,
            cross_heads=1,
            latent_heads=8,
            cross_dim_head=64,
            latent_dim_head=64,
            num_classes=None,
            attn_dropout=0.,
            ff_dropout=0.,
            embed=False,
            embed_size=10,
            weight_tie_layers=False,
            num_latent_blocks_per_layer=1,
            use_gelu: bool = False,
            cross_depth=2,
            cross_cross_heads=4,
            recon=None,
            args = None
    ):
        """
        :param modalities:
        :param depth: Number of times the perceiver will perform cross-attention between latent and input.
        :param num_latents:
        :param latent_dim:
        :param cross_heads:
        :param latent_heads:
        :param cross_dim_head:
        :param latent_dim_head:
        :param num_classes: Number of classes to predict, or if None, return the hidden state (num latents x hidden_dim)
        :param attn_dropout:
        :param ff_dropout:
        :param weight_tie_layers: True: share weights across layers, False no shared weights.
        :param num_latent_blocks_per_layer: Number of blocks in the latent transformer.
        :param use_gelu: Use GELU activation like the Perceiver preprint indicates. False,
               with Lucidrains' GEGLU activation in feed forward instead.
        """
        super().__init__()
        self.modalities = modalities
        self.embed_size=embed_size
        # we encode modality with one hot encoding, so need one dim per modality:
        modality_encoding_dim = sum([1 for _ in modalities])
        nummodalities = modality_encoding_dim
        if embed:
            modality_encoding_dim=embed_size
        self.modality_encoding_dim=modality_encoding_dim
        # input_dim is the maximum dimension over all input modalities:
        input_dim = max(modality.input_dim for modality in modalities) + modality_encoding_dim
        self.max_modality_dim = input_dim
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
        ff_type = FeedForwardGELU if use_gelu else FeedForward
        self.embed=None
        if embed:
            self.embed = torch.nn.Parameter(torch.randn(nummodalities,embed_size))
        get_cross_attn = lambda: PreNorm(latent_dim,
                                         AttentionWithPrint(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head,
                                                   dropout=attn_dropout), context_dim=input_dim)
        get_cross_cross_attn = lambda: PreNorm(latent_dim,
                                         AttentionWithPrint(latent_dim, latent_dim, heads=cross_cross_heads, dim_head=cross_dim_head,
                                                   dropout=attn_dropout), context_dim=latent_dim)
        get_cross_ff = lambda: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout))
        get_latent_attn = lambda: PreNorm(latent_dim,
                                          AttentionWithPrint(latent_dim, heads=latent_heads, dim_head=latent_dim_head,
                                                    dropout=attn_dropout))
        get_latent_ff = lambda: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout))

        get_cross_attn, get_cross_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_by_name_fn, (
            get_cross_attn,get_cross_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])

        build_perceiver_layers(self.layers, depth, get_cross_attn, get_cross_ff,
                               get_latent_attn, get_latent_ff,
                               weight_tie_layers,
                               num_latent_blocks_per_layer=num_latent_blocks_per_layer)
        self.to_logits = nn.Sequential(
            nn.LayerNorm(latent_dim*2),
            nn.Linear(latent_dim*2, num_classes)
        )

        self.feature_fusions = nn.ModuleDict([])
        for i in range(len(modalities_num)):
            self.feature_fusions[str(i)] = nn.Linear(latent_dim * modalities_num[i], mmoe_attn_dim, bias = True)
        # self.cross_layers = nn.ModuleList([])
        # build_perceiver_layers(self.cross_layers, cross_depth, get_cross_cross_attn, get_cross_ff,
        #                        get_latent_attn, get_latent_ff,
        #                        weight_tie_layers,
        #                        num_latent_blocks_per_layer=num_latent_blocks_per_layer)
        
        # self.recon=recon
        self.blocks = nn.ModuleList([
            MultiModalityTransformerBlock(mmoe_attn_dim, 8, 2, True, args = args, drop = ff_dropout, attn_drop=attn_dropout) for _ in range(cross_depth)
        ])

    # def 
    
    def itera_add(self, args):
        """Only use to debugging

        Args:
            args (_type_): _description_
        """
        for i in range(len(self.blocks)):
            self.blocks[i].attn.qkv.experts.args = args
            self.blocks[i].mlp.experts.args = args

    def gate_loss(self, task_idx):
        g_loss = None
        for i in range(len(self.blocks)):
            if g_loss == None:
                g_loss = self.blocks[i].gate_loss(task_idx)
            else:
                g_loss += self.blocks[i].gate_loss(task_idx)
        return g_loss

    def forward(self, multi_modality_data: Dict[str, Tensor], mask=None, use_recon=False, task_id = 0):
        """
        :param data: a dictionary where keys are modality names and Tensor contain a batch
        of modality input data.
        :param mask:
        :return:
        """
        batch_sizes = set()
        num_modalities = len(self.modalities)
        linearized_data = []
        linearized_data_per_layer: Dict[int, List[Tensor]] = {}
        latentout=[]
        self.attns={}
        for _, modality_name in enumerate(sorted(multi_modality_data.keys())):
            #assert modality_name in self.modalities, f"modality {modality_name} was not defined in constructor"
            data = multi_modality_data[modality_name]
            # print(self.modalities, modality_name)
            modality,modality_index = findmodalityandindex(self.modalities,modality_name)
            #print(data.shape)
            b, *axis, _, device = *data.shape, data.device
            # print(modality_name, data.shape,b, axis)
            assert len(
                axis) == modality.input_axis, f'input data must have the right number of  for modality {modality_name}. ' \
                                              f'Expected {modality.input_axis} while forward argument offered {len(axis),data.shape,b, axis}'
            batch_sizes.add(b)
            assert len(batch_sizes) == 1, "batch size must be the same across all modalities"
            # calculate fourier encoded positions in the range of [-1, 1], for all axis

            axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis))
            pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
            enc_pos = fourier_encode(pos,
                                     modality.max_freq, modality.num_freq_bands, modality.freq_base)
            enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
            enc_pos = repeat(enc_pos, '... -> b ...', b=b)
            #print(enc_pos.size())

            # Figure out padding for this modality, given max dimension across all modalities:
            padding_size = self.max_modality_dim - modality.input_dim - self.modality_encoding_dim

            padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).to(device)
            # concat to channels of data and flatten axis
            modality_encodings = modality_encoding(b, axis, modality_index, num_modalities, embed=self.embed, device=device)

            #print(modality_encodings.size())

            to_concat = (data, padding, enc_pos, modality_encodings)


            data = torch.cat(to_concat, dim=-1)
            #print(data.size())
            data = rearrange(data, 'b ... d -> b (...) d')
            #print(data.size())
            #print(data.size())
            #linearized_data.append(data)
        
            b = batch_sizes.pop()
            x = repeat(self.latents, 'n d -> b n d', b=b)
        
            # print(modality_name, x.shape, data.shape)
            # Concatenate all the modalities:
            #data = torch.cat(linearized_data, dim=1)

            for cross_attn, cross_ff, latent_transformer in self.layers:
                x = cross_attn(x, context=data, mask=mask) + x
                # print(x.shape)
                self.attns[modality_name]=cross_attn.fn.printattn
                x = cross_ff(x) + x
                x = latent_transformer(x) + x
            #x = self.pool(x)
            latentout.append(x)
        concat_feature = torch.cat(latentout, dim = -1)
        fusion_feature = self.feature_fusions[str(task_id)](concat_feature)
        # print(fusion_feature.shape)
        # exit()
        for i in range(len(self.blocks)):
            fusion_feature = self.blocks[i](fusion_feature, task_id)
        # outs=[]
        # for i in range(len(latentout)):
        #     for j in range(len(latentout)):
        #         if i==j:
        #             continue
        #         x=latentout[i]
        #         context=latentout[j]
        #         for cross_attn, cross_ff, latent_transformer in self.cross_layers:
        #             x = cross_attn(x, context=context, mask=mask) + x
        #             x = cross_ff(x) + x
        #             x = latent_transformer(x) + x
        #         outs.append(x[:,-1])
        # print(fusion_feature.shape)
        # exit()
        # if len(outs)==0:
        #     catted = latentout[0].flatten(start_dim=1)
        # else:
        #     catted=torch.cat(outs,dim=1)
        # if (self.recon is not None) and use_recon:
        #     return self.to_logits(catted),self.recon(catted)
        task_feature = fusion_feature[:, -1]
        return self.to_logits(task_feature)
    
class MultiModalitySequenceMoETransformer(nn.Module):
    def __init__(
            self,
            *,
            modalities: Iterable[InputModality],
            depth,
            modalities_num = [2],
            mmoe_attn_dim = 64,
            num_latents=512,
            latent_dim=512,
            cross_heads=1,
            latent_heads=8,
            cross_dim_head=64,
            latent_dim_head=64,
            num_classes=None,
            attn_dropout=0.,
            ff_dropout=0.,
            embed=False,
            embed_size=10,
            weight_tie_layers=False,
            num_latent_blocks_per_layer=1,
            use_gelu: bool = False,
            cross_depth=2,
            cross_cross_heads=4,
            recon=None,
            args : MultiModalityConfig = None
    ):
        """
        :param modalities:
        :param depth: Number of times the perceiver will perform cross-attention between latent and input.
        :param num_latents:
        :param latent_dim:
        :param cross_heads:
        :param latent_heads:
        :param cross_dim_head:
        :param latent_dim_head:
        :param num_classes: Number of classes to predict, or if None, return the hidden state (num latents x hidden_dim)
        :param attn_dropout:
        :param ff_dropout:
        :param weight_tie_layers: True: share weights across layers, False no shared weights.
        :param num_latent_blocks_per_layer: Number of blocks in the latent transformer.
        :param use_gelu: Use GELU activation like the Perceiver preprint indicates. False,
               with Lucidrains' GEGLU activation in feed forward instead.
        """
        super().__init__()
        self.modalities = modalities
        self.embed_size=embed_size
        # we encode modality with one hot encoding, so need one dim per modality:
        modality_encoding_dim = sum([1 for _ in modalities])
        nummodalities = modality_encoding_dim
        if embed:
            modality_encoding_dim=embed_size
        self.modality_encoding_dim=modality_encoding_dim
        # input_dim is the maximum dimension over all input modalities:
        input_dim = max(modality.input_dim for modality in modalities) + modality_encoding_dim
        self.max_modality_dim = input_dim
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
        ff_type = FeedForwardGELU if use_gelu else FeedForward
        self.embed=None
        print(input_dim)
        if embed:
            self.embed = torch.nn.Parameter(torch.randn(nummodalities,embed_size))
        if args.cross_attn_use_moe:
            get_cross_attn = lambda: PreNorm(latent_dim,
                                            AttentionWithPrintMoE(latent_dim, 
                                                                  input_dim, 
                                                                  heads=cross_heads, 
                                                                  dim_head=cross_dim_head, 
                                                                  dropout=attn_dropout,
                                                                  args = args), 
                                            context_dim=input_dim)
        else:
            get_cross_attn = lambda: PreNorm(latent_dim,
                                            AttentionWithPrint(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout), context_dim=input_dim)
        get_cross_cross_attn = lambda: PreNorm(latent_dim,
                                         AttentionWithPrint(latent_dim, latent_dim, heads=cross_cross_heads, dim_head=cross_dim_head,
                                                   dropout=attn_dropout), context_dim=latent_dim)
        get_cross_ff = lambda: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout))
        get_latent_attn = lambda: PreNorm(latent_dim,
                                          AttentionWithPrint(latent_dim, heads=latent_heads, dim_head=latent_dim_head,
                                                    dropout=attn_dropout))
        get_latent_ff = lambda: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout))

        get_cross_attn, get_cross_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_by_name_fn, (
            get_cross_attn,get_cross_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])

        build_perceiver_layers(self.layers, depth, get_cross_attn, get_cross_ff,
                               get_latent_attn, get_latent_ff,
                               weight_tie_layers,
                               num_latent_blocks_per_layer=num_latent_blocks_per_layer)
        self.to_logits = nn.Sequential(
            nn.LayerNorm(latent_dim*2),
            nn.Linear(latent_dim*2, num_classes)
        )
        if args.conditional_weight:
            self.weights_net = nn.ParameterList([nn.Parameter(torch.tensor(-4.5, dtype=torch.float32)) for _ in range(args.num_tasks)])
        
        self.modalities_num = modalities_num
        if args.equal_dense:
            self.blocks = nn.ModuleList([
                MultiModalityTransformerBlock(latent_dim, 8, 8, True, args = args, attn_top_k=args.attn_top_k, mlp_top_k=args.mlp_top_k) if i % 2 == 0 else VanillaTransformerBlock(latent_dim*2, 8, 2, True, args = args) for i in range(cross_depth)
            ])
        else:
            self.blocks = nn.ModuleList([
                MultiModalityTransformerBlock(latent_dim, 8, 2, True, args = args, attn_top_k=args.attn_top_k, mlp_top_k=args.mlp_top_k) if i % 2 == 0 else VanillaTransformerBlock(latent_dim, 8, 2, True, args = args) for i in range(cross_depth)
            ])
        self.expert_info = None
        self.args = args
        if self.args.padding_prompt:
            self.modality_prompt = {}
            for md in self.modalities:
                self.modality_prompt[md.name] = nn.Parameter(torch.randn(self.max_modality_dim - md.input_dim - self.modality_encoding_dim))
            self.modality_prompt = nn.ParameterDict(self.modality_prompt)
        # if self.args.task_contrastive:
        #     self.task_cache = {}
        #     self.task_proj = {}
        #     for i in range(len(self.modalities_num)):
        #         self.task_proj[i] = nn.Parameter(torch.randn(self.modalities_num[i] * latent_dim, latent_dim))

        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def set_expert_capacity(self, capacity):
        self.args.capacity_per_expert = capacity

    def gate_logits(self, task_idx):
        ret_dict = {}
        for i in range(len(self.blocks)):
            if hasattr(self.blocks[i], 'gate_logits'):
                ret_dict[i] = self.blocks[i].gate_logits(task_idx)
        return ret_dict
    
    def gate_topk_logits(self, task_idx, modality_name):
        ret_dict = {}
        for i in range(len(self.blocks)):
            if hasattr(self.blocks[i], 'gate_topk_logits'):
                ret_dict[i] = self.blocks[i].gate_topk_logits(task_idx,modality_name)
        return ret_dict

    def gate_loss(self, task_idx = None, modalies_name = None):
        g_loss = None
        if self.args.attn_modality_specific or self.args.mlp_modality_specific:
            for i in range(len(self.blocks)):
                if hasattr(self.blocks[i], 'get_expert_info'):
                    if g_loss == None:
                        g_loss = self.blocks[i].gate_loss(task_idx, modalies_name)
                    else:
                        g_loss += self.blocks[i].gate_loss(task_idx, modalies_name)
        else:
            for i in range(len(self.blocks)):
                if hasattr(self.blocks[i], 'get_expert_info'):
                    if g_loss == None:
                        g_loss = self.blocks[i].gate_loss(task_idx)
                    else:
                        g_loss += self.blocks[i].gate_loss(task_idx)
        if self.args.cross_attn_use_moe:
            for cross_attn, _, _ in self.layers:
                g_loss += cross_attn.fn.gate_loss(task_idx=task_idx, modality_name = modalies_name)
        return g_loss

    def forward(self, multi_modality_data: Dict[str, Tensor], task_id = 0, mask=None, modality_topk = None):
        """
        :param data: a dictionary where keys are modality names and Tensor contain a batch
        of modality input data.
        :param mask:
        :return:
        """
        batch_sizes = set()
        num_modalities = len(self.modalities)
        linearized_data = []
        linearized_data_per_layer: Dict[int, List[Tensor]] = {}
        latentout=[]
        self.attns={}
        
        one_step_info = None
        if self.args.load_expert_count:
            # self.expert_info = []
            one_step_info = {}
        elif self.expert_info is not None:
            self.expert_info = None
        for im, modality_name in enumerate(sorted(multi_modality_data.keys())):
            #assert modality_name in self.modalities, f"modality {modality_name} was not defined in constructor"
            data = multi_modality_data[modality_name]
            # print(self.modalities, modality_name)
            modality,modality_index = findmodalityandindex(self.modalities,modality_name)
            # print(modality_name, data.shape)
            #print(data.shape)
            b, *axis, _, device = *data.shape, data.device
            # print(modality_name, data.shape,b, axis)
            assert len(
                axis) == modality.input_axis, f'input data must have the right number of  for modality {modality_name}. ' \
                                              f'Expected {modality.input_axis} while forward argument offered {len(axis),data.shape,b, axis}'
            batch_sizes.add(b)
            assert len(batch_sizes) == 1, "batch size must be the same across all modalities"
            # calculate fourier encoded positions in the range of [-1, 1], for all axis

            axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis))
            pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
            enc_pos = fourier_encode(pos,
                                     modality.max_freq, modality.num_freq_bands, modality.freq_base)
            enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
            enc_pos = repeat(enc_pos, '... -> b ...', b=b)
            #print(enc_pos.size())

            # Figure out padding for this modality, given max dimension across all modalities:
            # padding_size = self.max_modality_dim - modality.input_dim - self.modality_encoding_dim

            # padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).to(device)
            if self.args.padding_prompt:
                padding = self.modality_prompt[modality_name]
                for i in range(len(data.size()[0:-1])):
                    d = data.size()[0:-1][len(data.size()[0:-1]) - 1 - i]
                    padding = repeat(padding, '... -> b ...', b=d)
                # print(padding.shape)
            else:
                padding_size = self.max_modality_dim - modality.input_dim - self.modality_encoding_dim

                padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).to(device)
                # print(padding.shape)
            # concat to channels of data and flatten axis
            modality_encodings = modality_encoding(b, axis, modality_index, num_modalities, embed=self.embed, device=device)

            #print(modality_encodings.size())
            # print(data.shape, padding.shape, enc_pos.shape, modality_encodings.shape)
            to_concat = (data, padding, enc_pos, modality_encodings)


            data = torch.cat(to_concat, dim=-1)
            #print(data.size())
            data = rearrange(data, 'b ... d -> b (...) d')
            b = batch_sizes.pop()
            if self.args.use_individual_latent_dim:
                x = repeat(self.latents[:self.args.individual_latent_dim[int(task_id)],:], 'n d -> b n d', b=b)
            else:
                x = repeat(self.latents, 'n d -> b n d', b=b)
            ccl_index = 0
            for cross_attn, cross_ff, latent_transformer in self.layers:
                if self.args.cross_attn_use_moe:
                    x = cross_attn(x, context=data, mask=mask, modality_name = modality_name, task_id = task_id) + x

                else:
                    x = cross_attn(x, context=data, mask=mask) + x
                # print(x.shape)
                self.attns[modality_name]=cross_attn.fn.printattn
                x = cross_ff(x) + x
                x = latent_transformer(x) + x
                ccl_index += 1
            #x = self.pool(x)
            latentout.append(x)
        # concat_feature = torch.cat(latentout, dim = -1)
        # fusion_feature = self.feature_fusions[str(task_id)](concat_feature)
        task_feature = []
        
        if self.args.co_input == False:
            modalities_name = sorted(multi_modality_data.keys())
            for j in range(self.modalities_num[int(task_id)]):
                for i in range(len(self.blocks)):
                    fusion_feature = self.blocks[i](latentout[j], task_id, modalities_name[j])
                    if self.args.load_expert_count:
                        if one_step_info is not None:
                            if j not in one_step_info:
                                one_step_info[j] = {}
                            if hasattr(self.blocks[i], 'get_expert_info'):
                                one_step_info[j][i] = self.blocks[i].get_expert_info()
                task_feature.append(fusion_feature[:, -1])
            if self.args.load_expert_count:
                # task_id, modelity, layer, layer_name
                self.expert_info = {task_id :one_step_info}
            task_feature = torch.cat(task_feature, dim=-1)
            
            return self.to_logits(task_feature)
        else:
            # setting expert capacity
            # self.args.capacity_per_expert = self.modalities_num[int(task_id)] * self.args.base_capacity
            for i in range(len(self.blocks)):
                if not hasattr(self.blocks[i], 'set_capacity'):
                    continue

                self.blocks[i].set_capacity(latentout[0].shape[0], latentout[0].shape[1], self.modalities_num[int(task_id)], int(task_id))
            
            # print(len(latentout))
            if self.args.cross_modality_attn:
                fusion_feature = torch.cat(latentout, dim=1)
            else:
                fusion_feature = torch.cat(latentout, dim=0)
            modalities_name = sorted(multi_modality_data.keys())
            for i in range(len(self.blocks)):
                if hasattr(self.blocks[i], 'set_modality_topk'):
                    self.blocks[i].set_modality_topk(modality_topk)
                fusion_feature = self.blocks[i](fusion_feature, task_id, modalities_name)
            # task_feature.append(fusion_feature[:, -1])
            if self.args.cross_modality_attn:
                B, _, D = fusion_feature.shape
                fusion_feature = fusion_feature.reshape(B, self.modalities_num[int(task_id)], -1, D).unbind(dim=1)
                # print(fusion_feature[0].shape, len(fusion_feature), self.modalities_num[int(task_id)])
                fusion_feature = torch.cat(fusion_feature, dim=0)
                # print(fusion_feature.shape)
                task_feature = fusion_feature[:,-1]
            else:
                task_feature = fusion_feature[:,-1]
            # task_feature = fusion_feature[:, -1]
            dim_feature = task_feature.shape[-1]
            task_feature = task_feature.reshape(self.modalities_num[int(task_id)], -1, dim_feature).unbind(0)
            # task_feature = torch.cat(task_feature, dim=-1)
            return self.to_logits(task_feature)
        
class SM3TaskHeads(nn.Module):
    def __init__(self, task, modalities, push_seq_length, device) -> None:
        super().__init__()
    #     torch.nn.Sequential(torch.nn.LayerNorm(64*2),torch.nn.Linear(64*2,20)).to(device),
    # torch.nn.Sequential(torch.nn.LayerNorm(64*2),torch.nn.Linear(64*2,10)).to(device),
    # torch.nn.Sequential(torch.nn.LayerNorm(64*4),torch.nn.Linear(64 *4,args.push_seq_length * 2),  Reshape([-1, args.push_seq_length, 2])).to(device)])
        module_list = []
        self.modalities = sorted(modalities)
        self.modality_dict = nn.ModuleDict()
        self.modality_weight = {mm:1 for mm in self.modalities}

        if task == 'push':
            module_list = [torch.nn.LayerNorm(64*4),torch.nn.Linear(64 *4, push_seq_length * 2),  Reshape([-1, push_seq_length, 2])]
            for mm in self.modalities:
                self.modality_dict[mm] = nn.Sequential(
                    torch.nn.LayerNorm(64),torch.nn.Linear(64, push_seq_length * 2),  Reshape([-1, push_seq_length, 2])
                )

        elif task == 'enrico':
            module_list = [torch.nn.LayerNorm(64*2),torch.nn.Linear(64*2,20)]
            for mm in self.modalities:
                self.modality_dict[mm] = nn.Sequential(
                    torch.nn.LayerNorm(64),torch.nn.Linear(64,20)
                )

        elif task == 'av_mnist':
            module_list = [torch.nn.LayerNorm(64*2),torch.nn.Linear(64*2,10)]
            for mm in self.modalities:
                self.modality_dict[mm] = nn.Sequential(
                    torch.nn.LayerNorm(64),torch.nn.Linear(64,10)
                )
        elif task == 'mosei' or task == 'humor':
            module_list = [torch.nn.LayerNorm(64*3),torch.nn.Linear(64*3,2)]
            for mm in self.modalities:
                self.modality_dict[mm] = nn.Sequential(
                    torch.nn.LayerNorm(64),torch.nn.Linear(64,2)
                )
        elif task == 'mimic':
            module_list = [torch.nn.LayerNorm(64*2),torch.nn.Linear(64*2,2)]
            for mm in self.modalities:
                self.modality_dict[mm] = nn.Sequential(
                    torch.nn.LayerNorm(64),torch.nn.Linear(64,2)
                )
        elif task == 'vt':
            module_list = [torch.nn.LayerNorm(64*5),torch.nn.Linear(64*5,2)]
            for mm in self.modalities:
                self.modality_dict[mm] = nn.Sequential(
                    torch.nn.LayerNorm(64),torch.nn.Linear(64,2)
                )
        

        self.multi_modal_head = torch.nn.Sequential(*module_list)

    def set_modality_weight(self, weights):
        self.modality_weight = copy.deepcopy(weights)

    def forward(self, feature):

        mm_result = self.multi_modal_head(torch.concat(feature, dim=-1))
        sm_result = {}
        for i in range(len(feature)):
            mm = self.modalities[i]
            weight = self.modality_weight[mm]
            sm_result[mm] = self.modality_dict[mm](feature[i]) * weight

        return mm_result, sm_result
        
class CNNFeatures(nn.Module):
    def __init__(self, modalities) -> None:
        super().__init__()
        self.features = nn.ModuleDict({
            key: nn.Sequential(nn.Conv2d(3, 32, kernel_size=3), nn.MaxPool2d(3),
                                  nn.Conv2d(32, 32, kernel_size=3), nn.MaxPool2d(3)) for key in modalities
        })
        
    def forward(self, multi_modality_data: Dict[str, Tensor]):
        for _, modality_name in enumerate(sorted(multi_modality_data.keys())):
            multi_modality_data[modality_name] = self.features[modality_name](multi_modality_data[modality_name])
            # B, _, _, D = multi_modality_data[modality_name].shape
            # multi_modality_data[modality_name] = multi_modality_data[modality_name].reshape(B, -1, D)
            # print(multi_modality_data[modality_name].shape, modality_name)
        return multi_modality_data
    

