


from re import A, T
from sympy import fraction
import torch

## this is for expriments##
import yaml 
from abc import abstractmethod
import argparse
import torch
from torch import nn
import os
import sys 
import numpy as np
from tqdm import tqdm 
import wandb

import gc

import torch.nn.functional as F

class SimAM(nn.Module):
    def __init__(self, lambda_value):
        super(SimAM, self).__init__()
        self.lambda_value = lambda_value

    def forward(self, X):
        # spatial size
        n = X.shape[2] * X.shape[3] - 1
        # square of (t - u)
        d = (X - X.mean(dim=[2,3], keepdim=True)).pow(2)
        # d.sum() / n is channel variance
        v = d.sum(dim=[2,3], keepdim=True) / n
        # E_inv groups all importance of X
        E_inv = d / (4 * (v + self.lambda_value)) + 0.5
        # return attended features
        return X * torch.sigmoid(E_inv)
    
current_working_directory = os.getcwd()
print("***now work dir:***\n", current_working_directory) # 
sys.path.append(current_working_directory)
import shutil
import time


from Torch_utils.utils3 import save_checkpoint,load_checkpoint
from Torch_utils.utils3 import LpLoss,Acc_Error,inference_unit_time
import matplotlib.pyplot as plt 


def gaussian_blur(img, kernel_size=5, sigma=1):
    # 生成高斯核
    from math import pi
    x_coords = torch.arange(kernel_size) - kernel_size // 2
    x_grid = x_coords.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    dist_sq = x_grid**2 + y_grid**2
    gaussian_kernel = torch.exp(-dist_sq / (2 * sigma ** 2)) / (2 * pi * sigma ** 2)
    gaussian_kernel = gaussian_kernel / gaussian_kernel.sum()
    
    # 扩展核维度以匹配输入的 'channels'
    channels = img.shape[1]
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
    
    # 使用padding保持图像尺寸
    padded_img = F.pad(img, [kernel_size//2, kernel_size//2, kernel_size//2, kernel_size//2], mode='reflect')
    blurred_img = F.conv2d(padded_img, gaussian_kernel, groups=channels)
    return blurred_img

def forward_1(self,input_tensors,
              static_tensors = None,
              boundary_tensors = None,
              pred_steps=[1],
              multi_sclale=1,**kwargs):
   """
   新增的方法，用于增加对象的值,
   model 的输入永远都是10个时间步,输出就是1个时间步
   Recon:True means that we inputp[b,10,64,64]->[b,-1,64,64]
   """
   
   if input_tensors.dim()==4:
      batch,time_stpes,nx,ny = input_tensors.shape
   elif input_tensors.dim()==3:
      batch,time_stpes,nx = input_tensors.shape
      self.encoder.latent_shape = [512,3] #经过测试


   #encoder
   conv_grid_n0 = self._modules['encoder'].conv_grid_n0

   # Append the SimAM module to the existing Sequential block
   if kwargs.get("simam", False):
      # Instantiate the SimAM module with the desired lambda value
      simam = SimAM(lambda_value=1e-3)
      conv_grid_n0.add_module('simam', simam)
   #常规处理
   processed_tensor = conv_grid_n0(input_tensors)

   if multi_sclale >1:
      '''
         multi-scale for image
         bec: latent size=256,so 256/4=64
         logstic: 1.splict 2.cat
      '''
      sub_tensors = []
      for i in range(multi_sclale):#每个边都是
            for j in range(multi_sclale):
               domain= int(nx/multi_sclale)
               sub_tensors.append(input_tensors[:, :, i*domain:(i+1)*domain, j*domain:(j+1)*domain])
      processed_sub_tensors = []

      for sub_tensor in sub_tensors:
            processed_sub_tensor = conv_grid_n0(sub_tensor)
            processed_sub_tensors.append(processed_sub_tensor)
      # 先按照高度维度（dim=2）拼接
      row_tensors = []
      for i in range(multi_sclale):
         row_tensor = torch.cat(processed_sub_tensors[i*multi_sclale:(i+1)*multi_sclale], dim=2)
         row_tensors.append(row_tensor)

        # 再按照宽度维度（dim=3）拼接
      processed_tensor = torch.cat(row_tensors, dim=3)
      

      
   dy_encoder_out = self._modules['encoder'](processed_tensor) # [B,256]
   #evol
   evo_Op = self._modules['evolution_op']

   #decoder
   decoder_list =  self._modules['decoder_n0'].deconv_list 
   decoder_n0 = self._modules['decoder_n0'].fc    #输出：[B,4096]
   
   #pred 的维度
 
   #info 记录一些latent的信息 
   info = {"recon":None,"latent_tensors":None}
   #latent forward and for pred

   
   if static_tensors != None:

      cat_tensors = torch.cat(static_tensors,dy_encoder_out)

   else:
      cat_tensors = dy_encoder_out


   outputs = []  # 初始化一个空列表来存储输出
   

   #pred_steps =[1]
   for i in pred_steps:
      
      latent = cat_tensors #【B，256】
      
      #不经过evo，直接decoder就是recon
      recon = decoder_n0(latent) #这里需要经过deconv——list 有问题-改进
      
      
      Latent_op = evo_Op(latent) #[]
      output = decoder_n0(Latent_op)
      
      if input_tensors.dim()==3:
         
         output = output.reshape(batch,*self.encoder.latent_shape) #[batch,512,3]
         recon = recon.reshape(batch,*self.encoder.latent_shape) 
      elif input_tensors.dim()==4:
         output = output.reshape(batch,self.latent_size,*self.encoder.latent_shape) #[batch,256,4,4]
         recon = recon.reshape(batch,self.latent_size,*self.encoder.latent_shape) 
      
      for module in decoder_list:
         output = module(output)
         recon = module(recon)
         
     
      cat_tensors = Latent_op #更新latent
      outputs.append(output)  # 将当前output添加到列表中



   # 在循环结束后，使用torch.stack将所有输出堆叠起来
   outputs = torch.stack(outputs, dim=1)
   
   if input_tensors.dim()==4:
      
      _,_,_,o_x,o_y =outputs.shape
      recon = recon.reshape(-1,1,o_x,o_y)
      
      outputs = outputs.squeeze(2)  # 移除第三个维度，这里的维度索引是从0开始的
      if output.shape[-1]!=ny or output.shape[-2]!=nx :
         #使用双线性插值进行尺寸调整,注意，只能4d
         recon = F.interpolate(recon, size=(nx, ny), mode='bilinear', align_corners=False)
         outputs= F.interpolate(outputs, size=(nx,ny), mode='bilinear', align_corners=False)
  
      recon =recon.reshape(-1,1,nx,ny)
   
   elif input_tensors.dim()==3:
      _,_,_,o_x =outputs.shape
      
      recon = recon.reshape(-1,1,1,o_x)
      if output.shape[-1]!=nx:
         recon = F.interpolate(recon, size=(1,nx), mode='bilinear', align_corners=False)
         outputs= F.interpolate(outputs, size=(1,nx), mode='bilinear', align_corners=False)
         
      recon =recon.reshape(-1,1,nx)
      outputs = outputs.squeeze(2)  # 移除第三个维度，这里的维度索引是从0开始的
   
   
   
   info["recon"] = recon

   assert outputs.shape[1] == len(pred_steps) #[batch,1,64,64]

   

   return outputs,info

def rollout_preds(self, dyn_input, steps:int,**kwargs):

   '''
   model: output 1 steps (b,1,64,64)
   ini_inputs: shape (batch,seq_lenth,input_dimension),2d' case(b,10,64,64)
   steps: pred_steps int
   out_put: [B,steps,64,64]
   kwargs: 1.static parameter 2.boudary

   '''

   predictions  = []
   input_seq = dyn_input

   #对identification 任务进行处理，分别进入两个一样配置的encoder，但是是独立的，然后合并，之后取演化，decoder之后加全链接


   recons = []
   recon_true = []
   
   if kwargs.get("static") is not None:
    # 如果 'static' 存在且其值不是 None，执行的代码
      # Instantiate the SimAM module with the desired lambda value
      static_tensors =kwargs["static"]
      static_tensors = None
      
   else:
      
      static_tensors = None
      
   if kwargs.get("boundary") is not None:
      
      boundary_tensor = kwargs["boundary"] #[batch,512,512]
      boundary_tensor= None
 
      
   else:
      
      boundary_tensor = None
   
   
   for _ in range(steps):
      #recon_true.append(input_seq[:, -1, :,:].unsqueeze(1)) #[b,1,64,64]最后一步
     
      recon_true.append(input_seq[:, -1].unsqueeze(1)) #[b,1,64,64]

      prediction, info = self.forward(input_seq,static_tensors,boundary_tensor)
      

      if self.base_model == "KNO":
         prediction = prediction[0]#[batch,4, 128, 128, 1]
         prediction = prediction.permute(0,3,1,2)
         
      predictions.append(prediction)  #标准为[batch,1,64,64]
      recons.append(info["recon"])
    
      # 更新输入序列，删除最旧的时间步，添加预测的时间步

      # 正常的模式--要应对moe的scale的模式-没写完
      new_input = torch.cat((input_seq[:, 1:], prediction), dim=1)
   

      input_seq = new_input
      

   # 使用 torch.cat 沿着第 1 维度进行拼接
   pred_aggre=  torch.cat(predictions, dim=1)
   recon_aggre = torch.cat(recons,dim=1)
   #recon_true 每次输入的最后一步
   recon_true = torch.cat(recon_true,dim=1)
   
#  如果预先scale的数据集，记得在这里合并 pred_aggre 会变成batch，10，4，64，64
   if len(input_seq.shape) == 5:
      
      time = pred_aggre.shape[1]
      scales = pred_aggre.shape[2]
      height =  pred_aggre.shape[-1]
      width = pred_aggre.shape[-2]
      scale = int(np.sqrt(scales))# 2
      patch_size = int(height / scale)  # 64/2  = 32
      
      #reshape 一下，batch维度变多
      combined_tensor = pred_aggre.reshape(-1,1,height,width)
      resized_tensor = F.interpolate(combined_tensor, size=(patch_size, patch_size), mode='bilinear', align_corners=False)
      # reshape回去
      resized_tensor = resized_tensor.reshape(-1, time, scale*scale, patch_size, patch_size) #[batch,1,4,32,32]
      
      pred_aggre = reshape_and_concat_patches(resized_tensor, num_patches=scale*scale)
      
      # recon不去管

   return pred_aggre,recon_aggre,recon_true

def multi_field_rollout_preds(self,dyn_input,control_u,steps:int,**kwargs):
   '''
   加入对identification任务如果输出(b,100,2,31,31)表示2两个场,输出还需要加入全链接层，变成标量
   control_u :【b,100,1】表示100时刻的u
   模型的rollout输出:未来n步,对于两个filed的预测和control_u的预测
   '''
      
   if dyn_input.dim() == 5:
      #
      field_1 = dyn_input[:,:,0,:,:] #[b,100,31,31]
      field_2  = dyn_input[:,:,1,:,:]
   if kwargs.get("static") is not None:
    # 如果 'static' 存在且其值不是 None，执行的代码
      # Instantiate the SimAM module with the desired lambda value
      static =kwargs["static"]
      static= None
      
   else:
      static = None
   
   pool_layer = self._modules["identification_layer"]
      
   if kwargs.get("boundary") is not None:
      
      boundary = kwargs["boundary"] #[batch,512,512]
      boundary = None
   
   #用同一个encoder
      
   field_1_roll,field_1_recon_aggre,field_1_recon_true = self.rollout_preds(field_1,
                                                                           steps=steps,
                                                                           static = static,
                                                                           boundary=boundary)
   field_2_roll,field_2_recon_aggre,field_2_recon_true = self.rollout_preds(field_2,
                                                                           steps=steps,
                                                                           static = static,
                                                                           boundary=boundary)
   
   #结合一下
   field_out = torch.stack((field_1_roll,field_2_roll),dim=2) #【b，n，2，31，31】
   # identification->[n]
   identification_rollout = pool_layer(field_out)
   
   # 自适应平均池化层到一个1x1的输
  # self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
   
   return [field_1_roll,field_1_recon_aggre,field_1_recon_true ],[field_2_roll,field_2_recon_aggre,field_2_recon_true],identification_rollout
def visualize_patches(patches, title="Patches"):
    batch, time, num_patches,height, width = patches.shape
    fig, axes = plt.subplots(1, num_patches, figsize=(15, 5))

    for i in range(num_patches):
        # 绘制每个patch
        patch = patches[0, 0, i].detach().cpu().numpy()  # 选择batch=0, time=0的patch进行可视化
        axes[i].imshow(patch, cmap='jet')
        axes[i].set_title(f'Patch {i}')
        axes[i].axis('off')

    plt.suptitle(title)
    plt.savefig(f"{title}.png")
    plt.close()

# 对插值前后的patch进行可视化
import math
def reshape_and_concat_patches(resized_tensor, num_patches):
    # 计算每边的 patch 数量，假设 num_patches 是完全平方数
    scale = int(math.sqrt(num_patches))
    assert scale * scale == num_patches, "num_patches 必须是完全平方数"
    
    patch_size = resized_tensor.shape[-1]  # 假设所有 patch 都是正方形
    
    # 分离出每个 patch
    patches = [resized_tensor[:, :, i, :, :] for i in range(num_patches)]
    
    # 按 scale 行来分组 patch
    rows = [torch.cat(patches[i * scale:(i + 1) * scale], dim=3) for i in range(scale)]
    
    # 将所有行垂直拼接起来形成最终的图像
    output = torch.cat(rows, dim=2)
    
    return output
 
def reshape_and_concat_patchesnew(resized_tensor, num_patches):
   # 假设 input_tensor 的形状是 [batch_size, time_steps, scale, patch_size, patch_size]
    batch_size, time_steps, scale, patch_size, _ = resized_tensor.shape
    
    # 计算 sqrt_scale，将 scale 视为完全平方数的结果
    sqrt_scale = int(math.sqrt(scale))
    assert sqrt_scale * sqrt_scale == scale, "scale 必须是完全平方数"

    # 重新 reshape 张量，使其方便拼接为 [batch_size, time_steps, sqrt_scale, sqrt_scale, patch_size, patch_size]
    input_tensor = resized_tensor.view(batch_size, time_steps, sqrt_scale, sqrt_scale, patch_size, patch_size)

    # permute 操作来交换维度，将 scale 维度 (sqrt_scale, sqrt_scale) 移到最后，方便拼接
    # 结果的形状是 [batch_size, time_steps, patch_size, sqrt_scale, patch_size, sqrt_scale]
    input_tensor = input_tensor.permute(0, 1, 4, 2, 5, 3)

    # 重新 reshape，直接拼接成 [batch_size, time_steps, patch_size * sqrt_scale, patch_size * sqrt_scale]
    output = input_tensor.reshape(batch_size, time_steps, patch_size * sqrt_scale, patch_size * sqrt_scale)
    
    return output


def baseline_forward_1(self,input_tensors,
                           static_tensors = None,
                           boundary_tensors = None,
                           pred_steps=[1],
                           multi_sclale=1,**kwargs):
   '''
   there is to test the baselines
   if 是3d都,rollout就需要改变下
   '''
   if self.base_model == "KNO": #kno 需要时间在后面[batch,64,64,10]
      input_tensors = input_tensors.permute(0,2,3,1)
   #divid and conquer
   if "DC" in self.scale_value or "KS" in self.scale_value:
      #根据scale
      if  self.int_scale_value == 1 :
         #1个尺度

         self.out = self._operator(input_tensors)


      else:
         #多个尺度，先把input tensor按照空间划分
         ## 分割输入张量
         batch, time, height, width = input_tensors.shape
         scale = self.int_scale_value
         if height % scale != 0 or width % scale != 0:
               raise ValueError("Image dimensions must be divisible by scale value")

         # 定义patch大小
         patch_size = height // scale
         # 使用 unfold 分割输入
         patches = input_tensors.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
         patches = patches.contiguous().reshape(batch, time, scale*scale, patch_size, patch_size) #ns :patches.shape torch.Size([20, 10, 4, 32, 32])
         # patches.shape torch.Size([batch, 10, 4, 32, 32]) interpo->[batch,10,4,64,64]
         # 调整维度以合并批次、时间和补丁维度
         patches = patches.view(-1, 1, patch_size, patch_size)  # -1 会自动计算合适的尺寸
         
         # 执行一次性插值
         upsampled_patches = F.interpolate(patches, size=(height, width), mode='bicubic', align_corners=False)

         # 恢复原始的批次、时间和补丁维度
         upsampled_patches = upsampled_patches.reshape(batch, time, scale*scale, height, width)
         
         #visualize_patches( upsampled_patches ,title="upsampled_patches")
      
         # 处理每个patch 用operator
         single_out =[]
         for i in range(scale*scale):
            # 空间上还是64，64
            single_scale_out = self.ms_operator[i](upsampled_patches[:,:,i,:,:])
            single_out.append(single_scale_out)
            
         # 需要聚集stack
         # 使用 torch.stack 在新的维度上堆叠张量
         combined_tensor = torch.stack(single_out, dim=2)  # 【batch,1,4(sclae),64,64】
         
         #reshape
         combined_tensor = combined_tensor.reshape(-1,1,height,width)
         resized_tensor = F.interpolate(combined_tensor, size=(patch_size, patch_size), mode='bicubic', align_corners=False)
         
         resized_tensor = resized_tensor.reshape(-1, 1, scale*scale, patch_size, patch_size)#[batch,1,4,32,32]
         # 调用函数处理
         output = reshape_and_concat_patches(resized_tensor, num_patches=scale*scale)

         #visualize_patches(resized_tensor ,title="out")
         self.out = output

      
      
   #暂不考虑recon
   recons =  torch.zeros_like(input_tensors,device=input_tensors.device)
   info = {"recon":recons,"latent_tensors":None}
   return self.out, info
                                                         
def Moe_forward(self,input_tensors,static_tensors = None,
                           boundary_tensors = None,
                           pred_steps=[1],
                           multi_sclale=1,**kwargs):
   '''
      the moe will ensemble the different tasks
   '''

   if "MOE" in self.scale_value:
      
      #预先切分--好,predication 输出的预测也暂时不要合并
      if len(input_tensors.shape)==5: 
      
         batch, time, patches, height, width = input_tensors.shape
         patch_size = height
         scale = int(np.sqrt(patches))
         upsampled_patches = input_tensors
         
         
      else:
         #根据scale
         #多个尺度，先把input tensor按照空间划分
         ## 分割输入张量
         batch, time, height, width = input_tensors.shape
      
         scale = self.int_scale_value
         if height % scale != 0 or width % scale != 0:
               raise ValueError("Image dimensions must be divisible by scale value")

         # 定义patch大小
         patch_size = height // scale
         # 使用 unfold 分割输入
         patches = input_tensors.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
         patches = patches.contiguous().reshape(batch, time, scale*scale, patch_size, patch_size) #ns :patches.shape torch.Size([20, 10, 4, 32, 32])
         # patches.shape torch.Size([batch, 10, 4, 32, 32]) interpo->[batch,10,4,64,64]
         # 调整维度以合并批次、时间和补丁维度
         patches = patches.view(-1, 1, patch_size, patch_size)  # -1 会自动计算合适的尺寸
         
         # 执行一次性插值
         upsampled_patches = F.interpolate(patches, size=(height, width), mode='bilinear', align_corners=False)

         # 恢复原始的批次、时间和补丁维度
         upsampled_patches = upsampled_patches.reshape(batch, time, scale*scale, height, width)
         
         #visualize_patches( upsampled_patches ,title="upsampled_patches")
         

   
      moe_out,load_loss = self.MOE_operator(upsampled_patches) # moe out [batch,1,4,128,128] 对于ns3d [batch,10,16,64,64]

   
      
      #预先多尺度的话切分--好,predication 输出的预测也暂时不要合并
      if len(input_tensors.shape)==5: 
         
         output = moe_out # [batch,1,4,128,128]
         
      else: #正常逻辑要合并--但很慢

         #reshape
         combined_tensor = moe_out.reshape(-1,1,height,width)
         resized_tensor = F.interpolate(combined_tensor, size=(patch_size, patch_size), mode='bicubic', align_corners=False)
         
         resized_tensor = resized_tensor.reshape(-1, 1, scale*scale, patch_size, patch_size)#[batch,1,4,32,32]
         # 调用函数处理
         output = reshape_and_concat_patches(resized_tensor, num_patches = scale*scale)
      
      #visualize_patches(resized_tensor ,title="out")
      self.out = output
      
      # out 应该【12，1，64，64】
   
      #暂不考虑recon

   recons =  torch.zeros_like(input_tensors,device=input_tensors.device)
   
   info = {"recon":recons,"latent_tensors":None,"MOE_loss":load_loss}
   
   return  self.out ,info
    
      
class TrainingManager:
   
   def __init__(self, save_dir, save_name,
                best_suffix='_best.pt', device='cuda:0', patience=50,
                **kwargs):


       ## template_name

      self.save_dir = save_dir
      self.save_model_name = save_name
      print("***template_name: \n***", self.save_model_name)
      self.best_suffix = best_suffix 
      self.device = device
      self.patience = patience
      self.best_loss = float('inf')
      self.no_improve_epochs = 0

      self.source_file = kwargs.get("source_file","None")
   
   def check_improvement(self, loss):
      
      if loss < self.best_loss:
         self.best_loss = loss
         self.no_improve_epochs = 0
         return True
      self.no_improve_epochs += 1
      
      return False

   def should_stop(self):
       
      return self.no_improve_epochs >= self.patience

class Expr():
   @abstractmethod
   def  __init__(self):
      pass
   
   @abstractmethod
   def train(self):
      pass
   
   @abstractmethod
   def valid(self):
      pass
   
   @abstractmethod
   def test(self):
      pass
   
   @abstractmethod
   def record(self):
      pass
   
   @abstractmethod
   def resume(self):
      pass
   
   @abstractmethod
   def _pre_model(self):
      
      pass
   @abstractmethod
   def _pre_dataset(self):
      pass
   
   def run(self):
      self.train()
      self.valid()
      self.test()

class Lepde_expr(Expr):
   '''
      THIS IS FOR Lepde TASK: AI4SCI
      CONFIG : a dict type by yaml file
      yaml:core setting file
         -dataset_path:"../../.h5"
         -milestones(check)[0,100...] #check the traing loss undering epoch
         -model
            -core parametes # do not need details just what is input, what is output, and what is the backbone?
         -log
            project: '' #eg. hsps/moe/lepde++
            group: '' #eg. china/dixia
         -train
            -epoch #training
            -batch # batch size
            -lr: # decide the optimzer 
            -ckpt_path # model adn optimizer save in .pth
            -loss weight #hyper parameter
               -physical loss
               -data loss
               -ic loss
         -test
            -save folder # save the loss or figures(torchvision)

   '''
   def __init__(self,config:dict,wandb_en=True,load=False):
      super(Lepde_expr,self).__init__()
      self.config = config
      self.task = self.config['data']['task']
      self.device = self.config['train']['device']
      self.wandb_en = wandb_en
      #seed
      self._seed()
      #load model 
      if load == False:
         
         if "le_pde" in self.config["model"]["name"]:
            self.model = self._pre_le_pde( self.device)
         elif "Baseline" in self.config["model"]["name"]:
            self.model = self._pre_model(self.device)
        
      else:
         self.model = self._pre_le_pde( self.device)
         self.resume()
         print("we have load from check_point")
      #名字更改
      print("t_manager_name", t_manager.save_model_name)
      self.config['train']['save_name'] = t_manager.save_model_name.replace('.pt', 'check_load_.pt')
      print("test:",self.config['train']['save_name'])
      # manager
      if wandb_en ==True:
        #key
         from dotenv import load_dotenv
         dotenv_path = self.config['log_run']['dot_key_path']

        # 加载.env文件
         key = load_dotenv(dotenv_path)
         api_key = os.getenv("Wandb_api")
         print("wandb_key",api_key)
         wandb.login(key = api_key)
         
      #一些定量结果的保存
      self.important_dict = {"model_paras":0,
                             "rmse_best_test":1,
                             "mae_best_test":1,
                             "inference_time(ms)":0,
                             "memo":None
                             }
      
   def curriculum_mask_torch(self,sequence_length, current_epoch, total_epochs, initial_ratio=0.1):
      """
      Create a curriculum learning mask for autoregressive models using PyTorch,
      where the initial true values start from the beginning of the sequence and gradually increase.

      Args:
      sequence_length (int): The full length of the sequence
      current_epoch (int): The current training epoch
      total_epochs (int): Total number of training epochs
      initial_ratio (float): Initial ratio of the sequence to unmask (default: 0.1)

      Returns:
      torch.Tensor: A boolean tensor where True indicates the position should be predicted
      """
      # Calculate the current ratio of the sequence to unmask
      self.current_ratio = initial_ratio + (1 - initial_ratio) * (current_epoch / total_epochs)
      
      # Calculate how many positions to unmask
      unmask_length = int(sequence_length *  self.current_ratio)
      
      # Create the mask
      # Create the mask directly on the specified device
      mask = torch.zeros(sequence_length, dtype=torch.bool, device=self.device)
      mask[:unmask_length] = True


      return mask
      


   
   def _seed(self,seed=42):

      """固定随机种子以确保实验的可重复性"""
   
      torch.manual_seed(seed)  # 为CPU设置PyTorch随机种子
      torch.cuda.manual_seed(seed)  # 为所有CUDA设备设置PyTorch随机种子
      torch.cuda.manual_seed_all(seed)  # 为所有CUDA设备设置相同的随机种子（与上行功能重复，用于兼容）
      torch.backends.cudnn.deterministic = True  # 确保CUDA的卷积算法具有确定性
      torch.backends.cudnn.benchmark = False  # 关闭优化的卷积计算方法，因为它们可能不是确定性的
      

   def _set_optimizer(self)->torch.optim.Optimizer:

      optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config["train"]["base_lr"],betas=(0.9, 0.999)) # optimizer

      return optimizer
   
   
   def train(self):
      print("***train***")
      
      train_loader,test_loader = self._pre_dataset()
      optimizer = self._set_optimizer()
      
      # 保存到制定的文件夹
      # Full path to the file
      save_dir = self.config['train']['save_dir'] 
      os.makedirs(save_dir, exist_ok=True)
      file_path = os.path.join(save_dir, 'sweep_config.yaml')
      # 保存 sweep_config 到文件
      with open(file_path, 'w') as file:
         yaml.dump(self.config, file)

      #shutil the expr.py
      expr_file_path = os.path.realpath(__file__)
      # destination
      expr_destination_file = os.path.join(save_dir, os.path.basename(expr_file_path))
      #copy expr
      shutil.copy(expr_file_path, expr_destination_file)
      print("###we have successfully saved the source expr file！")

      #train data_paraller
 
      if self.config["train"]["train_type"] == "accelerate":
         from accelerate import Accelerator
         # 初始化 Accelerator
         accelerator = Accelerator()
         # 准备模型、优化器和数据加载器
         self.model, optimizer, train_loader,test_loader = accelerator.prepare(self.model, optimizer, train_loader,test_loader)
         print("***accelerate加速！")
   
      else:
      #单独gpu
         self.model = self.model.to(self.device)
         print("single gpu")
      self.model.train()

      if self.config['train']['lr_scheduler'] == 'CosineAnnealingLR':
      
         scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max= 20, eta_min=1e-5)

      elif self.config['train']['lr_scheduler'] == 'StepLR':
         # 设置学习率调度器，例如 StepLR
         scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.99)


      pbar = range(config['train']['epochs'])
      
      pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.1)

      loss = LpLoss(d=2,p=2,reduction=True)



      for e in pbar:
      # set 0
         all_l2 = 0
         train_loss = 0
         recon_loss = torch.tensor(0.0)
         sampling_loss = torch.tensor(0.0)
         step_loss = torch.tensor(0.0)
         router_loss = torch.tensor(0.0)
         

         for inputs, outputs, g_t, condition,boundary in train_loader:

            if isinstance(inputs, list)==False:
               inputs = inputs
               outputs = outputs
               condition = condition
               boundary = boundary
            
            optimizer.zero_grad()

            # 进行rollout
            if self.task == "2d_Forward_NS_V1e-5_T20_In_10_OUT_1":
               #use PAST 10 stepS to pred  1 STEP,but rollout to the next 10 steps
               
               ## 如果是fno3d，直接一次就是输出为10个是时间
               if self.config["model"]["parameters"]["3d"] == True and  "3d" in self.config["model"]["parameters"]:

                  self.roll_out_steps = 1
               
               else:

                  self.roll_out_steps = outputs.shape[1] #10步
               
            
               model_out,recons,recon_true = self.model.roll_out(inputs,
                                                               steps = self.roll_out_steps)
               
               b,t,nx,ny =model_out.shape
   
            elif self.task =="2d_Forward_PS_T2_In_1_OUT_1":
               #possion 是1步
               self.roll_out_steps = outputs.shape[1]
               
               model_out,recons,recon_true = self.model.roll_out(inputs,
                                                               steps = self.roll_out_steps)
               b,t,nx,ny = model_out.shape
            elif self.task == "2d_Forward_Pollu_T181_In_10_OUT_1" or self.task == "2d_Reverse_Pollu_T181_In_10_OUT_1":
               #use PAST 10 stepS to pred  1 STEP,but rollout to the next 10 steps
           
               self.roll_out_steps = outputs.shape[1]
               # we could remove the initial by minus the first step
               in_t = inputs.shape[1]

               out_t = outputs.shape[1]
               new_inputs = inputs - boundary[:,0:in_t,:,:]
               outputs = outputs - boundary[:,in_t:in_t+out_t,:,:] #output也减去,相当于学习光滑d
 
               model_out,recons,recon_true = self.model.roll_out(new_inputs,
                                                                 steps = self.roll_out_steps,
                                                                 static = condition,
                                                                 boundary = boundary)
               

               b,t,nx,ny = model_out.shape
            elif self.task == "2d_Forward_SST_T52_In_10_OUT_1":
               #use PAST 10 stepS to pred  1 STEP,but rollout to the next 10 steps
               self.roll_out_steps = outputs.shape[1]
               # in sst, the boundary has separate,已经光滑
               in_t = inputs.shape[1]
            
               new_inputs = inputs 
               outputs = outputs
               
               model_out,recons,recon_true = self.model.roll_out(new_inputs,
                                                                  steps = self.roll_out_steps,
                                                                  static = condition,
                                                                  boundary = boundary)
               b,t,nx,ny =model_out.shape
            elif self.task == "1d_Forward_Wu_T250_In_1_OUT_1":
               #1d: use PAST 50 stepS to pred  1 STEP,but rollout to the next 200 steps
               self.roll_out_steps = outputs.shape[1]
               in_t = inputs.shape[1]
               new_inputs = inputs 
               outputs = outputs
               #[a,b,r] is condition
               model_out,recons,recon_true = self.model.roll_out( new_inputs,
                                                                  steps = self.roll_out_steps,
                                                                  static = condition,
                                                                  boundary = boundary)
               b,t,nx =model_out.shape
            elif self.task == "2d_Identification_Cylinder_T800_In_100_OUT_1":
               pass

            elif self.task == "2d_Identification_Physical_SWE_In10_Out_1":
               '''
                  预测出场和H,要加入physcial loss
                  根据物理方程
               '''
               pass
            
            ## currium_sampling
            if self.config["train"]["currium_sampling"]["EN"]:
               '''
               Currium sampling
               '''
               ratio = self.config["train"]["currium_sampling"]["ini_value"]
               
               t = model_out.shape[1] #sequence length
               # mask (int)
               mask = self.curriculum_mask_torch(t,e,config['train']['epochs'],initial_ratio=ratio)

               filtered_outputs = model_out[:, mask]
         
               filtered_targets = outputs[:,mask]

               sampling_loss = loss.rel(filtered_outputs,filtered_targets)
            else:
               # loss about all steps loss
               self.current_ratio = 1

               sampling_loss = loss.rel(model_out,outputs)
            if self.config["train"]["recon_loss"]["EN"]:

               '''
                  recon loss-weight-1
               '''   
               print("***",recon_true.shape)
               print(recons.shape)
               exit()
               recon_loss = loss.rel(recons, recon_true)*1

               recon_loss =  recon_loss
            if self.config["train"]["multi_step_loss"]["EN"]:

               weights_list = self.config["train"]["multi_step_loss"]["value"]

               # 计算 time_loss，假设 time_loss 是形状为 [batch_size, steps, height, width] 的张量
               for index,value in enumerate(weights_list):
                  single_loss = loss.rel(model_out[:, index], outputs[:, index])
                  step_loss += single_loss * value 
            #router 的loss
            if 'MOE_router_loss' in self.config["train"] and self.config["train"]['MOE_router_loss']["EN"]:
               
               self.lambada_value = self.model.lamba
               router_loss = self.model.load_balancing_loss * self.lambada_value 
               
   
            train_loss = sampling_loss+ recon_loss + step_loss + router_loss # loss

            all_l2 += train_loss.item() #all loss
            recon_loss += recon_loss.item()
            sampling_loss+= sampling_loss.item()
            step_loss += step_loss.item()

            #train_loss.backward()
            accelerator.backward(train_loss)
            optimizer.step()
     
           
         ###
         accelerator.wait_for_everyone()  # 确保所有进程到达这一点
         
         all_l2 /= len(train_loader)
         recon_loss /= len(train_loader)
         sampling_loss /= len(train_loader)
         step_loss /= len(train_loader)
         
         ## 嵌入pid的update,让控制器更新
         # 计算 RMSE（MSE的平方根）
         rmse = np.sqrt(all_l2)
         #rmse 要判断是否有update 在pid
         if self.config["model"]["PID"]["Enabled"]:
            self.lambda_value = self.model.update(rmse)
         #scheduler--每个进程
         scheduler.step()

         if accelerator.is_main_process:
            # 只有在主进程中执行除进程
            all_l2 = all_l2 / (  accelerator.num_processes)
            ##wandb ，传入字典
            if self.wandb_en == True:

               wandb.log(
                           {
                              'Train/All loss': all_l2,
                              'Train/Recon loss': recon_loss,
                              'Train/Sample loss': sampling_loss,
                              'Train/step loss': step_loss,
                              "currium_sampling/ratio": self.current_ratio,
                              "Train/router_loss":router_loss
                              
                           }
                           , step=e
               )
               #在单独记录一下pid
               if hasattr(self.model, 'save_dict'):
                  print("**_pid_record")
                  wandb.log(
                              {
                              'PID/target': self.model.target_loss,
                              'PID/lamba_t': self.model.lamba,
                              'PID/K_p': self.model.kp,
                              'PID/K_i': self.model.ki,
                              'PID/error_t':self.model.error
                               }
                              , step=e
                     
                           )
                  
            pbar.set_description(f"epoch{e}")
            
            gc.collect()  # 强制执行垃圾回收

            #test-- #主进程
            if e in self.config["train"]["milestones"]: 
              
               del model_out
               test_result = self.test(test_loader,
                        config,
                        device=self.device,visual=True,epoch=e)
               print("test_result",test_result)
                  
               #show train
               if isinstance(test_result, list):

                  pbar.set_description(
                        (
                           f'Epoch {e}, test rmse loss: {test_result[0]:.4f} '
                        )
                     )
               else:

                  pbar.set_description("test result is not list type")

         
            #early stop by test rmse
            if t_manager.check_improvement(test_result[0]):
               
               pbar.set_description(f"Epoch {e}, New Best Test Loss(rmse): {test_result[0]:.8f}")
               
            
               model_path = save_checkpoint(     config['train']['save_dir'],
                                                config['train']['save_model_name'].replace('.pt', f'_best.pt'),
                                                   self.model, optimizer)
               wandb.save(model_path)
               

            if t_manager.should_stop():

               print("***Early stopping triggered.")
               t_manager.best_loss = float('inf')
               break
            
      #结束时候
      if accelerator.is_main_process:
         wandb.log(
                  {
                     "Test/best_rmse":test_result[0]},step=e
               )


      
      
   def test(self,dataloader,config,device,visual=False,epoch=0):
      
      print(f"***{self.task}\n")

      if '2d' in self.task:

         print("***here is the 2d datasets \n")
         
        
         test_error = self.eval_2d(dataloader,config,visual= visual,epoch=epoch)
      
      elif '1d' in self.task:

         print("***here is the 1d datasets")
         test_error = self.eval_1d(dataloader,config,visual=visual,epoch=epoch)
 
      return test_error

   
   def eval_1d(self,dataloader,config,visual=False,epoch=0)->list:

      self.model.eval()
      from Torch_utils.utils3 import LpLoss,Acc_Error,inference_unit_time
   
      mae_loss_steps =0
      infer_time = 0
      rmse_loss_steps=0
      Test_loss = LpLoss(d=1,p=2,reduction=True)
      
      for inputs, outputs, g_t, condition,boundary in dataloader:
   
         inputs = inputs.to(self.device)
         outputs = outputs.to(self.device)
         condition = condition.to(self.device)
         boundary = boundary.to(self.device)

         if self.task == "1d_Forward_Wu_T250_In_1_OUT_1":
            #1d: use PAST 50 stepS to pred  1 STEP,but rollout to the next 200 steps
               self.roll_out_steps = outputs.shape[1]
               in_t = inputs.shape[1]
               new_inputs = inputs 
               outputs = outputs
               #[a,b,r] is condition
               model_out,recons,recon_true = self.model.roll_out( new_inputs,
                                                                  steps = self.roll_out_steps,
                                                                  static = condition,
                                                                  boundary = boundary)
            
         _,t,nx = inputs.shape

         # 多步预测
         rmse_multi = Test_loss.rel(model_out[:,:,:],outputs[:,:,:]).item()
         rmse_loss_steps = rmse_loss_steps + rmse_multi
         
         #mae
         mae_multi = torch.mean(torch.abs(model_out - outputs)).item()
         mae_loss_steps = mae_multi + mae_loss_steps


      if (epoch == config['train']['epochs']-1 or epoch==0):
         ##开始和结束测试
         input_random = torch.randn(1,t,nx).to(self.device)
         infer_time = inference_unit_time(self.model,input_random)
         

      rmse_loss_steps = rmse_loss_steps/ len(dataloader)
      mae_loss_steps = mae_loss_steps/len(dataloader)

      wandb.log(        
                           {
                              'Test/infernece_time_avg': infer_time,
                              'Test/rmse_multi-steps': rmse_loss_steps,
                              'Test/mae_multi-steps':mae_loss_steps
                           },step=epoch
                        )


      if visual == True:

         from matplotlib import gridspec

         _,t_in,nx = inputs.shape
         _,t_out,nx = model_out.shape
         inputs = inputs.cpu().detach().numpy()
         g_outputs = outputs.cpu().detach().numpy()
         model_out = model_out.cpu().detach().numpy()
         boundary = boundary.cpu().detach().numpy()
         
         ## fig
         fig, (ax1, ax2,ax3,ax4) = plt.subplots(1, 4, figsize=(10, 6))

         # 在第一行添加 3 个子图 show the inputs steps
         x_steps = np.linspace(0, 16, 50)  # Creates an array [0, 1, 2, ..., 49]
         ax1.plot(x_steps,inputs[0,0,:]) #[1,50] 第一个是时间
         ax1.set_title('Input Data', fontsize=16)
         ax1.set_xlabel("x")  # Label for the x-axis
         ax1.set_ylabel("u(x,t)")  # Label for the x-axis
         ax1.set_xlim(0, 16)  # 设置x轴的范围为0到16

         
         im2=ax2.imshow(model_out[0,:,:],vmin=0,vmax=2,cmap="jet")#[249,50]
         ax2.set_title('Model Out', fontsize=16)
         ax2.axis('off')
         
         im3=ax3.imshow(g_outputs[0,:,:], vmin=0, vmax=2, cmap="jet")
         ax3.set_title('Ground Truth', fontsize=16)
         ax3.axis('off')
         
         im4 = ax4.imshow(np.abs(model_out[0,:,:]-g_outputs[0,:,:]),vmin=0,vmax=0.2,cmap="jet")
         ax4.set_title('Abs Difference', fontsize=16)
         ax4.axis('off')
         #color
         cbar = fig.colorbar(im2, ax=ax2, orientation='horizontal', fraction=0.01, pad=0.02)
         cbar = fig.colorbar(im3, ax=ax3, orientation='horizontal', fraction=0.01, pad=0.02)
         cbar = fig.colorbar(im4, ax=ax4, orientation='horizontal', fraction=0.01, pad=0.02)

         # 调整子图之间的间距
         plt.tight_layout(pad=0.0005, w_pad=0.001, h_pad=0.01)  # 减少填充和子图间的空间
         img_path = f"{self.config['train']['save_dir']}/compare_True_model_test{epoch}.png"
         plt.savefig(img_path,dpi=300,bbox_inches='tight')
         wandb.log({"Test_Image": wandb.Image(img_path)}, step=epoch)
         plt.close()
         print("***We have save fig***")

      return [rmse_loss_steps,mae_loss_steps,infer_time]
         


     
   def eval_2d(self,dataloader,config,visual= False,epoch=0)->list:
      
      gc.collect()

      
      Test_loss = LpLoss(d=2,p=2,reduction=True)


      mae_loss_steps =0
      infer_time = 0
      rmse_loss_steps=0

      self.model.eval()
      #不要梯度
      with torch.no_grad():

         for inputs, outputs, g_t, condition,boundary in dataloader:

            g_t = g_t
            inputs = inputs
            outputs = outputs
            condition = condition
            boundary = boundary
            

            if self.task == "2d_Forward_NS_V1e-5_T20_In_10_OUT_1":
                  #use PAST 10 stepS to pred  1 STEP,but rollout to the next 10 steps
               
               
               #self.roll_out_steps = outputs.shape[1] roll_out_steps 已经在triain的时候被定义了
               ##ddp 加module
               model_out,recons,recon_true = self.model.roll_out(inputs,
                                                                  steps = self.roll_out_steps)

               b,t,nx,ny =model_out.shape
            elif self.task == "2d_Forward_PS_T2_In_1_OUT_1":
               self.roll_out_steps = outputs.shape[1]
               ##ddp 加module
               model_out,recons,recon_true = self.model.roll_out(inputs,
                                                                  steps = self.roll_out_steps)

               b,t,nx,ny =model_out.shape
            if self.task in ["2d_Forward_Pollu_T181_In_10_OUT_1", "2d_Reverse_Pollu_T181_In_10_OUT_1"]:

               #use PAST 10 stepS to pred  1 STEP,but rollout to the next 10 steps
               self.roll_out_steps = outputs.shape[1]

               # we could remove the initial by minus the first step
               in_t = inputs.shape[1]

               out_t = outputs.shape[1]
               new_inputs = inputs - boundary[:,0:in_t,:,:]
               outputs = outputs - boundary[:,in_t:in_t+out_t,:,:] #output也减去,相当于学习光滑d
   
               model_out,recons,recon_true = self.model.roll_out(new_inputs,
                                                                  steps = self.roll_out_steps,
                                                                  static = condition,
                                                                  boundary = boundary)
            
            if self.task == "2d_Forward_SST_T52_In_10_OUT_1":
               #use PAST 10 stepS to pred  1 STEP,but rollout to the next 10 steps
               self.roll_out_steps = outputs.shape[1]
               # in sst, the boundary has separate,已经光滑
               in_t = inputs.shape[1]
            
               new_inputs = inputs 
               outputs = outputs
               
               model_out,recons,recon_true = self.model.roll_out(new_inputs,
                                                                  steps = self.roll_out_steps,
                                                                  static = condition,
                                                                  boundary = boundary)
               
   
               
            # 多步预测
            rmse_multi = Test_loss.rel(model_out[:,:,:,:],outputs[:,:,:,:]).item()
            rmse_loss_steps = rmse_loss_steps + rmse_multi
            
            #mae
            mae_multi = torch.mean(torch.abs(model_out - outputs)).item()
            mae_loss_steps = mae_multi + mae_loss_steps



         if (epoch == config['train']['epochs']-1 or epoch==0):
               
            ##开始和结束测试，用第一个的值，不要random
            input_random = inputs[0:1,:]
            infer_time = inference_unit_time(self.model,input_random)
            # test prescale的数据集 画图
            

         rmse_loss_steps = rmse_loss_steps/ len(dataloader)
         mae_loss_steps = mae_loss_steps/len(dataloader)


         wandb.log(        
                              {
                                 'Test/infernece_time_avg': infer_time,
                                 'Test/rmse_multi-steps': rmse_loss_steps,
                                 'Test/mae_multi-steps':mae_loss_steps
                              },step=epoch
                           )


         if visual == True:

            from matplotlib import gridspec
            
            if len(inputs.shape) == 4 : 
               _,t_in,nx,ny = inputs.shape
            elif  len(inputs.shape) == 5 : 
               _,t_in,scales,nx,ny = inputs.shape #原来的逻辑
               #inputs 还是变回去吧
               inputs = reshape_and_concat_patches(inputs,num_patches = scales) #[batch,10,nx,ny]

            _,t_in,nx,ny = inputs.shape
            _,t_out,nx,ny = model_out.shape
            inputs = inputs.cpu().detach().numpy()
            g_outputs = outputs.cpu().detach().numpy()
            model_out = model_out.cpu().detach().numpy()
            boundary = boundary.cpu().detach().numpy()
            
            
            if self.task == "2d_Forward_Pollu_T181_In_10_OUT_1" or self.task == "2d_Reverse_Pollu_T181_In_10_OUT_1""":
               #加boundary
               
               g_outputs = g_outputs + boundary[:,0:out_t,:,:]#groud truth
               model_out = model_out + boundary[:,0:out_t,:,:]
            
      
            ## fig
            fig = plt.figure(figsize=(28, 8))
            
            # 设置 GridSpec 布局
            if t_in >t_out:
               subs_plot = t_in
            else:
               subs_plot = t_out
            gs = gridspec.GridSpec(4, subs_plot)  # 定义一个4行10列的网格

            # 在第一行添加 10 个子图 show the inputs steps
            axs_first = [fig.add_subplot(gs[0, i]) for i in range(t_in)]

            # 在第二行添加 3 个子图(多一个画abs)
            axs_sec= [fig.add_subplot(gs[1, i]) for i in range(t_out)]
            
            #第三行
            axs_third= [fig.add_subplot(gs[2, i]) for i in range(t_out)]

            #第四行
            axs_fourth= [fig.add_subplot(gs[3, i]) for i in range(t_out)]
            
            # 为每个子图绘制内容作为示例
            for index,ax in enumerate(axs_first):
               # 确定所有图像的全局最小和最大值
               vmin = np.min(inputs[0,index,:,:])
               vmax = np.max(inputs[0,index,:,:])
               im=ax.imshow(inputs[0,index,:,:],cmap="jet",vmin=vmin,vmax=vmax)
               
               #sst里面是0，1的掩码
               if self.task =="2d_Forward_SST_T52_In_10_OUT_1":
                  ax.imshow(boundary[0,index,:,:],cmap="gray",alpha=0.8)
                  
            
               if index == len(axs_first)-1:
                  cbar = fig.colorbar(im, ax = ax, orientation = 'vertical', fraction=0.046)
                  
               if "Reverse" in self.task:
                  ax.set_title(f"Reverse_IN_T_{13+index}",fontsize=14)
               else:
                  ax.set_title(f"In_T_{index+1}",fontsize=14)
               ax.axis('off')

            for index,ax in enumerate(axs_sec):
               vmin = np.min(g_outputs[0,index,:,:])
               vmax = np.max(g_outputs[0,index,:,:])
               
               im= ax.imshow(g_outputs[0,index,:,:],cmap="jet",vmin=vmin,vmax=vmax)
               #掩码
               if self.task =="2d_Forward_SST_T52_In_10_OUT_1":
                  ax.imshow(boundary[0,index,:,:],cmap="gray",alpha=0.8)
                  
               if index == len(axs_sec)-1:
                  cbar = fig.colorbar(im, ax = ax, orientation = 'vertical', fraction=0.046)
                  
               if "Reverse" in self.task:
                  ax.set_title(f"Reverse_True_T_{1+index}",fontsize=14)
               else:
                  ax.set_title(f"True_T_{t_in+index+1}",fontsize=14)
               ax.axis('off')

            for index,ax in enumerate(axs_third):
               vmin = np.min(model_out[0,index,:,:])
               vmax = np.max(model_out[0,index,:,:])
               
               im = ax.imshow(model_out[0,index,:,:],cmap="jet",vmin=vmin,vmax=vmax)
                           #sst里面是0，1的掩码
               if self.task =="2d_Forward_SST_T52_In_10_OUT_1":
                  ax.imshow(boundary[0,index,:,:],cmap="gray",alpha=0.8)
                  
               if index == len(axs_third)-1:
                  cbar = fig.colorbar(im, ax = ax, orientation = 'vertical', fraction=0.046)
               if "Reverse" in self.task:
                  ax.set_title(f"Reverse_Pred_T_{1+index}",fontsize=14)
               else:
                  ax.set_title(f"Pred_T_{t_in+index+1}",fontsize=14)
               ax.axis('off')

            for index,ax in enumerate(axs_fourth):
               
               im = ax.imshow(np.abs(model_out[0,index,:,:]-g_outputs[0,index,:,:]),cmap="jet",vmin=0,vmax=0.2)
               #掩码
               if self.task =="2d_Forward_SST_T52_In_10_OUT_1":
                  ax.imshow(boundary[0,index,:,:],cmap="gray",alpha=0.6)

               if index == len(axs_fourth)-1:
                  cbar = fig.colorbar(im, ax = ax, orientation = 'vertical', fraction=0.046)
               if "Reverse" in self.task:
                  ax.set_title(f"Reverse_ABS_{1+index}",fontsize=14)
               else:
                  ax.set_title(f"ABS_T_{t_in+index+1}",fontsize=14)
               ax.axis('off')
               
            
            # 调整子图之间的间距
            plt.tight_layout(pad=0.001, w_pad=0.002, h_pad=0.01)  # 减少填充和子图间的空间

            img_path = f"{self.config['train']['save_dir']}/compare_True_model_test{epoch}.png"
            # 确保目录存在
            os.makedirs(self.config["train"]['save_dir'], exist_ok=True)
            plt.savefig(img_path,dpi=300,bbox_inches='tight')
            wandb.log({"Test_Image": wandb.Image(img_path)}, step=epoch)
            plt.close()
            print("***We have save fig***")

            if True == self.config["model"]["MOE"]: #记录moe的配置矩阵
               from matplotlib.lines import Line2D
               print("record_moe")
               #n行4列，想证明weight有偏向性
               weight,mask = self.model.Moe_weight_info #
               weight = weight.cpu().detach().numpy()
               print("moe_weight:",weight.shape)

               #ylabels = ['Expert1', 'Exper2','Exper3','Exper4']
               ylabels = [f'Expert_{i+1}' for i in range(weight.shape[1])]
               expert_name = self.model.MOE_name #list
               # Create an empty list to hold the legend handles
               handles = [Line2D([0], [0], color='w', label=f'Expert: {expert_name}')]
             
               xlabels = [f'{i+1}' for i in range(weight.shape[0])]
               # # 创建柱状图
               # plt.figure(figsize=(8,6))  # 设置图形的尺寸
               
               # moe 权重矩阵变化-pengxiao
               plt.imshow(weight[:,:].T, cmap="coolwarm")
               
               # 设置 y 轴的标签
               plt.yticks(ticks=np.arange(len(ylabels)), labels=ylabels)
               
               plt.xticks(ticks=np.arange(len(xlabels)), labels=xlabels)
               #legend
               plt.legend(handles=handles,loc='upper center', bbox_to_anchor=(0.5, -0.3), frameon=False)

               # 设置 x 轴和标题
               plt.xlabel('Patchs')
               plt.ylabel('Experts')
               plt.title('MoE Weights Visualization')
               # 显示图像
               plt.colorbar(shrink=0.4)  # 添加颜色条
               plt.tight_layout()
               img_path = f"{self.config['train']['save_dir']}/moe_weight_{epoch}.png"
               plt.savefig(img_path,dpi=300,bbox_inches='tight')
               plt.close()
               wandb.log(        
                                 {
                                    'Test/moe_weight': weight[0,:],
                                    
                                 },step=epoch
                              )
               wandb.log({"Test_weight": wandb.Image(img_path)}, step=epoch)
               
               
               
               

         
         if hasattr(self.model, 'save_dict'):
           # pid 中间量会有图片
           self.model.save_pid( save_path = self.config['train']['save_dir'],epoch=epoch)
         
           
         
            #record the model' 参数
      
      #记录，如果小于就更新
               
      #一些定量结果的保存

      self.important_dict["model_paras"] = self.model.count_paras()
      if infer_time !=0: #0不保存哈
         self.important_dict["inference_time(ms)"] = infer_time
      self.important_dict["memo"]= self.config
      if rmse_loss_steps <  self.important_dict["rmse_best_test"]:
         print("***important_update")
         self.important_dict["rmse_best_test"] = rmse_loss_steps
         self.important_dict["mae_best_test"] = mae_loss_steps
         # update 保存
         filename = f"{self.config['train']['save_dir']}/important_results.yaml"
         with open(filename, 'w') as file:
            yaml.dump(self.important_dict, file, default_flow_style=False)

      
         
      return [rmse_loss_steps,mae_loss_steps,infer_time]


   def _pre_model(self,device_name= "cuda:0")->nn.Module:
      '''
         prepare model -> model,加入pid
      '''
      # dynaic import 
      import importlib
      
      model_name = config['model']['name']
      model_params = config['model']['parameters']

      
       # 动态导入模型类
      #使用 rsplit 分割字符串
      base_folder, py_file_name, class_name = model_name.rsplit('.', 2)

      print("Base Folder:", base_folder)  # Baseline
      print("Python File Name:", py_file_name)  # FNO
      print("Class Name:", class_name)  # FNO_Class

      # 导入模块并获取类
      try:
         module_name = f"{base_folder}.{py_file_name}"
         module = importlib.import_module(module_name)
         model_class = getattr(module, class_name)
         print("Model Class:", model_class)
      except ImportError as e:
         print(f"Import error: {e}")
      except AttributeError as e:
            # Handle the error if the specific attribute/function is not found
         print(f"Failed to find the required attribute or function in {base_folder}.{class_name}")
         print(f"Error: {e}")
      print("class_name: {}".format(model_class))
      model = model_class(**model_params).to(device_name)
      
      import types
      # 动态添加方法
      if True == self.config["model"]["MOE"]:
         # 代码块，如果'MOE'是键的一部分
         model.forward = types.MethodType(Moe_forward, model)
         print("***moe_forward")
      else:#普通的forward 
         model.forward = types.MethodType(baseline_forward_1, model)
         
      #MOE -控制架构
      if self.config["model"]["PID"]["Enabled"] == True:
         '''
         pid 控制方法
         '''
         from Baseline.PID import PI_controller
         print("*** this is the PI_MOE")
         print("*** pid paras:",self.config["model"]["PID"]["parameters"])
         PI_MOE = PI_controller(model=model,
                      **self.config["model"]["PID"]["parameters"])
         model = PI_MOE
         
      model.roll_out = types.MethodType(rollout_preds, model)

      #测试完成

      # Confirm the method addition
      if hasattr(model, 'forward'):
         print("forward method is successfully added.")
      else:
         print("Failed to forward method.")

      if hasattr(model, 'roll_out'):
         print("roll_out method is successfully added.")
      else:
         print("Failed to add roll_out method.")
         
         

      return model

   def _pre_dataset(self):
      '''
      read dataset from h5py
      read the task name in the yaml
      
      '''
      # get dataset path about .h5

      # 根据任务进行分流
      print("***task:",self.task)
      
      if self.task =="2d_Forward_NS_V1e-5_T20_In_10_OUT_1":
         
         if "Pre_scale" in self.config["model"]["parameters"]["scales"]:
            
            from Torch_utils.Datasets_utils.MOE_NS_datasets import MOE_NS_Dataset
            from torch.utils.data import DataLoader,random_split
            
            scale_str = self.config["model"]["parameters"]["scales"]
            scale_value = int(scale_str.split("x")[0]) #2 or 4

            ns = MOE_NS_Dataset(mat_path= self.config["train"]["Dataset_path"],
               task = self.task,scales=scale_value)
            
         
         else:
            from Torch_utils.Datasets_utils.NS_datasets import NS_Dataset
            from torch.utils.data import DataLoader,random_split

            ns = NS_Dataset(mat_path= self.config["train"]["Dataset_path"],
                           task = self.task)
 
         # 数据集的总长度
         dataset_size = len(ns)

         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size

         # 随机划分数据集
         train_dataset, test_dataset = random_split(ns, [train_size, test_size])

         train_loader = DataLoader (train_dataset, batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, batch_size=self.config['train']['batchsize'], shuffle=True)

         print("***NS_Loader has prepared")
         return train_loader,test_loader

      
      
      elif self.task == "NSWE_V1e-2_T101_Train_In_10_OUT_1":
         from Torch_utils.NSWE_datasets import NSWE_Dataset
         from torch.utils.data import DataLoader,random_split

         NSWE = NSWE_Dataset(h5_path= r"E:\study\LE-PDE++\Datasets\SWE_Nonlinear_Sigma_1.h5" ,
                              task = self.task)
         # 数据集的总长度
         dataset_size = len(NSWE)
          # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(NSWE, [train_size, test_size])

      
         train_loader = DataLoader (NSWE, batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(NSWE, batch_size=self.config['train']['batchsize'], shuffle=True)

         print("Loader has prepared")
         return train_loader,test_loader
      
      elif self.task == "2d_Forward_Pollu_T181_In_10_OUT_1":
         
         from torch.utils.data import DataLoader,random_split
         from Torch_utils.Datasets_utils.Pollu_datasets import Pollu_DatasetH5
         
         pollu_dataset = Pollu_DatasetH5(self.config["train"]["Dataset_path"],
                                   task='forward_In_10_OUT_1')
         #数据集的总长度(121)
         dataset_size = len(pollu_dataset)
         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(pollu_dataset, [train_size, test_size])

         train_loader = DataLoader (train_dataset, 
                                    batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, 
                                  batch_size=self.config['train']['batchsize'], shuffle=True)

         return train_loader,test_loader
      
      elif self. task =="2d_Forward_PS_T2_In_1_OUT_1":
         
         #自制多尺度
         if "Pre_scale" in self.config["model"]["parameters"]["scales"]:
            
            from torch.utils.data import DataLoader,random_split
            from Torch_utils.Datasets_utils.MOE_PB_datasets import MOE_PB_Dataset
            
            scale_str = self.config["model"]["parameters"]["scales"]
            scale_value = int(scale_str.split("x")[0]) #2 or 4
            Pb_Dataset = MOE_PB_Dataset(self.config["train"]["Dataset_path"],
                                    task='forward_In_1_OUT_1',scales=scale_value)
             
         else:
         
            from torch.utils.data import DataLoader,random_split
            from Torch_utils.Datasets_utils.PB_dataset import PB_Dataset
            
            Pb_Dataset = PB_Dataset(self.config["train"]["Dataset_path"],
                                    task='forward_In_1_OUT_1')
         #数据集的总长度(121)
         dataset_size = len(Pb_Dataset)
         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(Pb_Dataset, [train_size, test_size])

         train_loader = DataLoader (train_dataset, 
                                    batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, 
                                  batch_size=self.config['train']['batchsize'], shuffle=True)
         
         return train_loader,test_loader
         
      
      elif self.task == "2d_Reverse_Pollu_T181_In_10_OUT_1":
         from torch.utils.data import DataLoader,random_split
         from Torch_utils.Datasets_utils.Pollu_datasets import Pollu_DatasetH5
         
         pollu_dataset = Pollu_DatasetH5(self.config["train"]["Dataset_path"],
                                         task='reverse_In_10_OUT_1')
         #数据集的总长度(121)
         dataset_size = len(pollu_dataset)
         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(pollu_dataset, [train_size, test_size])

         train_loader = DataLoader (train_dataset, 
                                    batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, 
                                  batch_size=self.config['train']['batchsize'], shuffle=True)

         return train_loader,test_loader
      
      elif self.task == "2d_Forward_SST_T52_In_10_OUT_1":
         
         from torch.utils.data import DataLoader,random_split
         from Torch_utils.Datasets_utils.SST_datasets import SST_DatasetH5
         
         sst_dataset = SST_DatasetH5(  self.config["train"]["Dataset_path"],
                                       task='SST_In_10_OUT_1')
         #数据集的总长度(121)
         dataset_size = len(sst_dataset)
         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(sst_dataset, [train_size, test_size])

         train_loader = DataLoader (train_dataset, 
                                    batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, 
                                  batch_size=self.config['train']['batchsize'], shuffle=True)

         return train_loader,test_loader
      
      elif self.task == "1d_Forward_Wu_T250_In_1_OUT_1":
         
         from torch.utils.data import DataLoader,random_split
         from Torch_utils.Datasets_utils.Wu_1d_datasets import Wu1d_Dataset
         
         wu_1d = Wu1d_Dataset(   self.config["train"]["Dataset_path"],
                                       task='Wu1d_In_1_OUT_1')
         #数据集的总长度(1200)
         dataset_size = len(wu_1d)
         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(wu_1d, [train_size, test_size])

         train_loader = DataLoader (train_dataset, 
                                    batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, 
                                  batch_size=self.config['train']['batchsize'], shuffle=True)

         return train_loader,test_loader
      
      elif self.task == "2d_Identification_Cylinder_T800_In_100_OUT_1":
         
         from torch.utils.data import DataLoader,random_split
         from Torch_utils.Datasets_utils.Cylinder_datasets import Cylinder_DatasetH5
         
         Cylinder = Cylinder_DatasetH5(   self.config["train"]["Dataset_path"],
                                       task='Cylinder_In_100_OUT_1')
         #数据集的总长度(8)
         dataset_size = len(Cylinder)
         # 计算训练集和测试集的长度
         train_size = int(0.7 * dataset_size)
         test_size = dataset_size - train_size
          # 随机划分数据集
         train_dataset, test_dataset = random_split(Cylinder, [train_size, test_size])

         train_loader = DataLoader (train_dataset, 
                                    batch_size=self.config['train']['batchsize'], shuffle=True)
         test_loader = DataLoader(test_dataset, 
                                  batch_size=self.config['train']['batchsize'], shuffle=True)
         
         return train_loader,test_loader
         
      
         

         

   def _pre_le_pde(self,device_name= "cuda:0")->nn.Module:
      '''
      input the setting parameters in the lepde and the output is the model

      '''
   
      from le_pde.models import get_model, load_model, unittest_model, build_optimizer, test
      import pprint
      pp = pprint.PrettyPrinter(indent=0,depth=2,width=80)
      
      args_str = self.config['model']['parameters']['args_from_sota']
      # 将字符串转换为字典
      args_dict = eval(f"dict({args_str})") 
    
      import argparse
      device = device_name
      args = argparse.Namespace(**args_dict)
      #pp.pprint(args)

      #model = get_model(args,data,device)
      from le_pde.models import Contrastive
      input_size =  {'n0': self.config['model']['parameters']['input_size']}
      output_size = {'n0':  self.config['model']['parameters']['output_size']}
      grid_keys = ('n0',)
      part_keys = ()
      
      if "NS_" in self.task:
         print("NS_data")
         original_shape = (('n0', (64, 64)),)
         static_input_size = {'n0': 0}
      elif "Pollu_" in self.task:
         print("Pollu_")
         original_shape = (('n0', (512, 512)),)
         static_input_size = {'n0': 0}
      elif "SST_" in self.task:
         print("SST_")
         original_shape = (('n0', (180, 360)),)
         static_input_size = {'n0': 0}
      elif "Wu_" in self.task:
         print("Wu_1d")
         original_shape = (('n0', (50,)),)
         static_input_size = {'n0': 2}
      elif "Identification_Cylinder" in self.task:
         print("Identification_Cylinder")
         original_shape = (('n0', (31, 31)),)
         static_input_size = {'n0': 0}
      import types
      print("test_origin",original_shape)
      model = Contrastive(
            input_size=input_size,
            output_size=output_size,
            latent_size=args.latent_size,
            encoder_type=args.encoder_type,
            evolution_type=args.evolution_type,
            decoder_type=args.decoder_type,
            encoder_n_linear_layers=args.encoder_n_linear_layers,
            temporal_bundle_steps=args.temporal_bundle_steps,
            n_conv_blocks=args.n_conv_blocks,
            n_latent_levs=args.n_latent_levs,
            n_conv_layers_latent=args.n_conv_layers_latent,
            evo_conv_type=args.evo_conv_type,
            evo_pos_dims=args.evo_pos_dims,
            evo_inte_dims=args.evo_inte_dims,
            is_latent_flatten=args.is_latent_flatten,
            encoder_mode="dense",
            grid_keys=grid_keys,
            part_keys=part_keys,
            no_latent_evo=args.no_latent_evo,
            forward_type=args.forward_type,
            channel_mode=args.channel_mode,
            kernel_size=args.kernel_size,
            stride=args.stride,
            padding=args.padding,
            padding_mode=args.padding_mode,
            output_padding_str=args.output_padding_str,
            evo_groups=args.evo_groups,
            act_name=args.act_name,
            decoder_last_act_name=args.decoder_last_act_name,
            is_pos_transform=args.is_pos_transform,
            normalization_type=args.normalization_type,
            reg_type=args.reg_type,
            loss_type=args.loss_type,
            input_shape=original_shape,
            static_latent_size=args.static_latent_size,
            static_encoder_type=args.static_encoder_type,
            static_input_size=static_input_size,
            decoder_act_name=args.decoder_act_name,
            is_prioritized_dropout=args.is_prioritized_dropout,
            vae_mode=args.vae_mode,
        ).to(device)

      if "Identification_Cylinder" in self.task:
         print("Identification_Cylinder")
         #加入一些层
         model.identification_layer = nn.Sequential(
                                                   nn.AdaptiveAvgPool2d((1, 1)),
                                                   nn.Flatten(),
                                                   nn.Linear(2, 1)  # 两个通道映射成1个标量
                                                )
         
         #加入针对多个场的rollout
         model.multi_field_rollout =  types.MethodType(multi_field_rollout_preds, model)
        
      print("model: " , model)

 
      import types
      # 动态添加方法
      #print("test_method:",vars(model))
      model.forward = types.MethodType(forward_1, model)
      model.roll_out = types.MethodType(rollout_preds, model)
      # 使用新方法
      # Confirm the method addition
      if hasattr(model, 'forward'):
         print("forward method is successfully added.")
      else:
         print("Failed to forward method.")

      if hasattr(model, 'roll_out'):
         print("roll_out method is successfully added.")
      else:
         print("Failed to add roll_out method.")
      
      return model

      
   
   def resume(self): #load
      
      Load_Path = self.config["train"]["save_dir"] + self.config['load']['ckpt']
      print(f"load_ckpt_path =:{Load_Path}\n")
      model = load_checkpoint(self.model, Load_Path, optimizer = None)
      print("load_model_success")
    
      return model
      
import re

class Para_Optimize():

   def __init__(self,sp_config,train_expr:Lepde_expr):

      self.sp_config = sp_config
      print("sp",self.sp_config)
      #修改原来的config
      self.train_expr = train_expr

   
   def start(self,sweep_name_suffix="test"):

      #sweep name
      wandb.login(key = "546cc27dc3ae3d4e32301ef776cb1f65bd31cce4")
      
      sweep_name_suffix = self.train_expr.config["data"]["task"]
      
      sweep_name = self.train_expr.config["Sweep_config"]["name"]+ sweep_name_suffix
      
      self.sp_config["name"] =  sweep_name
      

      id = wandb.sweep(sweep=self.sp_config,
                       project = self.train_expr.config["log_run"]["project"])

      global sweep_counts  # 确保在这里声明它为全局变量
      sweep_counts = 0
      #test---finland
      with wandb.init():
         self.train_expr.train()
         
      # if self.config["Sweep_config"]["EN"] == False:
      #    exit()
   
      wandb.agent(id, 
                  function = self.wandb_train, 
                  count=self.train_expr.config["train"]["Sweep_count"])
      
   def generate_pattern_and_suffix(self,run_config):
    suffix_parts = []
    pattern_parts = []
    
    for key, value in run_config.items():
        # 生成正则表达式的模式部分
        if isinstance(value, (int, float)):  # 数值类型
            pattern_parts.append(f"_{key}_[\\d.]+")
        else:  # 非数值类型
            pattern_parts.append(f"_{key}_[\\w]+")
        
        # 生成替换用的后缀
        suffix_parts.append(f"_{key}_{value}")
    
    # 合成动态 pattern 和 new_suffix
    pattern = "".join(pattern_parts)
    new_suffix = "".join(suffix_parts)
    
    return pattern, new_suffix
 
   def  wandb_train(self):
      
      
      global sweep_counts  # 正确声明全局变量
      
      # name is the wandb
      with wandb.init( project = self.train_expr.config["log_run"]["project"],
                       group = self.train_expr.config["log_run"]["group"],
                       tags = [self.train_expr.config["log_run"]["tag"],self.train_expr.config["data"]["task"]],
                       notes = self.train_expr.config["log_run"]["note"],
                       mode= "online",
                       name = self.train_expr.config["log_run"]["name"] + f"_sweep_counts{sweep_counts}") as run:
        
         sweep_counts += 1  # 修改全局变量
         
         print("###run_config",run.config)
         wandb.save(t_manager.source_file)

         #遍历 run.config 中的每一个键值对
         for key, value in run.config.items():
            # 如果 train_expr.config 中存在同名键，则替换
            print(f"###***key:{key}")
            if key in self.train_expr.config["train"]:
               # 替换
               self.train_expr.config["train"][key] = value
               print(f"###sweep:{key}:{value}")
            else:
               print("## do not need in trian sweep")
               
            #  如果参数出现在model-PID
            if key in self.train_expr.config["model"]["PID"]["parameters"]:
               # 替换
               self.train_expr.config["model"]["PID"]["parameters"][key] = value
               print(f"###sweep:{key}:{value}")
            else:
               print("## do not need in pid sweep")
            
            if key in self.train_expr.config["model"]["parameters"]:
               #替换
               self.train_expr.config["model"]["parameters"][key] = value
               print(f"###sweep:{key}:{value}")
            else:
               print("## do not need in model sweep")
             
         
         # 生成最终的 suffix_name
         # 调用函数生成正则表达式模式和新的后缀
         pattern, new_suffix = self.generate_pattern_and_suffix(run.config)
         new_suffix = new_suffix + f"_sweep_count_{sweep_counts}"
         print(f"New Suffix Name: {new_suffix}")

         if '_raw' in self.train_expr.config['train']['save_dir']:
            raw_name = self.train_expr.config['train']['save_dir'].split('_raw')[0]
            
         elif '_sweep_count' in self.train_expr.config['train']['save_dir']:
            raw_name = self.train_expr.config['train']['save_dir'].split('_sweep_count')[0]

   
         # 使用正则表达式来替换已经存在的相应部分
         replacement = new_suffix 
         print("test_raw",raw_name)
         
         # 如果原始的suffix_name已包含相关部分，进行替换，否则直接使用新的suffix_name
         if re.search(pattern, raw_name):
            
            suffix_name = re.sub(pattern, replacement, raw_name)
      
         else:
            
            suffix_name = raw_name + new_suffix
         

         self.train_expr.config['train']['save_dir'] = suffix_name
         
         print("###save_dir:",self.train_expr.config['train']['save_dir'])
       
         
         # 注意模型需要初始化
         self.train_expr.__init__(config = self.train_expr.config)
         print("重新初始化")
         self.train_expr.train()
         run.finish()
        

   def add_suffix_to_filter(self,text, suffix):
    # 使用正则表达式匹配并替换
    pattern = r'(_SweepCount)'
    replacement = r'\1' + suffix
    modified_text = re.sub(pattern, replacement, text)
    return modified_text

if __name__ == "__main__":  

   # 设置命令行参数
   # 在保存模型时只让主进程保存


   print("This is expr.py")
   if torch.cuda.device_count() > 1:
      print("Let's use", torch.cuda.device_count(), "GPUs!")

   parser = argparse.ArgumentParser(description="Load configuration from YAML")
   parser.add_argument('--config_yaml', type=str, help='Path to the configuration YAML file')

   args = parser.parse_args()

   # yaml
   with open(args.config_yaml, 'r', encoding='utf-8') as stream:
      try:
         config = yaml.safe_load(stream)
         print("***config_yaml***",config)
         #copy

      except yaml.YAMLError as exc:
         print("exc:",exc)  # 解析错误
   
   global t_manager 
   # early stopping &save yaml
   t_manager = TrainingManager(save_dir = config["train"]["save_dir"], 
                              save_name= config["train"]['save_model_name'], 
                              device = config["train"]["device"], 
                              patience = config["train"]["patience"],
                              source_file = args.config_yaml)

   expr = Lepde_expr(config,load=config['load']['EN'])
   para_op = Para_Optimize(sp_config=config['Sweep'],train_expr = expr)
   para_op.start()