# The CNO2d code has been modified from a tutorial featured in the 
# ETH Zurich course "AI in the Sciences and Engineering."
# Git page for this course: https://github.com/bogdanraonic3/AI_Science_Engineering 

# For up/downsampling, the antialias interpolation functions from the 
# torch library are utilized, limiting the ability to design
# your own low-pass filters at present.

# While acknowledging this suboptimal setup, the performance of CNO2d remains commendable. 
# Additionally, a training script is available, offering a solid foundation for personal projects.


import torch
import torch.nn as nn
import os
import numpy as np
import torch.nn.functional as F


# CNO LReLu activation fucntion
# CNO building block (CNOBlock) → Conv2d - BatchNorm - Activation
# Lift/Project Block (Important for embeddings)
# Residual Block → Conv2d - BatchNorm - Activation - Conv2d - BatchNorm - Skip Connection
# ResNet → Stacked ResidualBlocks (several blocks applied iteratively)


#---------------------
# Activation Function:
#---------------------

class CNO_LReLu(nn.Module):
    def __init__(self,
                in_size,
                out_size
                ):
        super(CNO_LReLu, self).__init__()

        self.in_size = in_size
        self.out_size = out_size
        self.act = nn.LeakyReLU()

    def forward(self, x):
        x = F.interpolate(x, size = (2 * self.in_size, 2 * self.in_size), mode = "bicubic", antialias = True)
        x = self.act(x)
        x = F.interpolate(x, size = (self.out_size,self.out_size), mode = "bicubic", antialias = True)
        return x

#--------------------
# CNO Block:
#--------------------

