set -e

export WANDB_PROJECT='inheritance_new'
export WANDB_MODE=online
export CUDA_VISIBLE_DEVICES=0,1
export TORCHINDUCTOR_FX_GRAPH_CACHE=1
export OMP_NUM_THREADS=4

#         addition/copy_first_op \
        # addition/no_carry \
        # addition/only_carry \
        # addition/no_carry_only_carry \
        # addition/reverse_add_trans \
        # addition/reverse_sub \
        # addition/reverse_add \
for seed in 43 44 45 46 47; do
    # for task in \
    #     boolean/parity_majority \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --max_steps=5000 \
    #         --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
    #         --learning_rate=5e-4
    # done

    # for task in \
    #     addition/copy_first_op \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed
    # done

        # copy/copy_MQAR \
        # copy/copy_reverse \
        # copy/copy_capitalize_reverse \
        # copy/copy \
        # copy/copy_capitalize_reverse_control \
    for task in \
        copy/copy_capitalize_reverse_capitalize \
        copy/copy_capitalize_reverse_reverse \
    ; do
        WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --seed=$seed \
            --max_steps=5000 \
            --lr_scheduler_kwargs='{"num_stable_steps": 2500, "num_decay_steps": 2000}' 
    done

    # for task in \
    #     mult/reverse_mult_reverse_add \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=1 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=512 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=2 \
    #         --num_eval=512 \
    #         --use_iterable_dataset=True \
    #         --lr_scheduler_kwargs='{"num_decay_steps": 10000}' 
    # done

    # for task in \
    #     maze/maze_dfs \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=256 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=2 \
    #         --num_eval=512
    # done

    # for task in \
    #     maze/maze_dfs \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=256 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=2 \
    #         --num_eval=512 \
    #         --train_data.A.kwargs.padding=True \
    #         --train_data.B.kwargs.padding=True \
    #         --eval_data.A.kwargs.padding=True \
    #         --eval_data.B.kwargs.padding=True
    # done

        # maze/maze_dfs_sp \
        # maze/maze_dfs \

    # Set resume based on seed value
    # if [ "$seed" -gt 44 ]; then
    #     resume="False"
    # else
    #     resume="True"
    # fi
    
    # for task in \
    #     maze/maze_sp_dfs \
    #     maze/maze_dfs_sp \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=256 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=1 \
    #         --train_data.A.kwargs.randomize=True \
    #         --train_data.B.kwargs.randomize=True \
    #         --eval_data.A.kwargs.randomize=True \
    #         --eval_data.B.kwargs.randomize=True \
    #         --resume_from_checkpoint=$resume
    # done

    # for task in \
    #     maze/maze_sp \
    #     maze/maze_dfs \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=256 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=1 \
    #         --num_eval=1024 \
    #         --train_data.A.kwargs.randomize=True \
    #         --eval_data.A.kwargs.randomize=True \
    #         --resume_from_checkpoint=$resume
    # done
done
