#!/bin/bash

# ATTN_IMPLEMENTATION=flash_attention_2 \
# USE_ATTN_POSTFIX=0 \
# python src/hip_research/main/long_eval_experimental.py \
#     --batch-size 1 \
#     --long-ce-k 1024 \
#     --seq-len 98304 \
#     --long-ce-gamma 5 \
#     --dataset pg19-longqa \
#     --model meta-llama/Llama-3.2-1B-Instruct \

target_model=meta-llama/Llama-3.1-8B-Instruct

postfix_options=(
    # recompute_dense-window_0-diff_1-w_64-decode_dense
    # recompute_dense-window_0-diff_0-w_64-decode_dense
    # recompute_dense-window_0-diff_0-w_64-decode_dense-JUST_RETURN
    
    # recompute_dense-window_0-diff_1-w_8-decode_dense
    # recompute_dense-window_0-diff_1-w_16-decode_dense
    # recompute_dense-window_0-diff_1-w_32-decode_dense
    # recompute_dense-window_0-diff_1-w_128-decode_dense
    # recompute_dense-window_0-diff_1-w_256-decode_dense

    recompute_dense-window_4096-diff_1-w_8-decode_dense
    recompute_dense-window_4096-diff_1-w_16-decode_dense
    recompute_dense-window_4096-diff_1-w_32-decode_dense
    recompute_dense-window_4096-diff_1-w_128-decode_dense
    recompute_dense-window_4096-diff_1-w_256-decode_dense

    # recompute_dense-window_1024-diff_1-w_64-decode_dense
    # recompute_dense-window_1024-diff_0-w_64-decode_dense
    # recompute_dense-window_1024-diff_0-w_64-decode_dense-JUST_RETURN
    # recompute_dense-window_2048-diff_1-w_64-decode_dense
    # recompute_dense-window_2048-diff_0-w_64-decode_dense
    # recompute_dense-window_2048-diff_0-w_64-decode_dense-JUST_RETURN
    # recompute_dense-window_4096-diff_1-w_64-decode_dense
    # recompute_dense-window_4096-diff_0-w_64-decode_dense
    # recompute_dense-window_4096-diff_0-w_64-decode_dense-JUST_RETURN
    
    # recompute_dense-window_4096-diff_1-w_128-decode_dense
    # recompute_dense-window_4096-diff_0-w_128-decode_dense
    # recompute_dense-window_4096-diff_0-w_128-decode_dense-JUST_RETURN
    # recompute_dense-window_4096-diff_1-w_256-decode_dense
    # recompute_dense-window_4096-diff_0-w_256-decode_dense
    # recompute_dense-window_4096-diff_0-w_256-decode_dense-JUST_RETURN
)

echo "starting flsah attention 2"
PYTHONPATH=./src/hip_research \
ATTN_IMPLEMENTATION=flash_attention_2 \
python -u src/hip_research/main/long_eval_experimental.py \
    --batch-size 1 \
    --long-ce-k 1024 \
    --recompute-n 1024 \
    --seq-len 98304 \
    --long-ce-gamma 5 \
    --dataset pg19-longqa \
    --model $target_model

# for rn in 128 256 512 1024; do
# for rn in 2048 4096 8192 16384 32768 65536; do
for postfix in "${postfix_options[@]}"; do
for rn in 1024; do
    echo "starting recompute n: ${rn}, postfix: ${postfix}"
    # CUDA_VISIBLE_DEVICES=5 \
    # PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
    # WANDB_MODE=disabled \
    # PRESET=default \
    PYTHONPATH=./src/hip_research \
    ATTN_IMPLEMENTATION=recompute_dense \
    USE_ATTN_POSTFIX=$postfix \
    python -u src/hip_research/main/long_eval_experimental.py \
        --batch-size 1 \
        --long-ce-k 1024 \
        --recompute-n $rn \
        --seq-len 98304 \
        --long-ce-gamma 5 \
        --dataset pg19-longqa \
        --model $target_model
done
done

# # for rn in 128 256 512 1024 2048 4096 8192 16384 32768 65536; do
# for rn in 0 131072; do
#     echo "starting recompute n: ${rn}"
#     WANDB_MODE=disabled \
#     PYTHONPATH=./src/hip_research \
#     ORACLE_RECOMPUTE=0 \
#     RANDOM_RECOMPUTE=0 \
#     PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
#     PRESET=default \
#     CUDA_VISIBLE_DEVICES=2 \
#     ATTN_IMPLEMENTATION=hip_attention \
#     USE_ATTN_POSTFIX=1 \
#     python src/hip_research/main/long_eval_experimental.py \
#         --batch-size 1 \
#         --long-ce-k 1024 \
#         --recompute-n $rn \
#         --seq-len 98304 \
#         --long-ce-gamma 5 \
#         --dataset pg19-longqa \
#         --model meta-llama/Llama-3.2-1B-Instruct
# done
# 
# # for rn in 128 256 512 1024 2048 4096 8192 16384 32768 65536; do
# for rn in 0 131072; do
#     echo "starting recompute n: ${rn}"
#     WANDB_MODE=disabled \
#     PYTHONPATH=./src/hip_research \
#     ORACLE_RECOMPUTE=0 \
#     RANDOM_RECOMPUTE=1 \
#     PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
#     PRESET=default \
#     CUDA_VISIBLE_DEVICES=2 \
#     ATTN_IMPLEMENTATION=hip_attention \
#     USE_ATTN_POSTFIX=1 \
#     python src/hip_research/main/long_eval_experimental.py \
#         --batch-size 1 \
#         --long-ce-k 1024 \
#         --recompute-n $rn \
#         --seq-len 98304 \
#         --long-ce-gamma 5 \
#         --dataset pg19-longqa \
#         --model meta-llama/Llama-3.2-1B-Instruct
# done
# 
# # for rn in 128 256 512 1024 2048 4096 8192 16384 32768 65536; do
# for rn in 0 131072; do
#     echo "starting recompute n: ${rn}"
#     WANDB_MODE=disabled \
#     PYTHONPATH=./src/hip_research \
#     ORACLE_RECOMPUTE=1 \
#     RANDOM_RECOMPUTE=0 \
#     PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
#     PRESET=default \
#     CUDA_VISIBLE_DEVICES=2 \
#     ATTN_IMPLEMENTATION=hip_attention \
#     USE_ATTN_POSTFIX=1 \
#     python src/hip_research/main/long_eval_experimental.py \
#         --batch-size 1 \
#         --long-ce-k 1024 \
#         --recompute-n $rn \
#         --seq-len 98304 \
#         --long-ce-gamma 5 \
#         --dataset pg19-longqa \
#         --model meta-llama/Llama-3.2-1B-Instruct
# done