class CNOBlock(nn.Module):
    def __init__(self,
                in_channels,
                out_channels,
                in_size,
                out_size,
                use_bn = True
                ):
        super(CNOBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_size  = in_size
        self.out_size = out_size

        #-----------------------------------------

        # We apply Conv -> BN (optional) -> Activation
        # Up/Downsampling happens inside Activation

        self.convolution = torch.nn.Conv2d(in_channels = self.in_channels,
                                            out_channels= self.out_channels,
                                            kernel_size = 3,
                                            padding     = 1)

        if use_bn:
            self.batch_norm  = nn.BatchNorm2d(self.out_channels)
        else:
            self.batch_norm  = nn.Identity()
        self.act           = CNO_LReLu(in_size  = self.in_size,
                                        out_size = self.out_size)
    def forward(self, x):
        x = self.convolution(x)
        x = self.batch_norm(x)
        return self.act(x)
    
#--------------------
# Lift/Project Block:
#--------------------

class LiftProjectBlock(nn.Module):
    def __init__(self,
                in_channels,
                out_channels,
                size,
                latent_dim = 64
                ):
        super(LiftProjectBlock, self).__init__()

        self.inter_CNOBlock = CNOBlock(in_channels       = in_channels,
                                        out_channels     = latent_dim,
                                        in_size          = size,
                                        out_size         = size,
                                        use_bn           = False)

        self.convolution = torch.nn.Conv2d(in_channels  = latent_dim,
                                            out_channels = out_channels,
                                            kernel_size  = 3,
                                            padding      = 1)


    def forward(self, x):
        x = self.inter_CNOBlock(x)
        x = self.convolution(x)
        return x

#--------------------
# Residual Block:
#--------------------

class ResidualBlock(nn.Module):
    def __init__(self,
                channels,
                size,
                use_bn = True
                ):
        super(ResidualBlock, self).__init__()

        self.channels = channels
        self.size     = size

        #-----------------------------------------

        # We apply Conv -> BN (optional) -> Activation -> Conv -> BN (optional) -> Skip Connection
        # Up/Downsampling happens inside Activation

        self.convolution1 = torch.nn.Conv2d(in_channels = self.channels,
                                            out_channels= self.channels,
                                            kernel_size = 3,
                                            padding     = 1)
        self.convolution2 = torch.nn.Conv2d(in_channels = self.channels,
                                            out_channels= self.channels,
                                            kernel_size = 3,
                                            padding     = 1)

        if use_bn:
            self.batch_norm1  = nn.BatchNorm2d(self.channels)
            self.batch_norm2  = nn.BatchNorm2d(self.channels)

        else:
            self.batch_norm1  = nn.Identity()
            self.batch_norm2  = nn.Identity()

        self.act           = CNO_LReLu(in_size  = self.size,
                                        out_size = self.size)


    def forward(self, x):
        out = self.convolution1(x)
        out = self.batch_norm1(out)
        out = self.act(out)
        out = self.convolution2(out)
        out = self.batch_norm2(out)
        return x + out


class ResNet(nn.Module):
    def __init__(self,
                channels,
                size,
                num_blocks,
                use_bn = True
                ):
        super(ResNet, self).__init__()

        self.channels = channels
        self.size = size
        self.num_blocks = num_blocks

        self.res_nets = []
        for _ in range(self.num_blocks):
            self.res_nets.append(ResidualBlock(channels = channels,
                                                size = size,
                                                use_bn = use_bn))

        self.res_nets = torch.nn.Sequential(*self.res_nets)

    def forward(self, x):
        for i in range(self.num_blocks):
            x = self.res_nets[i](x)
        return x
    
class CNO2d(nn.Module):
    def __init__(self,
                in_dim,                    # Number of input channels.
                out_dim,                   # Number of input channels.
                size,                      # Input and Output spatial size (required )
                N_layers,                  # Number of (D) or (U) blocks in the network
                N_res = 4,                 # Number of (R) blocks per level (except the neck)
                N_res_neck = 4,            # Number of (R) blocks in the neck
                channel_multiplier = 16,   # How the number of channels evolve?
                use_bn = True,             # Add BN? We do not add BN in lifting/projection layer
                ):

        super(CNO2d, self).__init__()

        self.N_layers = int(N_layers)         # Number od (D) & (U) Blocks
        self.lift_dim = channel_multiplier//2 # Input is lifted to the half of channel_multiplier dimension
        self.in_dim   = in_dim
        self.out_dim  = out_dim
        self.channel_multiplier = channel_multiplier  # The growth of the channels

        ######## Num of channels/features - evolution ########

        self.encoder_features = [self.lift_dim] # How the features in Encoder evolve (number of features)
        for i in range(self.N_layers):
            self.encoder_features.append(2 ** i *   self.channel_multiplier)

        self.decoder_features_in = self.encoder_features[1:] # How the features in Decoder evolve (number of features)
        self.decoder_features_in.reverse()
        self.decoder_features_out = self.encoder_features[:-1]
        self.decoder_features_out.reverse()

        for i in range(1, self.N_layers):
            self.decoder_features_in[i] = 2*self.decoder_features_in[i] #Pad the outputs of the resnets (we must multiply by 2 then)

        ######## Spatial sizes of channels - evolution ########

        self.encoder_sizes = []
        self.decoder_sizes = []
        for i in range(self.N_layers + 1):
            self.encoder_sizes.append(size // 2 ** i)
            self.decoder_sizes.append(size // 2 ** (self.N_layers - i))


        ######## Define Lift and Project blocks ########

        self.lift   = LiftProjectBlock(in_channels = in_dim,
                                        out_channels = self.encoder_features[0],
                                        size = size)

        self.project   = LiftProjectBlock(in_channels = self.encoder_features[0] + self.decoder_features_out[-1],
                                            out_channels = out_dim,
                                            size = size)

        ######## Define Encoder, ED Linker and Decoder networks ########

        self.encoder         = nn.ModuleList([(CNOBlock(in_channels  = self.encoder_features[i],
                                                        out_channels = self.encoder_features[i+1],
                                                        in_size      = self.encoder_sizes[i],
                                                        out_size     = self.encoder_sizes[i+1],
                                                        use_bn       = use_bn))
                                                for i in range(self.N_layers)])

        # After the ResNets are executed, the sizes of encoder and decoder might not match (if out_size>1)
        # We must ensure that the sizes are the same, by aplying CNO Blocks
        self.ED_expansion     = nn.ModuleList([(CNOBlock(in_channels = self.encoder_features[i],
                                                        out_channels = self.encoder_features[i],
                                                        in_size      = self.encoder_sizes[i],
                                                        out_size     = self.decoder_sizes[self.N_layers - i],
                                                        use_bn       = use_bn))
                                                for i in range(self.N_layers + 1)])

        self.decoder         = nn.ModuleList([(CNOBlock(in_channels  = self.decoder_features_in[i],
                                                        out_channels = self.decoder_features_out[i],
                                                        in_size      = self.decoder_sizes[i],
                                                        out_size     = self.decoder_sizes[i+1],
                                                        use_bn       = use_bn))
                                                for i in range(self.N_layers)])

        #### Define ResNets Blocks 

        # Here, we define ResNet Blocks.

        # Operator UNet:
        # Outputs of the middle networks are patched (or padded) to corresponding sets of feature maps in the decoder

        self.res_nets = []
        self.N_res = int(N_res)
        self.N_res_neck = int(N_res_neck)

        # Define the ResNet networks (before the neck)
        for l in range(self.N_layers):
            self.res_nets.append(ResNet(channels = self.encoder_features[l],
                                        size = self.encoder_sizes[l],
                                        num_blocks = self.N_res,
                                        use_bn = use_bn))

        self.res_net_neck = ResNet(channels = self.encoder_features[self.N_layers],
                                    size = self.encoder_sizes[self.N_layers],
                                    num_blocks = self.N_res_neck,
                                    use_bn = use_bn)

        self.res_nets = torch.nn.Sequential(*self.res_nets)

    def forward(self, x):
                
        x = self.lift(x) #Execute Lift
        skip = []
       
        # Execute Encoder
        for i in range(self.N_layers):

            #Apply ResNet & save the result
            y = self.res_nets[i](x)
            skip.append(y)

            # Apply (D) block
            x = self.encoder[i](x)
        
        # Apply the deepest ResNet (bottle neck)
        x = self.res_net_neck(x)

        # Execute Decode
        for i in range(self.N_layers):

            # Apply (I) block (ED_expansion) & cat if needed
            if i == 0:
                x = self.ED_expansion[self.N_layers - i](x) #BottleNeck : no cat
            else:
                x = torch.cat((x, self.ED_expansion[self.N_layers - i](skip[-i])),1)

            # Apply (U) block
            x = self.decoder[i](x)

        # Cat & Execute Projetion
        x = torch.cat((x, self.ED_expansion[0](skip[0])),1)
        x = self.project(x)
        
        
        
        return x
    
class TransformerModel(nn.Module):
    def __init__(self, num_classes=4, 
                 embed_dim=64, 
                 num_heads=4, 
                 num_layers=2,
                 spatial=64):
        super(TransformerModel, self).__init__()
        self.patch_embedding = nn.Linear(spatial * spatial, embed_dim)  # 将 patch 展平并嵌入
        self.transformer_layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_dim, num_heads), num_layers)
        self.fc = nn.Linear(embed_dim, num_classes)  # 输出层

    def forward(self, x):
        batch_size, time_steps, num_patches, h, w = x.shape
        x = x.view(batch_size, time_steps, num_patches, -1)  # (batch, time, patches, h*w)
        x = self.patch_embedding(x)  # (batch, time, patches, embed_dim)
        x = x.permute(0, 2, 1, 3)  # (batch, patches, time, embed_dim)
        x = x.flatten(0, 1)  # (batch * patches, time, embed_dim)
        
        x = self.transformer_layers(x)  # (batch * patches, time, embed_dim)
        
        # 聚合时间维度
        x = x.mean(dim=1)  # (batch * patches, embed_dim)
        x = x.view(batch_size, num_patches, -1)  # (batch, patches, embed_dim)
        
        x = self.fc(x)  # (batch, patches, num_classes)
        return F.softmax(x, dim=-1)  # 最后一层的应用 softmax，weights的维度
class MoEGate(nn.Module):
    def __init__(self, input_size:list,
                 num_experts:int,
                 topk:int,
                 ini_prior:list):
        super(MoEGate, self).__init__()
        
        '''
          加速
          之前是一个一个小patch输入[batch,times,patches,64,64]
          更改一下逻辑:
          输入是[batch,times,patches,64,64]
          输出是[batch,patches,weights(4)] 
        
        '''
        print("***ini_gate_input_size",input_size) # 原来应该【10,4,64,64】
        input_dim = input_size[0]*input_size[1]*input_size[2]*input_size[2]
        self.top_k = topk
        self.gate = TransformerModel(num_classes=num_experts,
                                     spatial=input_size[2])
        
        if ini_prior is not None:
          self.ini_prior =torch.tensor(ini_prior, dtype=torch.float32).cuda()
        else:
          self.ini_prior = None

    def forward(self, x):

      self.weights = self.gate(x) # [batch,patch,nums]
      batch_size,patch_size,_ = self.weights.shape
      if  self.ini_prior  is not None:
        #输出加prior
        self.weights= self.weights + self.ini_prior.repeat(batch_size, patch_size,1)  # 重复以匹配batch和patch大小
        #再一次softmax
        self.weights = F.softmax(self.weights, dim=2)

      topk_values, topk_indices =  self.weights.topk(self.top_k, dim=2, largest=True, sorted=True)
      mask = torch.zeros_like( self.weights).scatter_(2, topk_indices, 1) #有 top-k 的 indices 被标记为 1，其他位置为 0。
      return  self.weights, mask, topk_indices
    
    def load_balancing_loss(self):
      # 熵正则
      #self.weights [batch,patch,nums]第三个维度是权重
      #expert_usage：计算每个专家在所有 patch 上的平均使用频率。
      #最小化熵，让专家spares，前期可以最大化
      #e 计算每个专家的使用频率
      expert_usage = self.weights.mean(dim=1)  # [batch, nums]
      
      # 归一化使用频率
      expert_usage_normalized = expert_usage / (expert_usage.sum(dim=1, keepdim=True) + 1e-10)  # 防止除以零

      # 计算熵
      entropy = -torch.sum(expert_usage_normalized * torch.log(expert_usage_normalized + 1e-10), dim=1)  # [batch]

      # 返回平均熵作为负载平衡损失
      return entropy.mean()  # 返回一个标量损失

class CNO_Class(nn.Module):
   
    def __init__(self, **kwargs):
      '''
       fno baseline
      '''
      super(CNO_Class, self).__init__()
      #转换为元组
      self.base_model = "CNO"
      
      N_layers = kwargs["N_layers"]
      N_res = kwargs["N_res"]
      size = kwargs["spatial"]
      hidden_channels = kwargs["hidden_channels"]  # 6
      in_channels = kwargs["in_channels"]  # 1个时间步
      out_channels = kwargs["out_channels"]  # 输出的时间步，也auto-regressive 应该是1
      self.scale_value = kwargs.get('scales', "1x1_DC") #默认1x1,独立
      self.int_scale_value = int(self.scale_value.split('x')[0]) #1
      self.scale_type = self.scale_value.split('_')[1] #dc means divied and conquer,独立，ksmeans sharing，就是共享
      
      modes_list = [0 for _ in range(self.int_scale_value*self.int_scale_value)] #初始化
      print(kwargs["n_modes"])
      print("***mode_list",modes_list)
      if isinstance(kwargs["n_modes"], list):
        # 处理modes不一致
        #处理n_modes不一致的情况
        n_modes_list = kwargs["n_modes"]
        n_modes = []
        
        # 遍历列表，解析每一对元素
        for i in range(0, len(n_modes_list), 2):
            # 获取第一个和第二个元素，并去除括号 like output： [(32, 32), (128, 128), (64, 64), (16, 16)]
            first_elem = n_modes_list[i].replace('(', '').replace(')', '')
            second_elem = n_modes_list[i+1].replace('(', '').replace(')', '')
            
            # 转换为整数并组合成元组
            parsed_tuple = (int(first_elem), int(second_elem))
            
            # 将解析后的元组加入列表
            n_modes.append(parsed_tuple)
        
        modes_list = n_modes  # 转换为列表
      else:
        #单一但可重复
        modes = eval(kwargs["n_modes"])  # (64,64) fouier
        for i in range(self.int_scale_value*self.int_scale_value):
          modes_list[i] = modes
      
      print("test_modes_list",modes_list)
      
      # test 多个专家 
      self.experts_nums = len(modes_list) # 4 or 9
      
      # 再搞个experts的name list
      self.experts_name = []
    
      if "MOE" in self.scale_value:
        #启动moe
        self.MOE_EN= True
      else:
        self.MOE_EN= False
      print("**MOE_en",self.MOE_EN)
      
      print("***scale_type",self.scale_type)

      ## 根据scale的条件来
      if  self.experts_nums == 1:
        # 就一个operator,输入都不变
        modes = modes_list[0] #list 就1个 (16,16)
        self._operator = CNO2d(in_dim = in_channels,out_dim = out_channels,
                               size=size, N_layers = N_layers,N_res=N_res)
      else:
        # 多个opertoors
        self.ms_operator = nn.ModuleList()  # 使用 nn.ModuleList 来存储子模块
        
        if self.scale_type == "DC":
          #test-pb
          for i in range( self.experts_nums):
            #独立的添加，根绝expert的数量
            self.experts_name.append(f"FNO_{modes_list[i]}")
            
            self.ms_operator.append(CNO2d(in_dim = in_channels,out_dim = out_channels,
                                          size=size, N_layers = N_layers))
        elif self.scale_type == "KS":
          # 将这个实例添加到列表中 multiple 次
          sharing_fno = CNO2d(n_modes=modes, hidden_cshannels=hidden_channels,
                            in_channels=in_channels, out_channels=out_channels)
          for i in range( self.int_scale_value*self.int_scale_value):
              self.ms_operator.append(sharing_fno)
      
      print("***Paras_",self.count_paras())
 

      if self.MOE_EN:
        #确认input：input_tensor
        #topk 要设置
        #set_size= [10,4,64,64] -ns
        ini_prior = kwargs["ini_prior"]
        #确保先验的长度一致
        assert len(ini_prior) == self.experts_nums, f"Expected {self.experts_nums} experts, but got {len(ini_prior)} in ini_prior."
        print("MOE_Gate_paras_input",kwargs["moe_gate_input_size"])
        gate_size = kwargs["moe_gate_input_size"]
        self.gate =  MoEGate(input_size =  gate_size, 
                             num_experts = self.experts_nums,
                             topk = kwargs["moe_topk"],
                             ini_prior = ini_prior)
        
        # 记录的值，eval的时候可查看, 最后的是专家的权重维度，就记录batch的第一个样本-pengxiao
        self.weight_map =torch.zeros(self.int_scale_value*self.int_scale_value,self.experts_nums) 

      
    @property
    def operator(self):
        return self._operator
    @property
    def Mscale_operator(self):
        # operator 是个list
        return self.ms_operator
    @property
    def parallel_operator(self,re_order):
        pass
   
    def oldMOE_operator(self,x)->list:
        #x:分割好的【batch，10，4，64，64】
        # 计算每个专家的权重
        print("moe_in",x.shape)
        if self.MOE_EN and hasattr(self, 'gate'):
        
          self.weights, self.mask, indices = self.gate(x)
          #print("x",x.shape) #x torch.Size([12, 10, 4, 64, 64])
          out =[]
          for index,expert in enumerate(self.ms_operator):
            #改变一个逻辑，每个patch去经过一次gate，让gate去选择！doing
            # 未来可以对于patch可以是多个
            
            #gate
            
            singel_out = expert(x[:,:,index,:,:])
            out.append(singel_out)
          
          outputs = torch.stack(out,dim=1)#outputs torch.Size([12, 4, 1, 64, 64])

          # 使用掩码和权重结合调整输出，权重过小可能还是不行，相当于专家模型没有得到训练，反向传播有问题，weight变成1可能好
          masked_weights = 1* self.mask  # 仅考虑 topk 专家的权重 #【12，4】 #原来self.weights * self.mask  
          weighted_outputs = outputs * masked_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # 扩展权重维度以匹配输出维度
          final_output = weighted_outputs.sum(dim=1)  # 沿专家维度求和
          load_balancing_loss = self.gate.load_balancing_loss()
        
       # print("final_output",final_output.shape) [12, 1, 4, 128, 128]
          
        return final_output,load_balancing_loss
    
    
    @property
    def load_balancing_loss(self):
      
      load_balancing_loss = self.gate.load_balancing_loss()
      return load_balancing_loss
    
    @property
    def Moe_weight_info(self):
      #weight 应该patch 不一样
      return self.weight_map,self.mask
      
    def forward(self,x):
      pass
    
    def count_paras(self):
      return sum(p.numel() for p in self.parameters())
    
    def MOE_operator0922(self, x, policy="Stragey_1"):
      """
      x: 输入张量 [batch, steps, patches, x_domain, y_domain]
      policy: 可以选择 'Stragey_1' 或 'Stragey_2' 来决定不同的专家选择策略
      """
      batch, steps, patches, x_domain, y_domain = x.shape
      load_balancing_loss = 0
  

      # 如果启用了 MOE 并且存在 gate
      if self.MOE_EN and hasattr(self, 'gate'):
          final_output = []
          
          # 遍历每个 patch
          for index in range(patches):

            # 获取当前的 patch，时间不分割
            patch = x[:, :, index, :, :]  # 输出[batch, steps, x_domain, y_domain]
     
            self.weights, self.mask, indices = self.gate(patch)
            
            # 记录--重点eval第一个样本

            self.weight_map[index,:] = self.weights[0,:]
            
            # 确定 topk 的专家数量
            experts_number = self.gate.top_k
            masked_weights = self.weights * self.mask  # 仅考虑 topk 专家的权重
            
            # 初始化存储专家输出的张量，step都设置为1，roll-out
            experts_output = torch.zeros(batch, 1, x_domain, y_domain, device=x.device)
            all_expert_outputs = []

            # 批量计算所有专家的输出，避免循环逐个计算
            for expert in self.ms_operator:
                expert_output = expert(patch)  # 每个专家对整个 batch 进行计算
                all_expert_outputs.append(expert_output)  # 存储专家的输出
    
            
            # 将所有专家的输出堆叠为一个张量，形状为 [batch_size, num_experts, steps, x_domain, y_domain]
            all_expert_outputs = torch.stack(all_expert_outputs, dim=1)

            
            # 根据策略选择专家并计算最终输出
            if policy == "Stragey_1":
                # 根据 indices 和 masked_weights 选择对应的专家和权重
                for idx in range(experts_number):
                    expert_indices = indices[:, idx]  # 每个 patch 的第 idx 个专家的索引应该是batch个
                    expert_weights = masked_weights[:, idx]  # 对应专家的权重 batch个

          
                    # 确保 expert_indices 的维度合适
                    expert_indices = expert_indices.unsqueeze(1)  # 变成 [batch_size, 1]s
           

                    # 使用 gather 从 all_expert_outputs 中提取对应的专家输出
                    selected_expert_outputs = torch.gather(all_expert_outputs, 1, 
                                                          expert_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, x_domain, y_domain)) #1为step
                    
                    

                    # 根据权重加权专家的输出
    
                    experts_output += selected_expert_outputs.squeeze(1) * expert_weights.view(-1, 1, 1, 1)

                  

            elif policy == "Stragey_2":
                # 在策略2中，你可以实现不同的专家选择逻辑或加权机制
                # 示例：采用某种不同的专家加权机制
                for idx in range(experts_number):
                    expert_indices = indices[:, idx]
                    expert_weights = masked_weights[:, idx] ** 2  # 改变加权方式（平方权重）
                    selected_expert_outputs = torch.gather(all_expert_outputs, 1, 
                                                          expert_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, steps, x_domain, y_domain))
                    experts_output += selected_expert_outputs.squeeze(1) * expert_weights.view(-1, 1, 1, 1) #ps  数据time_step=1
            
            final_output.append(experts_output)

            # 每次迭代计算 load balancing loss
            load_balancing_loss += self.gate.load_balancing_loss()
        
     

          # 堆叠所有 patches 的输出，形状为 [12, 1, 4, 128, 128] or [batch,1,4,64,64]
          final_output = torch.stack(final_output, dim=2)

        
          return final_output, load_balancing_loss
    
    def MOE_operator(self, x, policy="Strategy_1"):
        """
        x: 输入张量 [batch, steps, patches, x_domain, y_domain]
        policy: 可以选择 'Strategy_1' 或 'Strategy_2' 来决定不同的专家选择策略
        """
        batch, steps, patches, x_domain, y_domain = x.shape
        load_balancing_loss = 0

        # 获取 gate 的输出
        patch_weights, self.mask, indices = self.gate(x)  # 返回 [batch, patches, num_experts]
    
        self.weight_map = patch_weights[0,:,:]
        # 初始化输出张量
        outputs = torch.zeros(batch, patches, x_domain, y_domain, self.experts_nums, device=x.device)

        # 重塑 x 以方便索引
        x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(batch * patches, steps, x_domain, y_domain)  # [batch*patches, steps, x_domain, y_domain]

        for i in range(self.experts_nums):
            # 获取专家 i 的 mask，形状为 [batch, patches]
            expert_mask = self.mask[:, :, i]  # [batch, patches]
            expert_mask_flat = expert_mask.reshape(-1)  # [batch * patches]
            selected_indices = torch.nonzero(expert_mask_flat).squeeze(1)  # [num_selected_patches]

            if selected_indices.numel() == 0:
                continue  # 如果没有分配给该专家的 patches，跳过

            # 选择对应的 x 数据
            x_selected = x_reshaped[selected_indices]  # [num_selected_patches, steps, x_domain, y_domain]

            # 通过专家模型处理
            output_selected = self.ms_operator[i](x_selected)  # [num_selected_patches, 1, x_domain, y_domain]

            # 计算对应的 batch_indices 和 patch_indices
            batch_indices = selected_indices // patches
            patch_indices = selected_indices % patches

            # 将输出分配到 outputs 张量的对应位置
            outputs[batch_indices, patch_indices, :, :, i] = output_selected[:, 0, :, :]

        # 使用 einsum 对专家的输出进行加权求和
        # outputs: [batch, patches, x_domain, y_domain, num_experts]
        # patch_weights: [batch, patches, num_experts]
        # 计算加权输出
        weighted_outputs = torch.einsum('bpxyi,bpi->bpxy', outputs, patch_weights)  # [batch, patches, x_domain, y_domain]

        # 添加时间维度，变为 [batch, 1, patches, x_domain, y_domain]
        weighted_outputs = weighted_outputs.unsqueeze(1)

        return weighted_outputs, load_balancing_loss
