from neuralop.models import FNO
import torch
import torch.nn as nn
import torch.nn.functional as F

class oldMoEGate(nn.Module):
    def __init__(self, input_size:list,
                 num_experts:int,
                 topk:int,
                 ini_prior:list):
        super(MoEGate, self).__init__()
        
        '''
          加速
          之前是一个一个小patch输入[batch,times,patche,64,64]
          更改一下逻辑:
          输入是[batch,times,patches,64,64]
          输出是[batch,patches,weights(4)] 
        
        '''
        print("***ini_gate_input_size",input_size) # 原来应该【10,1,64,64】
        input_dim = input_size[0]*input_size[1]*input_size[2]*input_size[2]
        self.top_k = topk
        self.gate = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, 512),
            nn.ELU(),               # 在 512 维的全连接层后加入 ELU 激活函数
            nn.Linear(512, num_experts),
            nn.Softmax(dim=1)
        )
        if ini_prior is not None:
          self.ini_prior =torch.tensor(ini_prior, dtype=torch.float32).cuda()
        else:
          self.ini_prior = None

    def forward(self, x):


      self.weights = self.gate(x)
      batch_size = self.weights.size(0)
      if  self.ini_prior  is not None:
        #输出加prior
        self.weights= self.weights + self.ini_prior.repeat(batch_size, 1)  # 重复以匹配batch大小
        #再一次softmax
        self.weights = F.softmax(self.weights, dim=1)

      topk_values, topk_indices =  self.weights.topk(self.top_k, dim=1, largest=True, sorted=True)
      mask = torch.zeros_like( self.weights).scatter_(1, topk_indices, 1)
      return  self.weights, mask, topk_indices
    
    def load_balancing_loss(self):
      # 熵正则
      #最小化熵，让专家spares，前期可以最大化
      expert_usage = self.weights.mean(0)
      entropy = -torch.sum(expert_usage * torch.log(expert_usage + 1e-10))
      return entropy
class TransformerModel(nn.Module):
    def __init__(self, num_classes=4, 
                 embed_dim=64, 
                 num_heads=4, 
                 num_layers=2,
                 spatial=64):
        super(TransformerModel, self).__init__()
        self.patch_embedding = nn.Linear(spatial * spatial, embed_dim)  # 将 patch 展平并嵌入
        self.transformer_layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_dim, num_heads), num_layers)
        self.fc = nn.Linear(embed_dim, num_classes)  # 输出层

    def forward(self, x):
        batch_size, time_steps, num_patches, h, w = x.shape
        x = x.view(batch_size, time_steps, num_patches, -1)  # (batch, time, patches, h*w)
        x = self.patch_embedding(x)  # (batch, time, patches, embed_dim)
        x = x.permute(0, 2, 1, 3)  # (batch, patches, time, embed_dim)
        x = x.flatten(0, 1)  # (batch * patches, time, embed_dim)
        
        x = self.transformer_layers(x)  # (batch * patches, time, embed_dim)
        
        # 聚合时间维度
        x = x.mean(dim=1)  # (batch * patches, embed_dim)
        x = x.view(batch_size, num_patches, -1)  # (batch, patches, embed_dim)
        
        x = self.fc(x)  # (batch, patches, num_classes)
        return F.softmax(x, dim=-1)  # 最后一层的应用 softmax，weights的维度
