""" Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929

The official jax code is released and available at https://github.com/google-research/vision_transformer

Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.

Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert

Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
from functools import partial

from .task_moe import TaskMoE

import torch
import torch.nn as nn
from functools import partial

from .task_moe import TaskMoE
from .task_moe_refine import TaskMoEGate, TaskMoEFFN, ParallelExperts
import math


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # [B, num_heads, N, C // num_heads]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x



class MixtureOfExpertsAttention(nn.Module):
    def __init__(
        self,
        dim,
        k, 
        num_experts, 
        task_num, 
        expert_bias,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,         
        w_MI=0, w_H=0, 
        w_finetune_MI=0, 
        noisy_gating=True,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads // k
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        
        self.gating_network = TaskMoEGate(dim, 
                                          k, num_experts, task_num, 
                                          w_MI, w_H, w_finetune_MI, 
                                          noisy_gating)
        
        self.qkv = ParallelExperts(num_experts, dim, (dim // k) * 3, expert_bias)
        self.v_out = ParallelExperts(num_experts, dim // k // num_heads, dim, expert_bias)
        
        self.dim = dim
        self.k = k
        self.num_experts = num_experts
        
        

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0):
        B, N, C = x.shape
        
        x = x.reshape(-1, C)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        loss, probs = self.gating_network(x, task_bh, skip_mask, sample_topk)
        
        # batch_index
        # expert_size
        # batch_gates
        
        
        # expert inputs: B*N*k, C. The first dimension is the batch index for expert 1, expert 2, ...
        # the number of inputs for expert 1, expert 2, ... is defined in self.gating_network.expert_size
        expert_inputs = x[self.gating_network.batch_index]
        qkv = self.qkv(expert_inputs, self.gating_network.expert_size)
        q, k, v = qkv.split(self.dim // self.k, dim=-1)
                
        # multiply by gates
        v = v * self.gating_network.batch_gates[:, None]
        
        sorted_indices = self.gating_network.batch_index.argsort()
        
        
        head_expert_size = self.gating_network.expert_size.repeat_interleave(self.num_heads)
        v = v.reshape(-1, self.dim // self.k // self.num_heads)
        v_out = self.v_out(v, head_expert_size)
        v_out = v_out.reshape(-1, self.dim * self.num_heads)
        
        sorted_v_out = v_out[sorted_indices]
        sorted_v_out = sorted_v_out.reshape(B, N, self.k, self.num_heads, self.dim).permute(0,2,3,1,4)
        
        sorted_q = q[sorted_indices]
        sorted_q = sorted_q.reshape(B, N, self.k, self.num_heads, self.dim // self.num_heads // self.k).permute(0,2,3,1,4)
        sorted_k = k[sorted_indices]
        sorted_k = sorted_k.reshape(B, N, self.k, self.num_heads, self.dim // self.num_heads // self.k).permute(0,2,3,1,4)
        
        # attn [B, self.k, self.num_heads, N, N]
        attn = (sorted_q @ sorted_k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # x [B, self.k, self.num_heads, N, self.dim]
        x = (attn @ sorted_v_out)

        # [B, N, self.dim]
        x = x.sum(1).sum(1)

        return x, loss, probs
    

class AttentionExpert(nn.Module):
    def __init__(
        self,
        dim,
        k,
        expert_dim_divisor, 
        num_experts, 
        expert_bias,
        num_heads=8, 
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,         
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads // expert_dim_divisor
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights

        self.scale = head_dim**-0.5

        self.attn_drop = nn.Dropout(attn_drop)
        self.qkv = ParallelExperts(num_experts, dim, (dim // expert_dim_divisor) * 3, expert_bias)
        self.v_out = ParallelExperts(num_experts, dim // expert_dim_divisor // num_heads, dim, expert_bias)
        
        self.k = k
        self.dim = dim
        self.expert_dim_divisor = expert_dim_divisor
        self.num_experts = num_experts
        
        self.proj_drop = nn.Dropout(proj_drop)
        
    def count_parameters(self):
        total_params = sum(p.numel() for p in self.parameters())
        active_params = total_params / self.num_experts * self.k
        return {
            'total_params': total_params,
            'active_params': active_params
        }
        
    def add_experts(self, n_extra_experts, freeze_old_experts=True):
        self.qkv.add_experts(n_extra_experts, freeze_old_experts)
        self.v_out.add_experts(n_extra_experts, freeze_old_experts)

    def forward(self, x, gate_infos):
        batch_index = gate_infos['batch_index']
        expert_size = gate_infos['expert_size']
        batch_gates = gate_infos['batch_gates']
        B = gate_infos['batch_size']
        N = gate_infos['token_length']        
        
        # expert inputs: B*N*k, C. The first dimension is the batch index for expert 1, expert 2, ...
        # the number of inputs for expert 1, expert 2, ... is defined in self.gating_network.expert_size
        expert_inputs = x[batch_index]
        qkv = self.qkv(expert_inputs, expert_size)
        q, k, v = qkv.split(self.dim // self.expert_dim_divisor, dim=-1)
                
        # multiply by gates
        v = v * batch_gates[:, None]
        head_expert_size = expert_size * self.num_heads
        v = v.reshape(-1, self.dim // self.expert_dim_divisor // self.num_heads)
        v_out = self.v_out(v, head_expert_size)
        v_out = v_out.reshape(-1, self.dim * self.num_heads)
        v_out = self.proj_drop(v_out)
        
        sorted_indices = batch_index.argsort()
        sorted_v_out = v_out[sorted_indices]
        sorted_v_out = sorted_v_out.reshape(B, N, self.k, self.num_heads, self.dim).permute(0,2,3,1,4)
        
        sorted_q = q[sorted_indices]
        sorted_q = sorted_q.reshape(B, N, self.k, self.num_heads, self.dim // self.num_heads // self.expert_dim_divisor).permute(0,2,3,1,4)
        sorted_k = k[sorted_indices]
        sorted_k = sorted_k.reshape(B, N, self.k, self.num_heads, self.dim // self.num_heads // self.expert_dim_divisor).permute(0,2,3,1,4)
        
        # attn [B, self.k, self.num_heads, N, N]
        attn = (sorted_q @ sorted_k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # x [B, self.k, self.num_heads, N, self.dim]
        x = (attn @ sorted_v_out)

        # [B, N, self.dim]
        x = x.sum(1).sum(1)

        return x

    def merge_experts(self):
        """
        Merges new_weight into weight and removes new_weight.
        Should be called after loading the checkpoint of continual training.
        """
        self.qkv.merge_experts()
        self.v_out.merge_experts()


class MoEAttn(nn.Module):
    def __init__(
        self,
        dim,
        k, 
        num_experts, 
        task_num, 
        expert_bias,
        expert_dim_divisor,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,         
        w_MI=0, w_H=0, 
        w_finetune_MI=0, 
        noisy_gating=True,
    ):
        super().__init__()

        self.dim = dim
        self.k = k
        self.num_experts = num_experts
        self.task_num = task_num
        self.expert_bias = expert_bias
        self.expert_dim_divisor = expert_dim_divisor
        self.num_heads = num_heads
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.attn_drop = attn_drop
        self.proj_drop = proj_drop
        self.w_MI = w_MI
        self.w_H = w_H
        self.w_finetune_MI = w_finetune_MI
        self.noisy_gating = noisy_gating
        
        self.gating_network = TaskMoEGate(self.dim, 
                                          self.k, 
                                          self.num_experts, 
                                          self.task_num, 
                                          self.w_MI, 
                                          self.w_H, 
                                          self.w_finetune_MI, 
                                          self.noisy_gating)

        self.attention_expert = AttentionExpert(self.dim, 
                                                self.k, 
                                                self.expert_dim_divisor, 
                                                self.num_experts, 
                                                self.expert_bias, 
                                                self.num_heads, 
                                                self.qkv_bias, 
                                                self.qk_scale, 
                                                self.attn_drop, 
                                                self.proj_drop)
        
        

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0):
        B, N, C = x.shape
        
        x = x.reshape(-1, C)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        loss, probs = self.gating_network(x, task_bh, skip_mask, sample_topk)
        
        gate_infos = self.gating_network.get_gate_infos()
        gate_infos.update({'batch_size': B, 'token_length': N, 'dim': C})
        
        x = self.attention_expert(x, gate_infos)

        return x, loss, probs
    
    
class FfnExpert(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self,  dim, hidden_dim, expert_dim_divisor, 
                 num_experts, task_num, expert_bias, expert_activation = nn.GELU):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.expert_dim_divisor = expert_dim_divisor
        self.num_experts = num_experts
        self.task_num = task_num
        self.expert_bias = expert_bias
        
        self.experts = ParallelExperts(self.num_experts, self.dim, self.hidden_dim // self.expert_dim_divisor, self.expert_bias)
        self.expert_activation = expert_activation()
        self.output_experts = ParallelExperts(self.num_experts, self.hidden_dim // self.expert_dim_divisor, self.dim, self.expert_bias)
        
    
    def count_parameters(self):
        total_params = sum(p.numel() for p in self.parameters())
        return {
            'total_params': total_params,
            'active_params': 0
        }

    def forward(self, x, gate_infos):
        batch_index = gate_infos['batch_index']
        expert_size = gate_infos['expert_size']
        batch_gates = gate_infos['batch_gates']
        B = gate_infos['batch_size']
        N = gate_infos['token_length']   
        
        # expert_size: [used time of expert 1, used time of expert 2, ...] the number of times each expert is used
        # batch_index: size = sum(expert_size), batch index of input for each expert
        '''
        # Example: 
        # input for expert 1: [0, 1], input for expert 2: [0, 1, 2], input for expert 3: [2]
        # batch_index = ['0, 1' for expert 1, '0, 1, 2' for expert 2, '2' for expert 3]
        
        # expert_size = [0, 2, 3, 1, 0]
        # batch_index = [0, 1, 0, 1, 2, 2]
        '''
        expert_inputs = x[batch_index]
        h = self.experts(expert_inputs, expert_size)
        h = self.expert_activation(h)
        expert_outputs = self.output_experts(h, expert_size)
        
        
        # if self.exp_layernorm exists:
        if hasattr(self, 'exp_layernorm'):
            print('???exper norm?')
            norm_expert_outputs = []
            for i in range(expert_size.shape[0]):
                start_id = torch.sum(expert_size[:i])
                end_id = torch.sum(expert_size[:i+1])
                norm_expert_outputs.append(self.exp_layernorm[i](expert_outputs[start_id:end_id]))
            
            expert_outputs = torch.cat(norm_expert_outputs, dim=0)

        expert_outputs = expert_outputs * batch_gates[:, None]

        zeros = torch.zeros((B * N, self.dim), 
            dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, batch_index, expert_outputs)
        y = y.view(B, N, self.dim)
        
        return y
    
    def add_experts(self, n_extra_experts, freeze_old_experts=True):
        if n_extra_experts > 0:
            self.experts.add_experts(n_extra_experts, freeze_old_experts)
            self.output_experts.add_experts(n_extra_experts, freeze_old_experts)

    def merge_experts(self):
        """
        Merges new_weight into weight and removes new_weight.
        Should be called after loading the checkpoint of continual training.
        """
        self.experts.merge_experts()
        self.output_experts.merge_experts()
        
        
    def add_exp_layernorm(self):
        self.exp_layernorm = nn.ModuleList([nn.LayerNorm(self.dim) for _ in range(self.num_experts)])



class MoEFfn(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        k, 
        num_experts, 
        task_num, 
        expert_bias,
        expert_dim_divisor,
        num_heads=8,       
        w_MI=0, w_H=0, 
        w_finetune_MI=0, 
        noisy_gating=True,
    ):
        super().__init__()

        self.dim = dim
        self.hidden_dim = hidden_dim
        self.k = k
        self.num_experts = num_experts
        self.task_num = task_num
        self.expert_bias = expert_bias
        self.expert_dim_divisor = expert_dim_divisor
        self.num_heads = num_heads
        self.w_MI = w_MI
        self.w_H = w_H
        self.w_finetune_MI = w_finetune_MI
        self.noisy_gating = noisy_gating
        
        self.gating_network = TaskMoEGate(dim = self.dim, 
                                          k = self.k, 
                                          num_experts = self.num_experts, 
                                          task_num = self.task_num, 
                                          w_MI = self.w_MI, 
                                          w_H = self.w_H, 
                                          w_finetune_MI = self.w_finetune_MI, 
                                          noisy_gating = self.noisy_gating)

        self.ffn_expert = FfnExpert(dim = self.dim, 
                                    hidden_dim = self.hidden_dim,
                                    expert_dim_divisor = self.expert_dim_divisor, 
                                    num_experts = self.num_experts, 
                                    task_num = self.task_num, 
                                    expert_bias = self.expert_bias)
        
        

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0):
        B, N, C = x.shape
        
        x = x.reshape(-1, C)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        loss, probs = self.gating_network(x, task_bh, skip_mask, sample_topk)
        
        gate_infos = self.gating_network.get_gate_infos()
        gate_infos.update({'batch_size': B, 'token_length': N, 'dim': C})
        
        x = self.ffn_expert(x, gate_infos)

        return x, loss, probs


class MoETransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        attn_k, 
        ffn_k, 
        attn_num_experts, 
        ffn_num_experts, 
        task_num, 
        attn_expert_bias,
        ffn_expert_bias,
        attn_expert_dim_divisor,
        ffn_expert_dim_divisor,
        ffn_hidden_dim,
        shared_routers=False,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0, 
        w_MI=0, w_H=0, 
        w_finetune_MI=0, 
        noisy_gating=True,
        drop_path=0.0,
        norm_layer=nn.LayerNorm,
        **kwargs,
    ):
        super().__init__()

        self.dim = dim
        self.attn_k = attn_k
        self.ffn_k = ffn_k
        self.attn_num_experts = attn_num_experts
        self.ffn_num_experts = ffn_num_experts
        self.task_num = task_num
        self.attn_expert_bias = attn_expert_bias
        self.ffn_expert_bias = ffn_expert_bias
        self.attn_expert_dim_divisor = attn_expert_dim_divisor
        self.ffn_expert_dim_divisor = ffn_expert_dim_divisor
        self.ffn_hidden_dim = ffn_hidden_dim
        self.shared_routers = shared_routers
        self.num_heads = num_heads
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.attn_drop = attn_drop
        self.proj_drop = proj_drop
        self.w_MI = w_MI
        self.w_H = w_H
        self.w_finetune_MI = w_finetune_MI
        self.noisy_gating = noisy_gating
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.attn_norm = norm_layer(dim)
        self.ffn_norm = norm_layer(dim)
        
        # print('dim', self.dim)
        # print('attn_k', self.attn_k)
        # print('ffn_k', self.ffn_k)
        # print('attn_num_experts', self.attn_num_experts)
        # print('ffn_num_experts', self.ffn_num_experts)
        # print('task_num', self.task_num)
        # print('attn_expert_bias', self.attn_expert_bias)
        # print('ffn_expert_bias', self.ffn_expert_bias)
        # print('attn_expert_dim_divisor', self.attn_expert_dim_divisor)
        # print('ffn_expert_dim_divisor', self.ffn_expert_dim_divisor)
        # print('ffn_hidden_dim', self.ffn_hidden_dim)
        # print('shared_routers', self.shared_routers)
        # print('num_heads', self.num_heads)
        # print('qkv_bias', self.qkv_bias)
        # print('qk_scale', self.qk_scale)
        # print('attn_drop', self.attn_drop)
        # print('proj_drop', self.proj_drop)
        # print('w_MI', self.w_MI)
        # print('w_H', self.w_H)
        # print('w_finetune_MI', self.w_finetune_MI)
        # print('noisy_gating', self.noisy_gating)
        # print('drop_path', self.drop_path)
                
        # Ensure dim is divisible by (attn_expert_dim_divisor * num_heads) and ffn_expert_dim_divisor
        assert self.dim % (self.attn_expert_dim_divisor * self.num_heads) == 0, (
            f"dim ({self.dim}) must be divisible by "
            f"(attn_expert_dim_divisor ({self.attn_expert_dim_divisor}) * num_heads ({self.num_heads}))."
        )

        assert self.dim % self.ffn_expert_dim_divisor == 0, (
            f"dim ({self.dim}) must be divisible by ffn_expert_dim_divisor ({self.ffn_expert_dim_divisor})."
        )

        self.attn_gating_network = TaskMoEGate(self.dim, 
                                          self.attn_k, 
                                          self.attn_num_experts, 
                                          self.task_num, 
                                          self.w_MI, 
                                          self.w_H, 
                                          self.w_finetune_MI, 
                                          self.noisy_gating)
        
        if not self.shared_routers:
            self.ffn_gating_network = TaskMoEGate(self.dim, 
                                                  self.ffn_k, 
                                                  self.ffn_num_experts, 
                                                  self.task_num, 
                                                  self.w_MI, 
                                                  self.w_H, 
                                                  self.w_finetune_MI, 
                                                  self.noisy_gating)

        self.attention_expert = AttentionExpert(self.dim, 
                                                self.attn_k, 
                                                self.attn_expert_dim_divisor, 
                                                self.attn_num_experts, 
                                                self.attn_expert_bias, 
                                                self.num_heads, 
                                                self.qkv_bias, 
                                                self.qk_scale, 
                                                self.attn_drop, 
                                                self.proj_drop)
        
        self.ffn_expert = FfnExpert(self.dim, 
                                    self.ffn_hidden_dim,
                                    self.ffn_expert_dim_divisor, 
                                    self.ffn_num_experts, 
                                    self.task_num, 
                                    self.ffn_expert_bias)
        
    def count_parameters(self):
        params = {
            'total_params': 0,
            'active_params': 0
        }
        
        attn_params = self.attn_gating_network.count_parameters()
        params['total_params'] += attn_params['total_params']
        params['active_params'] += attn_params['active_params']
        
        if not self.shared_routers:
            ffn_params = self.ffn_gating_network.count_parameters()
            params['total_params'] += ffn_params['total_params']
            params['active_params'] += ffn_params['active_params']
        
        expert_params = self.attention_expert.count_parameters()
        params['total_params'] += expert_params['total_params']
        params['active_params'] += expert_params['active_params']
        
        expert_params = self.ffn_expert.count_parameters()
        params['total_params'] += expert_params['total_params']
        params['active_params'] += expert_params['total_params'] / self.ffn_num_experts * self.ffn_k
        
        return params
    
    def forward(self, x, task_bh, skip_mask=None, sample_topk=0):
        
        debug = False
        B, N, C = x.shape
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        # Attention
        x = self.attn_norm(x)
        
        flatten_x = x.reshape(-1, C)
        
        if debug:
            print('attn_gating_network')
            
        attn_loss, attn_probs = self.attn_gating_network(flatten_x, task_bh, skip_mask, sample_topk)
        attn_gate_infos = self.attn_gating_network.get_gate_infos()
        attn_gate_infos.update({'batch_size': B, 'token_length': N, 'dim': C})

        if debug:
            print('attn_expert')
        x = x + self.drop_path(self.attention_expert(flatten_x, attn_gate_infos))
        
        # FFN
        x = self.ffn_norm(x)
        flatten_x = x.reshape(-1, C)
        
        if self.shared_routers:
            if debug:
                print('ffn_expert')
            x = x + self.drop_path(self.ffn_expert(flatten_x, attn_gate_infos))
            ffn_loss = None
            ffn_probs = None
        else:
            if debug:
                print('ffn_gating_network')
            ffn_loss, ffn_probs = self.ffn_gating_network(flatten_x, task_bh, skip_mask, sample_topk)
            ffn_gate_infos = self.ffn_gating_network.get_gate_infos()
            ffn_gate_infos.update({'batch_size': B, 'token_length': N, 'dim': C})
            if debug:
                print('ffn_expert')
            x = x + self.drop_path(self.ffn_expert(flatten_x, ffn_gate_infos))

        return x, attn_loss, attn_probs, ffn_loss, ffn_probs
    
    def add_moe_experts(self, n_new_tasks, n_new_experts, freeze_old=True):
        if self.shared_routers:
            self.add_attn_routers(n_new_tasks, n_new_experts, freeze_old)
        else:
            self.add_attn_routers(n_new_tasks, n_new_experts, freeze_old)
            self.add_ffn_routers(n_new_tasks, n_new_experts, freeze_old)
            
        self.add_attention_experts(n_new_experts, freeze_old)
        self.add_ffn_experts(n_new_experts, freeze_old)
        
    
    def add_attention_experts(self, n_new_experts, freeze_old=True):
        self.attention_expert.add_experts(n_new_experts, freeze_old)

    def add_ffn_experts(self, n_new_experts, freeze_old=True):
        self.ffn_expert.add_experts(n_new_experts, freeze_old)
        
    def add_attn_routers(self, n_new_tasks, n_new_experts, freeze_old=True):
        self.attn_gating_network.add_gates(n_new_tasks=n_new_tasks, n_new_experts=n_new_experts, freeze_old=freeze_old)

    def add_ffn_routers(self, n_new_tasks, n_new_experts, freeze_old=True):
        self.ffn_gating_network.add_gates(n_new_tasks=n_new_tasks, n_new_experts=n_new_experts, freeze_old=freeze_old)
        
    def merge_experts(self):
        """
        Merges new_weight into weight and removes new_weight.
        Should be called after loading the checkpoint of continual training.
        """
        self.attention_expert.merge_experts()
        self.ffn_expert.merge_experts()
        
        
class MoEFfnBlock(nn.Module):
    def __init__(
        self,
        dim,
        ffn_k, 
        ffn_num_experts, 
        task_num, 
        ffn_expert_bias,
        ffn_expert_dim_divisor,
        ffn_hidden_dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0, 
        w_MI=0, w_H=0, 
        w_finetune_MI=0, 
        noisy_gating=True,
        drop_path=0.0,
        norm_layer=nn.LayerNorm,
        **kwargs,
    ):
        super().__init__()

        self.dim = dim
        self.ffn_k = ffn_k
        self.ffn_num_experts = ffn_num_experts
        self.task_num = task_num
        self.ffn_expert_bias = ffn_expert_bias
        self.ffn_expert_dim_divisor = ffn_expert_dim_divisor
        self.ffn_hidden_dim = ffn_hidden_dim
        self.num_heads = num_heads
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.attn_drop = attn_drop
        self.proj_drop = proj_drop
        self.w_MI = w_MI
        self.w_H = w_H
        self.w_finetune_MI = w_finetune_MI
        self.noisy_gating = noisy_gating
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.attn_norm = norm_layer(dim)
        self.ffn_norm = norm_layer(dim)
        
        # Ensure dim is divisible by (attn_expert_dim_divisor * num_heads) and ffn_expert_dim_divisor
        assert self.dim % self.ffn_expert_dim_divisor == 0, (
            f"dim ({self.dim}) must be divisible by ffn_expert_dim_divisor ({self.ffn_expert_dim_divisor})."
        )
        
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )


        self.ffn_gating_network = TaskMoEGate(self.dim, 
                                                self.ffn_k, 
                                                self.ffn_num_experts, 
                                                self.task_num, 
                                                self.w_MI, 
                                                self.w_H, 
                                                self.w_finetune_MI, 
                                                self.noisy_gating)

        
        self.ffn_expert = FfnExpert(self.dim, 
                                    self.ffn_hidden_dim,
                                    self.ffn_expert_dim_divisor, 
                                    self.ffn_num_experts, 
                                    self.task_num, 
                                    self.ffn_expert_bias)
        
    def count_parameters(self):
        params = {
            'total_params': 0,
            'active_params': 0
        }
        params['total_params'] += self.ffn_gating_network.count_parameters()['total_params']
        params['active_params'] += self.ffn_gating_network.count_parameters()['active_params']
        params['total_params'] += self.ffn_expert.count_parameters()['total_params']
        params['active_params'] += self.ffn_expert.count_parameters()['total_params'] / self.ffn_num_experts * self.ffn_k
        
        params['total_params'] += sum(p.numel() for p in self.attn.parameters())
        params['active_params'] += sum(p.numel() for p in self.attn.parameters())
        return params

    def get_gate_probs(self):
        return self.ffn_gating_network.get_probs()
    
    def get_gate_clean_logits(self):
        return self.ffn_gating_network.get_clean_logits()
    
    def get_gate_raw_noise_stddev(self):
        return self.ffn_gating_network.get_raw_noise_stddev()
    
    def forward(self, x, task_bh, skip_mask=None, sample_topk=0):
        
        debug = False
        B, N, C = x.shape
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        # Attention
        x = x + self.drop_path(self.attn(self.attn_norm(x)))
       
        # FFN
        norm_x = self.ffn_norm(x)
        flatten_x = norm_x.reshape(-1, C)
        
        if debug:
            print('ffn_gating_network')
            
        ffn_loss, ffn_probs = self.ffn_gating_network(flatten_x, task_bh, skip_mask, sample_topk)
        ffn_gate_infos = self.ffn_gating_network.get_gate_infos()
        ffn_gate_infos.update({'batch_size': B, 'token_length': N, 'dim': C})
        if debug:
            print('ffn_expert')
        
        ffn_output = self.ffn_expert(flatten_x, ffn_gate_infos)
        
        # print('mean, std')
        # print(ffn_output.mean())
        # print(ffn_output.std())
        
        # print(torch.mean(ffn_output.reshape(-1, ffn_output.shape[-1]), dim=0))
        # print(torch.std(ffn_output.reshape(-1, ffn_output.shape[-1]), dim=0))
        
        # print('norm, mean std')
        # print(torch.mean(ffn_output.norm(dim=-1)))
        # print(torch.std(ffn_output.norm(dim=-1)))
        
        x = x + self.drop_path(ffn_output)
        
        return x, None, None, ffn_loss, ffn_probs
    
    def add_moe_experts(self, n_new_tasks, n_new_experts, freeze_old=True, noisy_gating=False, topk=-1):
        self.add_ffn_routers(n_new_tasks, n_new_experts, freeze_old, noisy_gating, topk)
        self.add_ffn_experts(n_new_experts, freeze_old)
        
    def add_ffn_experts(self, n_new_experts, freeze_old=True):
        self.ffn_expert.add_experts(n_new_experts, freeze_old)
        
    def add_ffn_routers(self, n_new_tasks, n_new_experts, freeze_old=True, noisy_gating=False, topk=-1):
        self.ffn_gating_network.add_gates(n_new_tasks=n_new_tasks, n_new_experts=n_new_experts, freeze_old=freeze_old, noisy_gating=noisy_gating, k=topk)

    def merge_experts(self):
        """
        Merges new_weight into weight and removes new_weight.
        Should be called after loading the checkpoint of continual training.
        """
        self.ffn_expert.merge_experts()
        
    def unfreeze_norm(self):
        for param in self.attn_norm.parameters():
            param.requires_grad = True
        for param in self.ffn_norm.parameters():
            param.requires_grad = True
        

class TaskMoEBlock(nn.Module):
    def __init__(
        self,
        dim,
        k, 
        num_experts, 
        task_num, 
        expert_bias,
        shared_routers=False,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,         
        w_MI=0, w_H=0, 
        w_finetune_MI=0, 
        noisy_gating=True,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        mlp_ratio=4.0,
        drop_path=0.0,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MixtureOfExpertsAttention(
            dim=dim,
            k=k,
            num_experts=num_experts,
            task_num=task_num,
            expert_bias=expert_bias,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            w_MI=w_MI, w_H=w_H, w_finetune_MI=w_finetune_MI, noisy_gating=noisy_gating,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        
        
        self.ffn = TaskMoEFFN(
            input_size=dim, 
            head_size=mlp_hidden_dim // num_experts, 
            k=k, 
            num_experts=num_experts, 
            task_num=task_num, 
            expert_bias=expert_bias, 
            w_MI=w_MI, 
            w_H=w_H, 
            w_finetune_MI=w_finetune_MI, 
            noisy_gating=noisy_gating,
        )
        

    def forward(self, x, task_id):
        x, attn_aux_loss, attn_probs = self.attn(self.norm1(x), task_id)
        x = x + self.drop_path(x)
        x = self.norm2(x)
        x, ffn_aux_loss, ffn_probs = self.ffn(x, task_id)
        x = x + self.drop_path(x)
        return x, attn_aux_loss, attn_probs, ffn_aux_loss, ffn_probs



class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim**-0.5

        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, kvx, src_mask=None):
        B, N, C = x.shape
        _, kvN, _ = kvx.shape
        kv = (
            self.kv(kvx)
            .reshape(B, kvN, 2, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q = (
            self.q(x)
            .reshape(B, N, 1, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )

        q, k, v = (
            q[0],
            kv[0],
            kv[1],
        )  # make torchscript happy (cannot use tensor as tuple)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        if src_mask != None:
            src_mask = src_mask.unsqueeze(1)
            src_mask = src_mask.unsqueeze(1)
            src_mask = src_mask.repeat(1, self.num_heads, N, 1)
            attn = attn.masked_fill(src_mask == 0, -1e4)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
    
    
class TransformerDecoder(nn.Module):
    """
    A Transformer-based image decoder using your custom Block class.

    Args:
        in_dim (int): Dimension of the patch embeddings from the encoder.
        num_patches (int): Number of patches = (img_size // patch_size) ** 2.
        patch_size (int): Size of each image patch (e.g., 16).
        out_channels (int): Number of output channels (3 for RGB).
        depth (int): Number of Transformer blocks (layers).
        num_heads (int): Number of attention heads in each block.
        mlp_ratio (float): MLP expansion ratio inside each Block (e.g., 4.0).
        qkv_bias (bool): Whether to add bias in QKV projections.
        qk_scale (float or None): Override for QK scale if provided.
        drop (float): Dropout rate applied to projections and MLP layers.
        attn_drop (float): Dropout rate in the attention mechanism.
        drop_path (float): Stochastic depth rate (aka drop path).
        add_pos_embed (bool): Whether to add a learnable positional embedding for each patch.
        norm_layer (nn.Module): Normalization layer (e.g., nn.LayerNorm).
        act_layer (nn.Module): Activation function (e.g., nn.GELU).
    """
    def __init__(
        self,
        in_dim: int,
        num_patches: int,
        if_cls_token: bool = True,
        patch_size: int = 16,
        out_channels: int = 3,
        depth: int = 4,
        num_heads: int = 8,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        qk_scale=None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        add_pos_embed: bool = True,
        norm_layer=nn.LayerNorm,
        act_layer=nn.GELU
    ):
        super().__init__()

        self.in_dim = in_dim
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.out_channels = out_channels
        self.depth = depth

        # Typically, we keep the decoder dimension the same as in_dim in MAE,
        # but you can customize if you prefer a different dimension.
        self.hidden_dim = in_dim
        self.if_cls_token = if_cls_token
        
        # ---------------------------------------------------------------------
        # 1. Optional positional embedding (learnable)
        #    shape: [1, num_patches, hidden_dim]
        # ---------------------------------------------------------------------
        self.add_pos_embed = add_pos_embed
        if add_pos_embed:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.hidden_dim))
        else:
            self.pos_embed = None

        # ---------------------------------------------------------------------
        # 2. Project input embeddings (in_dim) -> hidden_dim (if desired)
        # ---------------------------------------------------------------------
        if self.hidden_dim != in_dim:
            self.proj = nn.Linear(in_dim, self.hidden_dim)
        else:
            self.proj = nn.Identity()

        # ---------------------------------------------------------------------
        # 3. Build a stack of Transformer blocks
        # ---------------------------------------------------------------------
        self.blocks = nn.ModuleList([
            Block(
                dim=self.hidden_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path,
                act_layer=act_layer,
                norm_layer=norm_layer,
            )
            for _ in range(depth)
        ])

        self.norm = norm_layer(self.hidden_dim)

        # ---------------------------------------------------------------------
        # 4. Final linear projection to patch pixels
        #    shape: hidden_dim -> (patch_size * patch_size * out_channels)
        # ---------------------------------------------------------------------
        self.out_proj = nn.Linear(self.hidden_dim, patch_size * patch_size * out_channels)

        self._init_weights()

    def _init_weights(self):
        # Initialize positional embedding if used
        if self.pos_embed is not None:
            nn.init.trunc_normal_(self.pos_embed, std=0.02)
        # Initialize final projection
        nn.init.xavier_uniform_(self.out_proj.weight)
        nn.init.constant_(self.out_proj.bias, 0)
        # Optionally, you could also init self.proj weights if it's not Identity.
        if isinstance(self.proj, nn.Linear):
            nn.init.xavier_uniform_(self.proj.weight)
            nn.init.constant_(self.proj.bias, 0)

    def forward(self, x: torch.Tensor):
        """
        Forward pass.

        Args:
            x (Tensor): [B, N, C], where
                B = batch size
                N = number of patches
                C = embedding dimension (in_dim)

        Returns:
            Reconstructed images of shape [B, out_channels, H, W].
        """
        if self.if_cls_token:
            x = x[:, 1:, :]
        B, N, C = x.shape
        # If your encoder includes a CLS token (e.g., x[:, 0]), remove it if needed:
        #   x = x[:, 1:, :]  # only patch tokens (comment out if not applicable)

        # 1. Project input to hidden_dim
        x = self.proj(x)  # [B, N, hidden_dim]

        # 2. Add positional embedding if configured
        if self.pos_embed is not None:
            x = x + self.pos_embed[:, :N, :]  # [1, N, hidden_dim] -> broadcast to B

        # 3. Pass through each Transformer Block
        for blk in self.blocks:
            x = blk(x)  # [B, N, hidden_dim]

        # 4. Layer normalization before final projection
        x = self.norm(x)

        # 5. Project each token to patch pixels
        x = self.out_proj(x)  # [B, N, patch_size*patch_size*out_channels]

        # 6. Un-patchify to [B, out_channels, H, W]
        #    a) reshape: [B, N, out_channels, patch_size, patch_size]
        out_channels = self.out_channels
        p = self.patch_size
        x = x.view(B, N, out_channels, p, p)

        #    b) fold patches into an image
        #       if num_patches = (H/p) * (W/p), then sqrt(num_patches) = H/p
        h_p = w_p = int(math.sqrt(N))  # e.g., if N=196, sqrt(N)=14
        H = h_p * p
        W = w_p * p
        # reshape to [B, out_channels, h_p, p, w_p, p]
        x = x.view(B, h_p, w_p, out_channels, p, p)
        # permute to [B, out_channels, h_p, p, w_p, p]
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        # finally reshape to [B, out_channels, H, W]
        x = x.view(B, out_channels, H, W)

        return x
    


# class MoEBlock(nn.Module):
#     def __init__(
#         self,
#         dim,
#         num_heads,
#         mlp_ratio=4.0,
#         qkv_bias=False,
#         qk_scale=None,
#         drop=0.0,
#         attn_drop=0.0,
#         drop_path=0.0,
#         act_layer=nn.GELU,
#         norm_layer=nn.LayerNorm,
#         num_total_experts=16,
#         num_active_experts=8,
#         task_num=1,
#         **kwargs,
#     ):
#         super().__init__()
#         self.norm1 = norm_layer(dim)
#         self.attn = Attention(
#             dim,
#             num_heads=num_heads,
#             qkv_bias=qkv_bias,
#             qk_scale=qk_scale,
#             attn_drop=attn_drop,
#             proj_drop=drop,
#         )
#         # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
#         self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
#         self.norm2 = norm_layer(dim)
#         mlp_hidden_dim = int(dim * mlp_ratio)
        
#         self.task_moe_layer = TaskMoEFFN(
#             input_size = dim,
#             head_size = mlp_hidden_dim // num_active_experts,
#             num_experts = num_total_experts,
#             k = num_active_experts,
#             expert_bias=True,
#             acc_aux_loss=True,
#             w_MI=0.0005, 
#             task_num=task_num,
#             noisy_gating=False,
#         )
        

#     def forward(self, x, task_id):
#         x = x + self.drop_path(self.attn(self.norm1(x)))
#         x = self.norm2(x)
#         x,aux_loss,probs= self.task_moe_layer(x,task_id)
#         x = x + self.drop_path(x)
#         return x,aux_loss,probs


class MoEBlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        num_total_experts=16,
        num_active_experts=8,
        task_num=1,
        **kwargs,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        
        self.task_num = task_num
        self.num_total_experts = num_total_experts
        self.num_active_experts = num_active_experts
        # expert number of parameters 885888 * 2 (in and out)_
        self.task_moe_layer = TaskMoE(
            input_size = dim,
            head_size = mlp_hidden_dim // num_active_experts,
            num_experts = num_total_experts,
            k = num_active_experts,
            bias=True,
            acc_aux_loss=True,
            w_MI=0.0005, #0.0005
            w_finetune_MI=0,
            task_num=task_num,
            activation=nn.Sequential(
                nn.GELU(),
            ),
            noisy_gating=False,
        )
    
    def count_parameters(self):
        params = {
            'total_params': 0,
            'active_params': 0
        }
        
        total_params = sum(p.numel() for p in self.parameters())
        
        task_moe_params = sum(p.numel() for p in self.task_moe_layer.parameters())
        f_gate_params = sum(p.numel() for p in self.task_moe_layer.f_gate.parameters())
        expert_params = task_moe_params - f_gate_params
        
        params['total_params'] = total_params
        params['active_params'] = total_params - f_gate_params / self.task_num * (self.task_num - 1) - expert_params / self.num_total_experts * (self.num_total_experts - self.num_active_experts)
        return params
        
    # something wrong in 'x = x + self.drop_path(x)'. but model has been trained. so the code keeps here.
    # the correct code is here.
    # def forward(self, x, task_id):
    #     x = x + self.drop_path(self.attn(self.norm1(x)))
    #     norm_x = self.norm2(x)
    #     norm_x,aux_loss,probs= self.task_moe_layer(norm_x,task_id)
    #     x = x + self.drop_path(norm_x)
    #     return x,aux_loss,probs
    
    def forward(self, x, task_id):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = self.norm2(x)
        x,aux_loss,probs= self.task_moe_layer(x,task_id)
        x = x + self.drop_path(x)
        return x,aux_loss,probs


class CSABlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm_kv1 = norm_layer(dim)
        self.cattn = CrossAttention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        #self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        '''self.mlp1 = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )'''
        self.norm3 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.norm4 = norm_layer(dim)
        self.mlp2 = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x, kvx, src_mask=None):
        x = x + self.drop_path(
            self.cattn(self.norm1(x), self.norm_kv1(kvx), src_mask=src_mask)
        )
        #x = x + self.mlp1(self.norm2(x))
        x = x + self.drop_path(self.attn(self.norm3(x)))
        x = x + self.mlp2(self.norm4(x))
        return x

