# @package _global_

# train.py arguments
project_name: regress-lm-multi-task-reinforce
experiment_name: regress-lm-multi-task-reinforce-${now:%Y-%m-%d_%H-%M-%S}
use_wandb: false  # Disable wandb to avoid log confusion during multi-task training
log_dir: logs
seed: 42
batch_size: 16  # Reduce batch size to accommodate REINFORCE memory requirements
num_epochs: 50  # Reduce training epochs to speed up testing
learning_rate: 1e-4
save_every_n_epochs: 5

# Step-based checkpoint saving
save_every_n_steps: 0 # Save checkpoint every N steps, set to null or 0 to disable

# Learning rate scheduler
scheduler:
  type: cosine  # cosine, linear, constant
  warmup_steps: 50  # Reduce warmup steps
  min_lr_ratio: 0.1  # minimum learning rate as a ratio of initial lr

# Model parameters from train.py
model:
  max_input_len: 1024
  max_num_objs: 1
  d_model: 256
  num_encoder_layers: 3
  num_decoder_layers: 3
  nhead: 4
  dim_feedforward: 1024
  dropout: 0.1

# REINFORCE configuration
reinforce:
  enabled: true
  temperature: 1.0  # Sampling temperature, controls exploration degree
  num_samples: 8    # Reduce sample count to speed up training
  reward_scale: 1.0 # Reward scaling factor
  baseline_type: "mean"  # Baseline type: 'mean' or 'min'
  weight: 0.5       # REINFORCE loss weight, reduce to avoid excessive exploration
  kl_weight: 0.1    # KL regularization weight, used to prevent policy from deviating too far from reference model

# Dataset parameters
dataset:
  name: "dataset-name"  # Default dataset name, will be overridden by command line arguments
  path:
    train: null
    val: null
  params:
    data_dir: regression_data

# Pretrained/trained checkpoint path, optional:
# - Directly point to a model.pt file
# - Point to a directory containing model.pt
# - Point to a parent directory containing multiple checkpoint_* subdirectories (will automatically select the latest)
init_checkpoint: null

# Other configurations
hydra:
  run:
    dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

save_dir: outputs/multi-task-reinforce/${now:%Y-%m-%d}/${now:%H-%M-%S}/checkpoints
if_ntl: false
