set -e

export WANDB_PROJECT='inheritance_new'
export WANDB_MODE=online
export CUDA_VISIBLE_DEVICES=0,1
export TORCHINDUCTOR_FX_GRAPH_CACHE=1
export OMP_NUM_THREADS=4

for seed in 43 44 45 46 47; do
# addition/no_carry_only_carry_no_add \
    for task in \
        copy/capitalize_and_reverse \
    ; do
        WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --seed=$seed \
            --resume_from_checkpoint=True \
            --max_steps=3000 \
            --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
            --save_only_model=True \
            --eval_steps=1000 \
            --save_steps=0.5 \
            --num_layers=2
    done

    # for layers in "[0]" "[1]" "[2]" "[3]" "[4]" "[5]" "[0,1]" "[1,2]" "[2,3]" "[3,4]" "[4,5]" "[0,1,2]" "[1,2,3]" "[2,3,4]" "[3,4,5]"; do
    for layers in "[]" "[0]" "[1]"; do
        # echo "Layers: $layers"
        # addition/reverse_add \
        # --resume_from_checkpoint='out/-llama-384-6-6-1024-rope-reverse_add_no_carry-la=1_33-lb=1_33-reverse_add_only_carry-la=1_33-lb=1_33-SFT-seed-'$seed'/checkpoint-20000' \
        for task in \
            copy/capitalize_reverse \
        ; do
            WANDB_RUN_GROUP=$task-nano-llama-grafting torchrun --nproc_per_node=1 run.py \
                --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
                --train_args experiments/inheritance/common/train_args_base.yaml \
                --run_name_prefix='grafting-'$layers \
                --resume_from_checkpoint='out/-llama-384-6-2-1024-rope-capitalize-l=6_33-reverse-l=6_33-SFT-seed-'$seed'/checkpoint-3000' \
                --max_steps=4000 \
                --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
                --freeze_layers=$layers \
                --seed=$seed \
                --save_total_limit=1 \
                --num_layers=2
        done
    done
done    
