set -e

export WANDB_PROJECT='inheritance_new'
export WANDB_MODE=online
export CUDA_VISIBLE_DEVICES=0,1
export TORCHINDUCTOR_FX_GRAPH_CACHE=1

# for seed in 42; do
#     model_id=HuggingFaceTB/SmolLM2-360M-intermediate-checkpoints
#     lr=5e-5
#     bs=64
#     gas=2


#         # step-160000 \
#         # step-320000 \
#         # step-640000 \
#         # step-1280000 \
#         # step-2560000 \
#     for revision in \
#         null \
#     ; do

#     if [[ $revision == "null" ]]; then
#         model_id=HuggingFaceTB/SmolLM2-360M
#         from_pretrained=False
#         run_name="step-0"
#         max_steps=5000
#         gas=4
#     else
#         from_pretrained=True
#         max_steps=1500
#     fi

#     echo "Using model $model_id @ $revision with learning rate $lr, batch size $bs, and gradient accumulation steps $gas"
#     for task in \
#         addition/reverse_add \
#     ; do
#         WANDB_RUN_GROUP=$task-scaling accelerate launch run.py \
#             --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml \
#             --train_args experiments/inheritance/common/train_args_base.yaml \
#             --model_id=$model_id \
#             --from_pretrained=$from_pretrained \
#             --revision=$revision \
#             --run_name=$revision \
#             --seed=$seed \
#             --max_steps=$max_steps \
#             --learning_rate=$lr \
#             --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
#             --per_device_train_batch_size=$bs \
#             --per_device_eval_batch_size=64 \
#             --num_eval=512 \
#             --eval_on_start=False \
#             --gradient_accumulation_steps=$gas \
#             --eval_data.A.kwargs.la=[5,33,2] \
#             --resume_from_checkpoint=False \
#             --save_total_limit=1 \
#             --torch_compile=False \
#             --eval_steps=0.25 \
#             --save_steps=0.5 \
#             --fsdp="full_shard auto_wrap"
#     done
#     done
# done


for seed in 42; do
    model_id=HuggingFaceTB/SmolLM2-360M-intermediate-checkpoints
    lr=5e-5
    bs=32
    gas=4

        # step-160000 \
        # step-320000 \
        # step-640000 \
        # step-1280000 \
        # step-2560000 \

    for revision in \
        null \
    ; do

    if [[ $revision == "null" ]]; then
        model_id=HuggingFaceTB/SmolLM2-360M
        from_pretrained=False
        run_name="step-0"
        max_steps=5000
        gas=4
    else 
        from_pretrained=True
        max_steps=1500
    fi

    echo "Using model $model_id @ $revision with learning rate $lr, batch size $bs, and gradient accumulation steps $gas"
    for task in \
        maze/maze \
    ; do
        WANDB_RUN_GROUP=$task-scaling accelerate launch run.py \
            --args experiments/inheritance/common/data.yaml experiments/inheritance/$task.yaml \
            --train_args experiments/inheritance/common/train_args_base.yaml \
            --model_id=$model_id \
            --from_pretrained=$from_pretrained \
            --revision=$revision \
            --run_name=$revision \
            --seed=$seed \
            --max_steps=$max_steps \
            --learning_rate=$lr \
            --lr_scheduler_kwargs='{"num_decay_steps": 1000}' \
            --per_device_train_batch_size=$bs \
            --per_device_eval_batch_size=128 \
            --num_eval=512 \
            --eval_on_start=False \
            --gradient_accumulation_steps=$gas \
            --resume_from_checkpoint=False \
            --save_total_limit=1 \
            --torch_compile=False \
            --eval_steps=0.25 \
            --save_steps=0.25 \
            --fsdp="full_shard auto_wrap" \
            --resume_from_checkpoint=True \
            --train_data.A.kwargs.dfs=False \
            --eval_data.A.kwargs.dfs=False
    done
    done
done
