data:
  # vectorized_data_dir: "datasets/vectorized_last_smoltalk_single_round_sae_llama3b_layers_14"
  vectorized_data_dir: "datasets/vectorized_last_20K_augmented_sae_llama3b_layers_14"
  train_ratio: 0.89
  val_ratio: 0.1
  test_ratio: 0.01

model:
  sae_path: "sae/sae_llama3b_layers_14.pth"
  fine_tune: False
  normalize: False
  hidden_layer: 16384  # 如果>0，使用MLP结构: (4096*32)->hidden_layer->4096

training:
  batch_size: 32  # 减小batch_size以适应多进程共享GPU
  learning_rate: 5e-5
  num_epochs: 3
  weight_decay: 1e-5
  warmup_steps: 1000
  max_grad_norm: 0.5
  mixed_precision: "no"  # "fp16", "bf16", "no" - 使用 fp32
  loss_type: "cosine"  # "mse" 或 "cosine"
  
  # GPU 配置
  gpu_ids: "7"  # 指定使用的 GPU (单块4090)
  num_gpus: 1     # 进程数量 (单进程)
  
  # 负载平衡配置
  dataloader_pin_memory: False
  dataloader_num_workers: 0
  
  output_dir: "prompt_decoder"
  intermediate_eval_frequency: 400  # 每50个batch进行一次中间验证
  
  # 测试配置 - 现在总是运行固定子集和随机测试两种

wandb:
  use_wandb: False
  project: "prompt-decoder-smoltalk"
  # run_name 将根据 fine_tune, batch_size, learning_rate 自动生成 