class MoEGate(nn.Module):
    def __init__(self, input_size:list,
                 num_experts:int,
                 topk:int,
                 ini_prior:list):
        super(MoEGate, self).__init__()
        
        '''
          加速
          之前是一个一个小patch输入[batch,times,patches,64,64]
          更改一下逻辑:
          输入是[batch,times,patches,64,64]
          输出是[batch,patches,weights(4)] 
        
        '''
        print("***ini_gate_input_size",input_size) # 原来应该【10,4,64,64】
        input_dim = input_size[0]*input_size[1]*input_size[2]*input_size[2]
        self.top_k = topk
        self.gate = TransformerModel(num_classes=num_experts,
                                     spatial=input_size[2])
        
        if ini_prior is not None:
          self.ini_prior =torch.tensor(ini_prior, dtype=torch.float32).cuda()
        else:
          self.ini_prior = None

    def forward(self, x):

      self.weights = self.gate(x) # [batch,patch,nums]
      batch_size,patch_size,_ = self.weights.shape
      if  self.ini_prior  is not None:
        #输出加prior
        self.weights= self.weights + self.ini_prior.repeat(batch_size, patch_size,1)  # 重复以匹配batch和patch大小
        #再一次softmax
        self.weights = F.softmax(self.weights, dim=2)

      topk_values, topk_indices =  self.weights.topk(self.top_k, dim=2, largest=True, sorted=True)
      mask = torch.zeros_like( self.weights).scatter_(2, topk_indices, 1) #有 top-k 的 indices 被标记为 1，其他位置为 0。
      return  self.weights, mask, topk_indices
    
    def load_balancing_loss(self):
      # 熵正则
      #self.weights [batch,patch,nums]第三个维度是权重
      #expert_usage：计算每个专家在所有 patch 上的平均使用频率。
      #最小化熵，让专家spares，前期可以最大化
      #e 计算每个专家的使用频率
      expert_usage = self.weights.mean(dim=1)  # [batch, nums]
      
      # 归一化使用频率
      expert_usage_normalized = expert_usage / (expert_usage.sum(dim=1, keepdim=True) + 1e-10)  # 防止除以零

      # 计算熵
      entropy = -torch.sum(expert_usage_normalized * torch.log(expert_usage_normalized + 1e-10), dim=1)  # [batch]

      # 返回平均熵作为负载平衡损失
      return entropy.mean()  # 返回一个标量损失

