set -e

export WANDB_PROJECT='inheritance_new'
export WANDB_MODE=online
export CUDA_VISIBLE_DEVICES=0,1
export TORCHINDUCTOR_FX_GRAPH_CACHE=1
export OMP_NUM_THREADS=4

for seed in 43 44 45; do
    for l_aux in 8 16 32 64 128 256; do
    for l_main  in 4 8 16 32 64 128; do
        if [[ $l_aux -eq $l_main ]]; then
            continue
        fi
        if [[ $l_aux -lt $l_main ]]; then
            max_l=$l_main
        else
            max_l=$l_aux
        fi
        echo "l_main: $l_main, l_aux: $l_aux", "max_l: $max_l"
        num_eval=$((max_l < 16 ? max_l : 16))
        for task in \
            addition/no_carry_only_carry \
        ; do
            WANDB_RUN_GROUP=$task-nano-llama-sweep torchrun --nproc_per_node=2 run.py \
                --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
                --train_args experiments/inheritance/common/train_args_base.yaml \
                --seed=$seed \
                --max_steps=20000 \
                --lr_scheduler_kwargs='{"num_decay_steps": 4000}' \
                --eval_strategy='epoch' \
                --learning_rate=5e-4 \
                --auto_find_batch_size=True \
                --train_data.A.kwargs.la=[1,$((l_main+1))] \
                --train_data.A.kwargs.lb=[1,$((l_main+1))] \
                --train_data.B.kwargs.la=[1,$((l_aux+1))] \
                --train_data.B.kwargs.lb=[1,$((l_aux+1))] \
                --train_data.C.kwargs.la=[1,$((l_aux+1))] \
                --train_data.C.kwargs.lb=[1,$((l_aux+1))] \
                --eval_data.A.kwargs.la=[$((max_l / num_eval + 1)),$((max_l + max_l / num_eval + 1)),$((max_l / num_eval))] \
                --eval_data.B.kwargs.la=[$((max_l / num_eval + 1)),$((max_l + max_l / num_eval + 1)),$((max_l / num_eval))] \
                --eval_data.C.kwargs.la=[$((max_l / num_eval + 1)),$((max_l + max_l / num_eval + 1)),$((max_l / num_eval))] \
                --resume_from_checkpoint=True \
                --save_total_limit=1 \
                --torch_compile=False
        done
    done
    done
done
