set -e

export WANDB_PROJECT='inheritance_new'
export WANDB_MODE=online
export CUDA_VISIBLE_DEVICES=0,1
export TORCHINDUCTOR_FX_GRAPH_CACHE=1
# export OMP_NUM_THREADS=4

for seed in 43 44 45 46 47; do
    # for task in \
    #     addition/no_carry \
    #     addition/only_carry \
    #     addition/no_carry_only_carry \
    #     addition/reverse_add_trans \
    #     addition/reverse_sub \
    #     addition/reverse_add \
    #     addition/copy_first_op \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama python run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --architecture='llama-nope' \
    #         --learning_rate=5e-4
    # done

    # for task in \
    #     copy/copy_MQAR \
    #     copy/copy_reverse \
    #     copy/copy_capitalize_reverse \
    #     copy/copy \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama python run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --max_steps=5000 \
    #         --lr_scheduler_kwargs='{"num_stable_steps": 2500, "num_decay_steps": 2000}' \
    #         --architecture='llama-nope' \
    #         --learning_rate=5e-4
    # done

    # for task in \
    #     mult/reverse_mult_reverse_add \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama python run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=512 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=2 \
    #         --num_eval=512 \
    #         --use_iterable_dataset=True \
    #         --lr_scheduler_kwargs='{"num_decay_steps": 10000}' \
    #         --architecture='llama-nope' \
    #         --learning_rate=5e-4
    # done

    # for task in \
    #     maze/maze_dfs \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=256 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=2 \
    #         --num_eval=512
    # done

    # for task in \
    #     maze/maze_dfs \
    # ; do
    #     WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2 run.py \
    #         --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
    #         --train_args experiments/inheritance/common/train_args_base.yaml \
    #         --seed=$seed \
    #         --per_device_train_batch_size=256 \
    #         --per_device_eval_batch_size=256 \
    #         --gradient_accumulation_steps=2 \
    #         --num_eval=512 \
    #         --train_data.A.kwargs.padding=True \
    #         --train_data.B.kwargs.padding=True \
    #         --eval_data.A.kwargs.padding=True \
    #         --eval_data.B.kwargs.padding=True
    # done

    for task in \
        maze/maze_dfs \
        maze/maze_sp \
    ; do
        WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2  run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --seed=$seed \
            --per_device_train_batch_size=256 \
            --per_device_eval_batch_size=256 \
            --gradient_accumulation_steps=1 \
            --train_data.A.kwargs.randomize=True \
            --train_data.B.kwargs.randomize=True \
            --eval_data.A.kwargs.randomize=True \
            --eval_data.B.kwargs.randomize=True \
            --architecture='llama-nope'
    done

    for task in \
        maze/maze \
    ; do
        WANDB_RUN_GROUP=$task-nano-llama torchrun --nproc_per_node=2  run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml experiments/inheritance/common/NanoLlama.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --seed=$seed \
            --per_device_train_batch_size=256 \
            --per_device_eval_batch_size=256 \
            --gradient_accumulation_steps=1 \
            --num_eval=1024 \
            --train_data.A.kwargs.randomize=True \
            --eval_data.A.kwargs.randomize=True \
            --architecture='llama-nope'
    done
done
