# 模型参数
input_dim: 1
output_dim: 1
input_len: 100
output_len: 50
d_model: 256
n_layer: 4



# 训练参数
batch_size: 2048
lr: 1.0e-4
num_epochs: 100
train_ratio: 0.8
val_ratio: 0.1


# 数据参数
data:
  # type: "two_phase_gbm"  # 可选: "gbm", "sde"
  # type: "generate_three_peak_switching_sde_data"  # 可选: "gbm", "sde"
  # type: "gbm"  # 可选: "gbm", "sde"
  # type: "gbm_multi"  # 可选: "gbm", "sde"
  type: "generate_gbm_data_three"  # 可选: "gbm", "sde"
  num_samples: 50000
  save_dir: "./data"  # 数据保存的根目录


# 保存配置
save:
  model_name: "mamba_sde"  # 模型名称（用于文件前缀）
  save_dir: "./checkpoints"  # 保存目录
  log_dir: "./logs"  # 日志目录
  keep_models: 3  # 保存最好的前N个模型

# 其他配置
device: "cuda:0"  # 使用 GPU 0,1
test_gpu: 0  # 指定用于测试的GPU编号

# 或者
# device: "cuda:1,3"    # 只使用 GPU 1和3

wandb:
  use_wandb: False  # 是否使用 wandb
  project: "ssm_sde"  # wandb 项目名称
  entity: "yuanshuai"  # wandb 实体名称
  name: "ssm_sde_xiuaga_001"  # 实验名称