set -e

export WANDB_PROJECT='inheritance_new'
export WANDB_MODE=online
export CUDA_VISIBLE_DEVICES=0,1
export TORCHINDUCTOR_FX_GRAPH_CACHE=1

for seed in 44 43 45; do
    for task in \
        addition/reverse_sub \
        addition/no_carry_only_carry \
        addition/copy_first_op \
    ; do
        WANDB_RUN_GROUP=$task-smollm-360m torchrun --nproc_per_node=2 run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/SmolLM.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --seed=$seed \
            --data_seed=$seed \
            --max_steps=2500 \
            --learning_rate=5e-5 \
            --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
            --per_device_train_batch_size=32 \
            --per_device_eval_batch_size=128 \
            --gradient_accumulation_steps=4 \
            --num_eval=512 \
            --use_iterable_dataset=True \
            --resume_from_checkpoint=False \
            --save_total_limit=1 \
            --eval_steps=0.5 \
            --save_steps=0.5 \
            --fsdp="full_shard auto_wrap" \
            --eval_data.A.kwargs.la=[5,49,2] \
            --eval_data.B.kwargs.la=[5,49,2] \
            $([ "$task" == "addition/no_carry_only_carry" ] && echo "--eval_data.C.kwargs.la=[5,49,2]")
    done

    for task in \
        copy/copy_capitalize_reverse \
        copy/copy_reverse \
    ; do
        WANDB_RUN_GROUP=$task-smollm-360m torchrun --nproc_per_node=2 run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/SmolLM.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --seed=$seed \
            --data_seed=$seed \
            --max_steps=1000 \
            --learning_rate=5e-5 \
            --lr_scheduler_kwargs='{"num_decay_steps": 500}' \
            --per_device_train_batch_size=32 \
            --per_device_eval_batch_size=128 \
            --gradient_accumulation_steps=4 \
            --num_eval=512 \
            --use_iterable_dataset=True \
            --resume_from_checkpoint=False \
            --save_total_limit=1 \
            --eval_steps=0.5 \
            --save_steps=0.5 \
            --fsdp="full_shard auto_wrap" \
            --eval_data.A.kwargs.l=[7,65,4] \
            --eval_data.B.kwargs.l=[7,65,4] \
            $([ "$task" != "copy/copy_reverse" ] && echo "--eval_data.C.kwargs.l=[7,65,4]")
    done

    # for mult_len in 5 7 9; do
    # for task in \
    #     mult/reverse_mult_reverse_add \
    # ; do
    #     WANDB_RUN_GROUP=$task-qwen-15 python run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/Qwen_0.5B.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --max_steps=5000 \
    #         --learning_rate=5e-5 \
    #         --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
    #         --per_device_train_batch_size=64 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=1 \
    #         --num_eval=256 \
    #         --use_iterable_dataset=True \
    #         --train_data.B.kwargs.la=[1,$mult_len] \
    #         --train_data.B.kwargs.lb=[1,$mult_len]
    # done
    # done

    # for task in \
    #     maze/maze_dfs \
    #     maze/maze \
    # ; do
    #     WANDB_RUN_GROUP=$task-qwen-15 python run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/Qwen_0.5B.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --max_steps=5000 \
    #         --learning_rate=5e-5 \
    #         --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
    #         --per_device_train_batch_size=64 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=1 \
    #         --num_eval=256 \
    #         --use_iterable_dataset=True
    # done
done
