#!/bin/bash
python -u train_dist.py \
    model=mistral_7b \
    datasets=[ultrachat] \
    loss=sft \
    lr=2e-5 \
    wandb.project=xxx \
    exp_name=xxx \
    gradient_accumulation_steps=4 \
    batch_size=32 \
    eval_batch_size=32 \
    trainer=FSDPTrainer \
    sample_during_eval=false \
    model.fsdp_policy_mp=bfloat16 \
    bandit.train_only=true \
    bandit.norm_metrics=true \
    bandit.use_pool=true \
    bandit.pool_size=40000 \
    bandit.pool_use_num=32 \
    bandit.lr=0.0001 \
    bandit.concat_wl=false \
    bandit.f1_reward_type=add_rm_logp \
    bandit.encoder_name_or_path=xxx \
    bandit.enable_bandit=true \
    bandit.layer_encoder=true \
    bandit.use_encoder=true \
    bandit.load_bandit=false \
    bandit.forward_only=false \
    bandit.enable_avg_reward=true \
    bandit.train_step=1 \
    bandit.f1_only=false \
    bandit.f1_num_layers=16 \
    bandit.f2_num_layers=16 \
    bandit.f1_num_epochs=2 \
    bandit.f1_num_batch=2 \
    bandit.f2_num_epochs=2 \
    bandit.f2_num_batch=2 \
    bandit.f2_weight=0.01 \
    save_every=20000000  \
    max_length=1024 \
    max_prompt_length=512 \
    n_epochs=1 \
    n_examples=null \
    debug=false
