# 模型参数
input_dim: 1
output_dim: 1
input_len: 100
output_len: 50
d_model: 256
n_layer: 4



# 训练参数
batch_size: 2048
lr: 1.0e-4
num_epochs: 100
train_ratio: 0.8
val_ratio: 0.1


# 数据参数
data:
  type: "generate_gbm_data_multi_two"
  num_samples: 50000
  save_dir: "./data"
  # GBM参数
  mu1: -0.0027
  sigma1: 0.049
  mu2: 0.0149
  sigma2: 0.


save:
  model_name: "gbm_dist_μ1=-0.0027_σ1=0.049_μ2=0.0149_σ2=0.010"
  save_dir: "./checkpoints"
  log_dir: "./logs"
  keep_models: 3

device: "cuda:0"
test_gpu: 0
# 或者
# device: "cuda:1,3"    # 只使用 GPU 1和3

wandb:
  use_wandb: False  # 是否使用 wandb
  project: "ssm_sde"  # wandb 项目名称
  entity: "yuanshuai"  # wandb 实体名称
  name: "ssm_sde_xiuaga_001"  # 实验名称