import torch
import torch.nn as nn
import numpy
import math
from .deepseekMoE import MoEGate, AddAuxiliaryLoss, GroupedDeepSeekMoE
import torch.nn.functional as F

def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )
   
def reshape_tensor(x, heads):
    bs, length, width = x.shape
    #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
    x = x.view(bs, length, heads, -1)
    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
    x = x.transpose(1, 2)
    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
    x = x.reshape(bs, heads, length, -1)
    return x

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)


    def forward(self, x, latents):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D)
        """
        x = self.norm1(x)
        latents = self.norm2(latents)
        
        b, l, _ = latents.shape

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        
        q = reshape_tensor(q, self.heads)
        k = reshape_tensor(k, self.heads)
        v = reshape_tensor(v, self.heads)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v
        
        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)

        return self.to_out(out)

class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
    ):
        super().__init__()
        
        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) # (1, num_queries, dim)
        
        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)
        
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        '''
        x: (b, n, embedding_dim)
        '''
        latents = self.latents.repeat(x.size(0), 1, 1) # (b, num_queries, dim)
        x = self.proj_in(x) # (b, n, dim)
        
        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents
            
        latents = self.proj_out(latents)
        return self.norm_out(latents)

class CopyExpert(nn.Module):
    def __init__(self):
        super(CopyExpert, self).__init__()

    def forward(self, inputs):
        return inputs

class ZeroExpert(nn.Module):
    def __init__(self):
        super(ZeroExpert, self).__init__()

    def forward(self, inputs):
        return torch.zeros_like(inputs).to(inputs.dtype).to(inputs.device)
    

class ClipExpert(nn.Module):
    def __init__(self, config):
        super(ClipExpert, self).__init__()
        self.resampler = Resampler(
            dim=config.resampler_intermediate_size,
            depth=config.resampler_depth,
            dim_head=config.resampler_dim_head,
            heads=config.resampler_heads,
            num_queries=config.resampler_num_tokens,
            embedding_dim=512,
            output_dim=config.hidden_size,
            ff_mult=config.resampler_ff_mult,
        )

        self.wg = torch.nn.Linear(config.hidden_size, 2, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)
    
    def forward(self, inputs, clip_emb):
        weight = self.wg(inputs)
        weight = self.softmax(weight) # (B, L, 2)

        clip_resampled = self.resampler(clip_emb.unsqueeze(1))

        outputs = weight[:, :, 0:1] * inputs + weight[:, :, 1:2] * clip_resampled

        return outputs

class FaceExpert(nn.Module):
    def __init__(self, config):
        super(FaceExpert, self).__init__()
        self.resampler = Resampler(
            dim=config.resampler_intermediate_size,
            depth=config.resampler_depth,
            dim_head=config.resampler_dim_head,
            heads=config.resampler_heads,
            num_queries=config.resampler_num_tokens,
            embedding_dim=512,
            output_dim=config.hidden_size,
            ff_mult=config.resampler_ff_mult,
        )
        self.wg = torch.nn.Linear(config.hidden_size, 2, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)
    
    def forward(self, inputs, face_emb):
        weight = self.wg(inputs)
        weight = self.softmax(weight) # (B, L, 2)

        face_resampled = self.resampler(face_emb.unsqueeze(1))

        outputs = weight[:, :, 0:1] * inputs + weight[:, :, 1:2] * face_resampled
        
        return outputs

class ShowoImageExpert(nn.Module):
    def __init__(self, config):
        super(ShowoImageExpert, self).__init__()
        self.resampler = Resampler(
            dim=config.resampler_intermediate_size,
            depth=config.resampler_depth,
            dim_head=config.resampler_dim_head,
            heads=config.resampler_heads,
            num_queries=config.resampler_num_tokens,
            embedding_dim=512,
            output_dim=config.hidden_size,
            ff_mult=config.resampler_ff_mult,
        )
        self.wg = torch.nn.Linear(config.hidden_size, 2, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)
    
    def forward(self, inputs, showo_img_embs):
        weight = self.wg(inputs)
        weight = self.softmax(weight)

        image_resampled = self.resampler(showo_img_embs)

        outputs = weight[:, :, 0:1] * inputs + weight[:, :, 1:2] * image_resampled
    
        return outputs

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256, silu=True):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    def timestep_embedding(self, t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_freq = t_freq.to(t.dtype)
        t_emb = self.mlp(t_freq)
        return t_emb

class NoiseExpert(nn.Module):
    def __init__(self, config):
        super(NoiseExpert, self).__init__()
        self.sigma_map = TimestepEmbedder(
            config.hidden_size,
        )
        self.wg = torch.nn.Linear(config.hidden_size, 2, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.use_scale_noise_moe = config.get('use_scale_noise_moe', False)
        if self.use_scale_noise_moe:
            self.scale_modulation = nn.Linear(2048, 2 * 2048, bias=True)
        else:
            self.resampler = Resampler(
                dim=config.resampler_intermediate_size,
                depth=config.resampler_depth,
                dim_head=config.resampler_dim_head,
                heads=config.resampler_heads,
                num_queries=config.resampler_num_tokens,
                embedding_dim=2048,
                output_dim=config.hidden_size,
                ff_mult=config.resampler_ff_mult,
            )


    def forward(self, inputs, sigmas):
        weight = self.wg(inputs)
        weight = self.softmax(weight)
        noise_emb = self.sigma_map(sigmas)
        if self.use_scale_noise_moe:
            c = F.silu(noise_emb)
            scale_noise = self.scale_modulation(c)
            shift, scale = scale_noise[:, None].chunk(2, dim=2)
            ada_inputs = inputs * (1.0 + scale) + shift
            outputs = weight[:, :, 0:1] * inputs + weight[:, :, 1:2] * ada_inputs
        else:

            noise_resampled = self.resampler(noise_emb.unsqueeze(1))
            outputs = weight[:, :, 0:1] * inputs + weight[:, :, 1:2] * noise_resampled


        return outputs



class ShowoTextExpert(nn.Module):
    def __init__(self, config):
        super(ShowoTextExpert, self).__init__()
        self.resampler = Resampler(
            dim=config.resampler_intermediate_size,
            depth=config.resampler_depth,
            dim_head=config.resampler_dim_head,
            heads=config.resampler_heads,
            num_queries=config.resampler_num_tokens,
            embedding_dim=512,
            output_dim=config.hidden_size,
            ff_mult=config.resampler_ff_mult,
        )

        self.wg = torch.nn.Linear(config.hidden_size, 2, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)
    
    def forward(self, inputs, showo_text_embs):
        weight = self.wg(inputs)
        weight = self.softmax(weight)

        text_resampled = self.resampler(showo_text_embs)

        outputs = weight[:, :, 0:1] * inputs + weight[:, :, 1:2] * text_resampled
        
        return outputs

class InstanceGate(nn.Module):
    def __init__(
        self, 
        d_model, 
        num_experts, 
        use_attention_pooling=False,
        top_k=1,
    ):
        super(InstanceGate, self).__init__()
        self.num_experts = num_experts
        self.use_attention_pooling = use_attention_pooling

        if self.use_attention_pooling:
            self.token_attention_pooling = nn.Linear(d_model, 1)

        self.gate_network = nn.Linear(d_model, num_experts)

        self.top_k = top_k


    def forward(self, hidden_states):
        '''
        hidden_states: (B, L, D)
        '''
        B, L, D = hidden_states.shape
        if self.use_attention_pooling:
            attention_weights = F.softmax(self.token_attention_pooling(hidden_states), dim=1)  # (B, L, 1)
            sequence_representation = torch.bmm(attention_weights.transpose(1, 2), hidden_states).squeeze(1)  # (B, D)
        else:
            sequence_representation = hidden_states.mean(dim=1)  # (B, D)

        logits = self.gate_network(sequence_representation)  # (B, num_experts)

        routing_weights = F.softmax(logits, dim=-1)  # (B, num_experts)
       
        routing_weights, routing_idxs = torch.topk(routing_weights, k=self.top_k, dim=-1) # (B, top_k)

        return routing_weights, routing_idxs

class InstanceMoE(nn.Module):
    def __init__(
        self,
        config,
        print_freq,
    ):
        super(InstanceMoE, self).__init__()

        self.n_copy_experts = config.get('n_copy_experts', 0)
        self.n_zero_experts = config.get('n_zero_experts', 0)
        self.n_text_experts = config.get('n_text_experts', 0)
        self.n_noise_experts = config.get('n_noise_experts', 0)
        self.n_clip_experts = config.get('n_clip_experts', 0)
        self.n_face_experts = config.get('n_face_experts', 0)
        self.n_showo_image_experts = config.get('n_showo_image_experts', 0)

        self.n_routed_experts = self.n_copy_experts + \
                                self.n_zero_experts + \
                                self.n_text_experts + \
                                self.n_noise_experts + \
                                self.n_clip_experts + \
                                self.n_face_experts + \
                                self.n_showo_image_experts
        
        self.experts = nn.ModuleList()
        if self.n_copy_experts > 0:
            self.experts.extend([CopyExpert() for _ in range(self.n_copy_experts)])
        if self.n_zero_experts > 0:
            self.experts.extend([ZeroExpert() for _ in range(self.n_zero_experts)])
        if self.n_text_experts > 0:
            self.experts.extend([ShowoTextExpert(config) for _ in range(self.n_text_experts)])
        if self.n_noise_experts > 0:
            self.experts.extend([NoiseExpert(config) for _ in range(self.n_noise_experts)])
        if self.n_clip_experts > 0:
            self.experts.extend([ClipExpert(config) for _ in range(self.n_clip_experts)])
        if self.n_face_experts > 0:
            self.experts.extend([FaceExpert(config) for _ in range(self.n_face_experts)])
        if self.n_showo_image_experts > 0:
            self.experts.extend([ShowoImageExpert(config) for _ in range(self.n_showo_image_experts)])

        self.gate = InstanceGate(
            d_model=config.hidden_size,
            num_experts=self.n_routed_experts,
            use_attention_pooling=True,
            top_k = config.get('num_experts_per_instance', 1)
        )

        self.top_k = config.get('num_experts_per_instance', 1)

        self.print_freq = print_freq

    def forward(
        self, 
        hidden_states, 
        clip_features=None,
        face_features=None,
        sigmas=None,
    ):
        expert_outs = []
        for i, expert in enumerate(self.experts):
            if isinstance(expert, CopyExpert):
                expert_outs.append(expert(hidden_states))
            elif isinstance(expert, ZeroExpert):
                expert_outs.append(expert(hidden_states))
            elif isinstance(expert, ShowoTextExpert):
                raise NotImplementedError
            elif isinstance(expert, NoiseExpert):
                expert_outs.append(expert(hidden_states, sigmas))
            elif isinstance(expert, ClipExpert):
                expert_outs.append(expert(hidden_states, clip_features))
            elif isinstance(expert, FaceExpert):
                expert_outs.append(expert(hidden_states, face_features))
            elif isinstance(expert, ShowoImageExpert):
                raise NotImplementedError
            else:
                raise ValueError(f"Invalid expert type: {type(expert)}")

        expert_outs = torch.stack(expert_outs, dim=1)  # (B, num_experts, L, D)

        routing_weights, routing_idx = self.gate(hidden_states) # (B, top_k)

        y = torch.empty_like(hidden_states)

        B, L, D = hidden_states.shape
        top_k_expert_outs = torch.gather(
            expert_outs, 
            dim=1, 
            index=routing_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, L, D) # (B, top_k, L, D)
        )
        routing_weights = routing_weights.unsqueeze(-1).unsqueeze(-1)  # (B, top_k, 1, 1)
        
        y = (top_k_expert_outs * routing_weights).sum(dim=1)  # (B, L, D)
        
        if self.print_freq:
            return y, routing_idx
        else:
            return y

class GroupedInstanceMoE(nn.Module):
    def __init__(
        self,
        config,
        print_freq=True,
        layer_idx=None,
    ):
        super().__init__()
        self.num_image_tokens = config.num_image_tokens
        config_t2i_text = config.t2i.text
        config_t2i_image = config.t2i.image
        config_mmu_text = config.mmu.text
        config_mmu_image = config.mmu.image

        for config_sub in [config_t2i_text, config_t2i_image, config_mmu_text, config_mmu_image]:
            config_sub.resampler_intermediate_size = config.resampler_intermediate_size
            config_sub.resampler_depth = config.resampler_depth
            config_sub.resampler_dim_head = config.resampler_dim_head
            config_sub.resampler_heads = config.resampler_heads
            config_sub.resampler_num_tokens = config.resampler_num_tokens
            config_sub.resampler_ff_mult = config.resampler_ff_mult
            config_sub.hidden_size = config.hidden_size

        config_t2i_image.use_scale_noise_moe = config.t2i.use_scale_noise_moe
        
        
        self.mmu_image_moe = InstanceMoE(config_mmu_image, print_freq)
        # self.t2i_image_moe = InstanceMoE(config_t2i_image, print_freq)

        self.print_freq = print_freq
        if print_freq:
            self.print_freq = print_freq
            self.layer_idx = layer_idx
            self.t2i_freq = [0.0] * 4
            self.t2i_times = [0] * 4
            self.mmu_freq = [0.0] * self.mmu_image_moe.n_routed_experts
            self.mmu_times = [0] * self.mmu_image_moe.n_routed_experts

    def compute_freq(self, topk_idx_mmu, topk_idx_t2i):
        '''
        topk_idx_mmu: [B, L, K] - 表示 MMU 路由中每个 token 的 top-K 专家索引, device:cuda
        topk_idx_t2i: [B, L, K] - 表示 T2I 路由中每个 token 的 top-K 专家索引, device:cuda
        '''
        # MMU 专家统计
        if topk_idx_mmu is not None:
            # 将 topk_idx_mmu 展平为 1D 张量
            flat_mmu_indices = topk_idx_mmu.view(-1)
            # 统计每个专家的激活次数（使用 torch.bincount）
            mmu_counts = torch.bincount(flat_mmu_indices, minlength=len(self.mmu_times)).cpu()
            # 更新 self.mmu_times
            self.mmu_times = [self.mmu_times[i] + mmu_counts[i].item() for i in range(len(self.mmu_times))]
            # 计算频率
            total_mmu_count = sum(self.mmu_times)
            if total_mmu_count > 0:
                self.mmu_freq = [count / total_mmu_count for count in self.mmu_times]

        # T2I 专家统计
        if topk_idx_t2i is not None:
            # 将 topk_idx_t2i 展平为 1D 张量
            flat_t2i_indices = topk_idx_t2i.view(-1)
            # 统计每个专家的激活次数（使用 torch.bincount）
            t2i_counts = torch.bincount(flat_t2i_indices, minlength=len(self.t2i_times)).cpu()
            # 更新 self.t2i_times
            self.t2i_times = [self.t2i_times[i] + t2i_counts[i].item() for i in range(len(self.t2i_times))]
            # 计算频率
            total_t2i_count = sum(self.t2i_times)
            if total_t2i_count > 0:
                self.t2i_freq = [count / total_t2i_count for count in self.t2i_times]

        # print(
        #     f"[Sequence-Level]Layer_idx: {self.layer_idx}, \
        #         MMU freq: {self.mmu_freq}, T2I freq: {self.t2i_freq}, \
        #         MMU times: {self.mmu_times}, T2I times: {self.t2i_times}"
        # )
        # print(
        #     f"[Sequence-Level]Layer_idx: {self.layer_idx}, MMU freq: {self.mmu_freq}, MMU times: {self.mmu_times},"
        # )


    def forward(self, hidden_states, clip_features, face_features, bsz_t2i, sigmas):
        if bsz_t2i == 0:
            mmu_hidden_states = hidden_states

            mmu_image_hidden_states = mmu_hidden_states[:, 2:2+self.num_image_tokens, :]
            mmu_text_hidden_states = mmu_hidden_states[:, 2+self.num_image_tokens+1:, :]

            # if self.print_freq:
            #     mmu_image_hidden_states, topk_idx_mmu = self.mmu_image_moe(mmu_image_hidden_states, clip_features=clip_features, face_features=face_features)
            #     self.compute_freq(topk_idx_mmu, None)
            # else:
            #     mmu_image_hidden_states = self.mmu_image_moe(mmu_image_hidden_states, clip_features=clip_features, face_features=face_features)
            # mmu_text_hidden_states = self.mmu_text_moe(mmu_text_hidden_states, clip_features, face_features)

            mmu_prefix = mmu_hidden_states[:, :2, :]  # 前两个 token (<mmu>, <soi>)
            mmu_center = mmu_hidden_states[:, 2+self.num_image_tokens:2+self.num_image_tokens+1, :]
            updated_mmu_hidden_states = torch.cat([mmu_prefix, mmu_image_hidden_states, mmu_center, mmu_text_hidden_states], dim=1)

            return updated_mmu_hidden_states
            
        elif bsz_t2i == hidden_states.shape[0]:
            t2i_hidden_states = hidden_states

            t2i_image_hidden_states = t2i_hidden_states[:, -(self.num_image_tokens+1):-1 :]
            t2i_text_hidden_states = t2i_hidden_states[:, 1:t2i_hidden_states.shape[1]-1-self.num_image_tokens-1, :]


            # if self.print_freq:
            #     t2i_image_hidden_states, topk_idx_t2i = self.t2i_image_moe(t2i_image_hidden_states, clip_features=clip_features, face_features=face_features, sigmas=sigmas)
            #     self.compute_freq(None, topk_idx_t2i)
            # else:
            #     t2i_image_hidden_states = self.t2i_image_moe(t2i_image_hidden_states, clip_features=clip_features, face_features=face_features, sigmas=sigmas)

            t2i_prefix = t2i_hidden_states[:, :1, :]  # 第一个 token (<t2i>)
            t2i_center = t2i_hidden_states[:, -(self.num_image_tokens+2):-(self.num_image_tokens+1), :]  # 中间部分 (<soi>)
            t2i_suffix = t2i_hidden_states[:, -1:, :]  # 最后一个 token (<eoi>)
            updated_t2i_hidden_states = torch.cat([t2i_prefix, t2i_text_hidden_states, t2i_center, t2i_image_hidden_states, t2i_suffix], dim=1)

            return updated_t2i_hidden_states

        else:
            t2i_hidden_states = hidden_states[:bsz_t2i] # <t2i> text_tokens <soi> image_tokens <eoi>
            mmu_hidden_states = hidden_states[bsz_t2i:] # <mmu> <soi> image_tokens <eoi> text_tokens

            t2i_image_hidden_states = t2i_hidden_states[:, -(self.num_image_tokens+1):-1 :]
            t2i_text_hidden_states = t2i_hidden_states[:, 1:t2i_hidden_states.shape[1]-1-self.num_image_tokens-1, :]
            mmu_image_hidden_states = mmu_hidden_states[:, 2:2+self.num_image_tokens, :]
            mmu_text_hidden_states = mmu_hidden_states[:, 2+self.num_image_tokens+1:, :]

            # if self.print_freq:
            #     # t2i_hidden_states, topk_idx_t2i = self.t2i_image_moe(t2i_hidden_states, clip_features=clip_features, face_features=face_features, sigmas=sigmas)
            #     mmu_hidden_states, topk_idx_mmu = self.mmu_image_moe(mmu_hidden_states, clip_features=clip_features, face_features=face_features)
            #     self.compute_freq(topk_idx_mmu, None)
            # else:       
            #     # t2i_image_hidden_states = self.t2i_image_moe(t2i_image_hidden_states, sigmas=sigmas)
            #     # t2i_text_hidden_states = self.t2i_text_moe(t2i_text_hidden_states, clip_features, face_features)
            #     mmu_image_hidden_states = self.mmu_image_moe(mmu_image_hidden_states, clip_features=clip_features, face_features=face_features)
            #     # mmu_text_hidden_states = self.mmu_text_moe(mmu_text_hidden_states, clip_features, face_features)

            t2i_prefix = t2i_hidden_states[:, :1, :]  # 第一个 token (<t2i>)
            t2i_center = t2i_hidden_states[:, -(self.num_image_tokens+2):-(self.num_image_tokens+1), :]  # 中间部分 (<soi>)
            t2i_suffix = t2i_hidden_states[:, -1:, :]  # 最后一个 token (<eoi>)
            updated_t2i_hidden_states = torch.cat([t2i_prefix, t2i_text_hidden_states, t2i_center, t2i_image_hidden_states, t2i_suffix], dim=1)

            mmu_prefix = mmu_hidden_states[:, :2, :]  # 前两个 token (<mmu>, <soi>)
            mmu_center = mmu_center = mmu_hidden_states[:, 2+self.num_image_tokens:2+self.num_image_tokens+1, :]
            updated_mmu_hidden_states = torch.cat([mmu_prefix, mmu_image_hidden_states, mmu_center, mmu_text_hidden_states], dim=1)

            updated_hidden_states = torch.cat([updated_t2i_hidden_states, updated_mmu_hidden_states], dim=0)

            return updated_hidden_states
    
class GroupTokenInstanceMoE(nn.Module):
    def __init__(
        self,
        config_token_moe,
        config_instance_moe,
        one_shared_experts,
        layer_idx=None,
    ):
        super().__init__()
        self.token_moe = GroupedDeepSeekMoE(
            config_token_moe,
            one_shared_experts=one_shared_experts,
            layer_idx=layer_idx,
        )
        self.instance_moe = GroupedInstanceMoE(
            config_instance_moe,
            layer_idx=layer_idx,
        )

    def forward(self, hidden_states, bsz_t2i, clip_features, face_features, sigmas):
        if clip_features is not None:
            clip_features = clip_features.to(hidden_states.dtype)
        if face_features is not None:
            face_features = face_features.to(hidden_states.dtype)
        if sigmas is not None:
            sigmas = sigmas.to(hidden_states.dtype)

        hidden_states = self.token_moe(hidden_states, bsz_t2i)

        hidden_states = self.instance_moe(hidden_states, clip_features, face_features, bsz_t2i, sigmas)

        return hidden_states
    