class FNO_Class(nn.Module):
   
    def __init__(self, **kwargs):
      '''
       fno baseline
      '''
      super(FNO_Class, self).__init__()
      #转换为元组
      self.base_model = "FNO"
      hidden_channels = kwargs["hidden_channels"]  # 32 or 64
      in_channels = kwargs["in_channels"]  # 512个
      out_channels = kwargs["out_channels"]  # 这应该是 `out_channels`, 而不是 `in_channels`
      self.scale_value = kwargs.get('scales', "1x1_DC") #默认1x1,独立
      self.int_scale_value = int(self.scale_value.split('x')[0]) #1
      self.scale_type = self.scale_value.split('_')[1] #dc means divied and conquer,独立，ksmeans sharing，就是共享
      
      modes_list = [0 for _ in range(self.int_scale_value*self.int_scale_value)] #初始化
      print(kwargs["n_modes"])
      print("***mode_list",modes_list)
      if isinstance(kwargs["n_modes"], list):
        # 处理modes不一致
        #处理n_modes不一致的情况
        n_modes_list = kwargs["n_modes"]
        n_modes = []
        
        # 遍历列表，解析每一对元素
        for i in range(0, len(n_modes_list), 2):
            # 获取第一个和第二个元素，并去除括号 like output： [(32, 32), (128, 128), (64, 64), (16, 16)]
            first_elem = n_modes_list[i].replace('(', '').replace(')', '')
            second_elem = n_modes_list[i+1].replace('(', '').replace(')', '')
            
            # 转换为整数并组合成元组
            parsed_tuple = (int(first_elem), int(second_elem))
            
            # 将解析后的元组加入列表
            n_modes.append(parsed_tuple)
        
        modes_list = n_modes  # 转换为列表
      else:
        #单一但可重复
        modes = eval(kwargs["n_modes"])  # (64,64) fouier
        for i in range(self.int_scale_value*self.int_scale_value):
          modes_list[i] = modes
      print("test_modes_list",modes_list)
      # test 多个专家 
      self.experts_nums = len(modes_list) # 4 or 9
      
      # 再搞个experts的name list
      self.experts_name = []

      if "MOE" in self.scale_value:
        #启动moe
        self.MOE_EN= True
      else:
        self.MOE_EN= False
      print("**MOE_en",self.MOE_EN)
      
      print("***scale_type",self.scale_type)

      ## 根据scale的条件来
      if  self.experts_nums == 1:
        # 就一个operator,输入都不变
        modes = modes_list[0] #list 就1个 (16,16)
        self._operator = FNO(n_modes=modes, hidden_channels=hidden_channels,
                             in_channels=in_channels, out_channels=out_channels)
      else:
        # 多个opertoors
        self.ms_operator = nn.ModuleList()  # 使用 nn.ModuleList 来存储子模块
        
        if self.scale_type == "DC":
          #test-pb
          for i in range( self.experts_nums):
            #独立的添加，根绝expert的数量
            self.experts_name.append(f"FNO_{modes_list[i]}")
            
            self.ms_operator.append(FNO(n_modes = modes_list[i], hidden_channels=hidden_channels,
                                        in_channels=in_channels, out_channels=out_channels))
        elif self.scale_type == "KS":
          # 将这个实例添加到列表中 multiple 次
          sharing_fno = FNO(n_modes=modes, hidden_channels=hidden_channels,
                            in_channels=in_channels, out_channels=out_channels)
          for i in range( self.int_scale_value*self.int_scale_value):
              self.ms_operator.append(sharing_fno)
      
      print("***Paras_",self.count_paras())
 

      if self.MOE_EN:
        #确认input：input_tensor
        #topk 要设置
        #set_size= [10,4,64,64] -ns
        ini_prior = kwargs["ini_prior"]
        #确保先验的长度一致
        assert len(ini_prior) == self.experts_nums, f"Expected {self.experts_nums} experts, but got {len(ini_prior)} in ini_prior."
        print("MOE_Gate_paras_input",kwargs["moe_gate_input_size"])
        gate_size = kwargs["moe_gate_input_size"]
        self.gate =  MoEGate(input_size =  gate_size, 
                             num_experts = self.experts_nums,
                             topk = kwargs["moe_topk"],
                             ini_prior = ini_prior)
        
        # 记录的值，eval的时候可查看, 最后的是专家的权重维度，就记录batch的第一个样本-pengxiao
        self.weight_map =torch.zeros(self.int_scale_value*self.int_scale_value,self.experts_nums) 

      
    @property
    def operator(self):
        return self._operator
    @property
    def Mscale_operator(self):
        # operator 是个list
        return self.ms_operator
    @property
    def parallel_operator(self,re_order):
        pass
   
    def oldMOE_operator(self,x)->list:
        #x:分割好的【batch，10，4，64，64】
        # 计算每个专家的权重
        print("moe_in",x.shape)
        if self.MOE_EN and hasattr(self, 'gate'):
        
          self.weights, self.mask, indices = self.gate(x)
          #print("x",x.shape) #x torch.Size([12, 10, 4, 64, 64])
          out =[]
          for index,expert in enumerate(self.ms_operator):
            #改变一个逻辑，每个patch去经过一次gate，让gate去选择！doing
            # 未来可以对于patch可以是多个
            
            #gate
            
            singel_out = expert(x[:,:,index,:,:])
            out.append(singel_out)
          
          outputs = torch.stack(out,dim=1)#outputs torch.Size([12, 4, 1, 64, 64])

          # 使用掩码和权重结合调整输出，权重过小可能还是不行，相当于专家模型没有得到训练，反向传播有问题，weight变成1可能好
          masked_weights = 1* self.mask  # 仅考虑 topk 专家的权重 #【12，4】 #原来self.weights * self.mask  
          weighted_outputs = outputs * masked_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # 扩展权重维度以匹配输出维度
          final_output = weighted_outputs.sum(dim=1)  # 沿专家维度求和
          load_balancing_loss = self.gate.load_balancing_loss()
        
       # print("final_output",final_output.shape) [12, 1, 4, 128, 128]
          
        return final_output,load_balancing_loss
    
    
    @property
    def load_balancing_loss(self):
      
      load_balancing_loss = self.gate.load_balancing_loss()
      return load_balancing_loss
    
    @property
    def Moe_weight_info(self):
      #weight 应该patch 不一样
      return self.weight_map,self.mask
      
    def forward(self,x):
      pass
    
    def count_paras(self):
      return sum(p.numel() for p in self.parameters())
    
    def MOE_operator0922(self, x, policy="Stragey_1"):
      """
      x: 输入张量 [batch, steps, patches, x_domain, y_domain]
      policy: 可以选择 'Stragey_1' 或 'Stragey_2' 来决定不同的专家选择策略
      """
      batch, steps, patches, x_domain, y_domain = x.shape
      load_balancing_loss = 0
  

      # 如果启用了 MOE 并且存在 gate
      if self.MOE_EN and hasattr(self, 'gate'):
          final_output = []
          
          # 遍历每个 patch
          for index in range(patches):

            # 获取当前的 patch，时间不分割
            patch = x[:, :, index, :, :]  # 输出[batch, steps, x_domain, y_domain]
     
            self.weights, self.mask, indices = self.gate(patch)
            
            # 记录--重点eval第一个样本

            self.weight_map[index,:] = self.weights[0,:]
            
            # 确定 topk 的专家数量
            experts_number = self.gate.top_k
            masked_weights = self.weights * self.mask  # 仅考虑 topk 专家的权重
            
            # 初始化存储专家输出的张量，step都设置为1，roll-out
            experts_output = torch.zeros(batch, 1, x_domain, y_domain, device=x.device)
            all_expert_outputs = []

            # 批量计算所有专家的输出，避免循环逐个计算
            for expert in self.ms_operator:
                expert_output = expert(patch)  # 每个专家对整个 batch 进行计算
                all_expert_outputs.append(expert_output)  # 存储专家的输出
    
            
            # 将所有专家的输出堆叠为一个张量，形状为 [batch_size, num_experts, steps, x_domain, y_domain]
            all_expert_outputs = torch.stack(all_expert_outputs, dim=1)

            
            # 根据策略选择专家并计算最终输出
            if policy == "Stragey_1":
                # 根据 indices 和 masked_weights 选择对应的专家和权重
                for idx in range(experts_number):
                    expert_indices = indices[:, idx]  # 每个 patch 的第 idx 个专家的索引应该是batch个
                    expert_weights = masked_weights[:, idx]  # 对应专家的权重 batch个

          
                    # 确保 expert_indices 的维度合适
                    expert_indices = expert_indices.unsqueeze(1)  # 变成 [batch_size, 1]s
           

                    # 使用 gather 从 all_expert_outputs 中提取对应的专家输出
                    selected_expert_outputs = torch.gather(all_expert_outputs, 1, 
                                                          expert_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, x_domain, y_domain)) #1为step
                    
                    

                    # 根据权重加权专家的输出
    
                    experts_output += selected_expert_outputs.squeeze(1) * expert_weights.view(-1, 1, 1, 1)

                  

            elif policy == "Stragey_2":
                # 在策略2中，你可以实现不同的专家选择逻辑或加权机制
                # 示例：采用某种不同的专家加权机制
                for idx in range(experts_number):
                    expert_indices = indices[:, idx]
                    expert_weights = masked_weights[:, idx] ** 2  # 改变加权方式（平方权重）
                    selected_expert_outputs = torch.gather(all_expert_outputs, 1, 
                                                          expert_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, steps, x_domain, y_domain))
                    experts_output += selected_expert_outputs.squeeze(1) * expert_weights.view(-1, 1, 1, 1) #ps  数据time_step=1
            
            final_output.append(experts_output)

            # 每次迭代计算 load balancing loss
            load_balancing_loss += self.gate.load_balancing_loss()
        
     

          # 堆叠所有 patches 的输出，形状为 [12, 1, 4, 128, 128] or [batch,1,4,64,64]
          final_output = torch.stack(final_output, dim=2)

        
          return final_output, load_balancing_loss
    
    def MOE_operator_0927(self, x, policy="Strategy_1"):
        """
        x: 输入张量 [batch, steps, patches, x_domain, y_domain]
        policy: 可以选择 'Strategy_1' 或 'Strategy_2' 来决定不同的专家选择策略
        """
        batch, steps, patches, x_domain, y_domain = x.shape
        load_balancing_loss = 0

        # 获取 gate 的输出
        patch_weights, self.mask, indices = self.gate(x)  # 返回 [batch, patches, num_experts]
    
        self.weight_map = patch_weights[0,:,:]
        # 初始化输出张量
        outputs = torch.zeros(batch, patches, x_domain, y_domain, self.experts_nums, device=x.device)
        print("out",outputs.shape)
        # 重塑 x 以方便索引
        x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(batch * patches, steps, x_domain, y_domain)  # [batch*patches, steps, x_domain, y_domain]
        
        print("x_reshape",x_reshaped.shape)

        for i in range(self.experts_nums):
            # 获取专家 i 的 mask，形状为 [batch, patches]
            expert_mask = self.mask[:, :, i]  # [batch, patches]
            expert_mask_flat = expert_mask.reshape(-1)  # [batch * patches]
            selected_indices = torch.nonzero(expert_mask_flat).squeeze(1)  # [num_selected_patches]

            if selected_indices.numel() == 0:
                continue  # 如果没有分配给该专家的 patches，跳过

            # 选择对应的 x 数据
            x_selected = x_reshaped[selected_indices]  # [num_selected_patches, steps, x_domain, y_domain]

            # 通过专家模型处理
            output_selected = self.ms_operator[i](x_selected)  # [num_selected_patches, 1, x_domain, y_domain]

            # 计算对应的 batch_indices 和 patch_indices
            batch_indices = selected_indices // patches
            patch_indices = selected_indices % patches

            # 将输出分配到 outputs 张量的对应位置
            outputs[batch_indices, patch_indices, :, :, i] = output_selected[:, 0, :, :]

        # 使用 einsum 对专家的输出进行加权求和
        # outputs: [batch, patches, x_domain, y_domain, num_experts]
        # patch_weights: [batch, patches, num_experts]
        # 计算加权输出
        weighted_outputs = torch.einsum('bpxyi,bpi->bpxy', outputs, patch_weights)  # [batch, patches, x_domain, y_domain]

        # 添加时间维度，变为 [batch, 1, patches, x_domain, y_domain]
        weighted_outputs = weighted_outputs.unsqueeze(1)

        return weighted_outputs, load_balancing_loss
    def MOE_operator(self, x, policy="Strategy_1"):
      """
      x: Input tensor with shape [batch, steps, patches, x_domain, y_domain]
      policy: Can choose 'Strategy_1' or 'Strategy_2' to decide different expert selection strategies
      """
      batch, steps, patches, x_domain, y_domain = x.shape
      load_balancing_loss = 0

      # Get the gate's output
      patch_weights, self.mask, indices = self.gate(x)  # Returns [batch, patches, num_experts]
      self.weight_map = patch_weights[0, :, :]

      # Initialize the output tensor with an additional 'steps' dimension
      # New shape: [batch, steps, patches, x_domain, y_domain, num_experts]
      outputs = torch.zeros(
          batch,
          steps,
          patches,
          x_domain,
          y_domain,
          self.experts_nums,
          device=x.device
      )
    

      # Reshape x to facilitate indexing
      # Original x shape: [batch, steps, patches, x_domain, y_domain]
      # New x_reshaped shape: [batch * patches, steps, x_domain, y_domain]
      x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(batch * patches, steps, x_domain, y_domain)
      

      # Reshape outputs to [batch * patches, steps, x_domain, y_domain, num_experts]
      outputs = outputs.permute(0, 2, 1, 3, 4, 5).reshape(batch * patches, steps, x_domain, y_domain, self.experts_nums)

      for i in range(self.experts_nums):
          # Get expert i's mask, shape: [batch, patches]
          expert_mask = self.mask[:, :, i]  # [batch, patches]
          expert_mask_flat = expert_mask.reshape(-1)  # [batch * patches]
          selected_indices = torch.nonzero(expert_mask_flat).squeeze(1)  # [num_selected_patches]

          if selected_indices.numel() == 0:
              continue  # Skip if no patches are assigned to this expert

          # Select the corresponding x data
          # x_selected shape: [num_selected_patches, steps, x_domain, y_domain]
          x_selected = x_reshaped[selected_indices]  # [num_selected_patches, steps, x_domain, y_domain]

          # Process through the expert model
          # Output shape: [num_selected_patches, steps, x_domain, y_domain]
          output_selected = self.ms_operator[i](x_selected)  # [num_selected_patches, steps, x_domain, y_domain]

          # Assign the expert's output to the outputs tensor
          # outputs shape: [batch * patches, steps, x_domain, y_domain, num_experts]
          outputs[selected_indices, :, :, :, i] = output_selected

      # Reshape outputs back to [batch, steps, patches, x_domain, y_domain, num_experts]
      outputs = outputs.reshape(batch, patches, steps, x_domain, y_domain, self.experts_nums).permute(0, 2, 1, 3, 4, 5)

      # Use einsum to perform weighted sum over experts
      # outputs shape: [batch, steps, patches, x_domain, y_domain, num_experts]
      # patch_weights shape: [batch, patches, num_experts]
      # Desired weighted_outputs shape: [batch, steps, patches, x_domain, y_domain]
      weighted_outputs = torch.einsum('bspxyi,bpi->bspxy', outputs, patch_weights)  # [batch, steps, patches, x_domain, y_domain]

    
      return weighted_outputs, load_balancing_loss
