python ./sdr/train.py \
    --model_path <model_path> \
    --dataset_name gsm8k \
    --num_remote 1 \
    --slm_name <slm_name>  \
    --llm_names <llm_names>  \
    --local_answer_paths data/<slm_name>/gsm8k-train.parquet data/<slm_name>/gsm8k-test.parquet \
    --remote_answer_paths data/<llm_name>/gsm8k-train.parquet data/<llm_name>/gsm8k-test.parquet \
    --embed_path <embed_path> \
    --output_dir <output_dir> \
    --head_model mlp \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --eval_steps 100 \
    --save_steps 100 \
    --warmup_steps 100 \
    --learning_rate 5e-5 \
    --num_train_epochs 0.01  \
    --bf16 \
    --seed 42 \
    --save_total_limit 2 \
    --eval_ratio 0.2 \
    --multi_remote_strategy head \