# tuning hyper parameters
model_name=llama-134M
# model_name=llama-134M
round=2
gamma=1e-1
num_gpu=4
lr_lower=1e-5
lr_upper=1e-5
lr_z=2e-1
lamb=5e-2
zLoop=3
inner_steps=16
inner_accumulation_steps=4
max_steps=1000
micro_batch_size=4

method=bilevel


python src/select_data/bilevel_selection_llama.py --round $round --devices $num_gpu --gamma $gamma \
    --learning_rate $lr_lower  --z_loops $zLoop --inner_steps $inner_steps --inner_accumulation_steps $inner_accumulation_steps \
    --learning_rate_influence $lr_upper --max_steps $max_steps --lr_z $lr_z \
    --model_name $model_name --micro_batch_size $micro_batch_size


wait
max_steps=1000
for s in $(seq 0 $(($num_gpu-1))); do
    echo $s
    CUDA_VISIBLE_DEVICES=$s python src/select_data/predict_data_influence.py --shard $s $num_gpu --round $round --model_name $model_name --iter_num $max_steps &
done
wait
python src/select_data/select_data.py  --data_shards $num_gpu --round $round --model_name $model_name --method $method

wait
round=2

num_gpu=4

data_ckpt=$(($round * 80000))
ckpt=$data_ckpt
model_name=llama-0.5B method=$method decay=false ckpt=$ckpt data_ckpt=$data_ckpt round=$round devices=$num_gpu data_model_name=llama-134M bash scripts/pretrain_lma0.5B.sh
wait
ckpt=$(($round * 80000 + 80000))
data_ckpt=$(($round * 80000))
model_name=llama-0.5B method=bilevel decay=true ckpt=$ckpt data_ckpt=$data_ckpt round=$round devices=$num_gpu data_model_name=llama-134M bash scripts/pretrain_lma0.5B.sh
wait
CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES
ckpt=$(($round * 80000 + 81600))
ckpt=$(printf "%06d" $ckpt)
model_name=llama-0.5B method=bilevel ckpt=$ckpt bash scripts/eval_llama.sh