#!/bin/bash

export $(cat .env | xargs)

forget_gate=${FORGET_GATE:-"fg4"}
memory_size=${MEMORY_SIZE:-512}
forget_weight=${FORGET_WEIGHT:-2.0}
fgbias_init=${FG_BIAS_INIT:-8.0}
fg_dropout=${FG_DROPOUT:-0.0}
skip_layers=${SKIP_LAYERS:-0}
warmup_steps=${WARMUP_STEPS:-100}
base_loss=${BASE_LOSS:-"ntp"}
training_max_length=${TRAINING_MAX_LENGTH:-32768}  # Maximum sequence length for training
trainable_params=${TRAINABLE_PARAMS:-"self_attn.f_proj|self_attn.forget_gate"}
trainable_params_short=${TRAINABLE_PARAMS_SHORT:-"ff"}
# trainable_params=${TRAINABLE_PARAMS:-"self_attn.f_proj|self_attn.forget_gate|self_attn.v_proj|self_attn.q_proj|self_attn.k_proj"}
# trainable_params_short=${TRAINABLE_PARAMS_SHORT:-"ffqkv"}
# trainable_params=${TRAINABLE_PARAMS:-"self_attn.q_proj"}
# trainable_params_short=${TRAINABLE_PARAMS_SHORT:-"q"}
lr=${LR:-2e-4}
resume_from_checkpoint=${RESUME:-None}
weight_decay=${WEIGHT_DECAY:-0.01}
gpus=${GPUS:-1}
output_dir=${OUTPUT_DIR:-"~/scratch/jpmc/outputs/models"}
steps=${STEPS:-22000}  # -1 means no limit
bs=${BS:-1}
ebs=${EBS:-4}
logit_block_size=${LOGIT_BLOCK_SIZE:-8192}
dataset_name=${DATASET_NAME:-"openr1_math"}
# dataset_name=${DATASET_NAME:-"synth_long,booksum,buddhi"}
# dataset_name=${DATASET_NAME:-"booksum"}
dataset_path=${DATASET_PATH:-"~/scratch/jpmc/datasets"}
# base_model=${BASE_MODEL:-"/gpfs/radev/home/nhb25/scratch/jpmc/outputs/models/Qwen3-1.7B/test_model_v0"}
# resume_from_checkpoint="/gpfs/radev/home/nhb25/scratch/jpmc/outputs/models/Qwen3-1.7B/rnsa_openr1_Qwen3-1.7B_fg2_rnsa_flex_memory512_fw5.0_ff_ebs8_wd0.01_lr5e-4/checkpoint-6534"
base_model=${BASE_MODEL:-"Qwen/Qwen3-8B"}
# base_model=${BASE_MODEL:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
# base_model=${BASE_MODEL:-'deepseek-ai/DeepSeek-R1-0528-Qwen3-8B'}
# base_model=${BASE_MODEL:-'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'}

base_name=$(basename "$base_model")
prefix=${PREFIX:-"rnsa"}
ds_config=${DS_CONFIG:-"ds_config/stage2.json"}
attn_impl=${ATTN_IMPL:-"rnsa_flex"}
gradient_accumulation_steps=$((ebs/(gpus * bs)))
gc=${GC:-"True"}
debug=${DEBUG:-0}
max_samples=${MAX_SAMPLES:--1}  # -1 means no limit
# add a random value to the master port if it is already in use
master_port=$((10000 + RANDOM % 100))

export $(cat .env | xargs)

run_name="${prefix}_${base_name}_${dataset_name}_${base_loss}_${training_max_length}_${forget_gate}_${attn_impl}_memory${memory_size}_skip${skip_layers}_fw${forget_weight}_bias${fgbias_init}_${fg_dropout}_${trainable_params_short}_ebs${ebs}_wd${weight_decay}_lr${lr}"

if [[ $debug -eq 1 ]]; then
    WANDB_MODE="disabled"
    report_to="none"
    run_name="${run_name}_debug"
    steps=10
    training_max_length=131072
    max_samples=100
    output_dir="~/scratch/jpmc/outputs/debug_models"
    echo "Running in debug mode, steps set to 10 and training_max_length set to ${training_max_length}."
else
    WANDB_MODE="online"
    report_to="wandb"
    echo "Running in normal mode."
fi

echo "Run name: ${run_name}"

torchrun --nproc_per_node=$gpus --master_port=$master_port train.py  \
    --base_model $base_model \
    --bf16 True \
    --output_dir ${output_dir}/$base_name/$run_name \
    --dataset_name $dataset_name \
    --dataset_path $dataset_path \
    --training_max_length $training_max_length \
    --num_train_epochs 1     \
    --per_device_train_batch_size $bs     \
    --gradient_accumulation_steps $gradient_accumulation_steps     \
    --resume_from_checkpoint $resume_from_checkpoint     \
    --overwrite_output_dir True     \
    --save_steps 1000     \
    --save_total_limit 1     \
    --gradient_checkpointing $gc     \
    --eval_strategy "no"     \
    --save_strategy "steps"     \
    --learning_rate $lr     \
    --weight_decay $weight_decay     \
    --warmup_steps $warmup_steps     \
    --lr_scheduler_type "cosine_with_min_lr"     \
    --lr_scheduler_kwargs '{"min_lr":1e-6}'     \
    --trainable_params $trainable_params     \
    --base_loss $base_loss     \
    --forget_gate $forget_gate     \
    --forget_gate_bias_init $fgbias_init     \
    --fg_dropout $fg_dropout     \
    --attn_impl $attn_impl     \
    --memory_size $memory_size     \
    --forget_weight $forget_weight     \
    --skip_layers $skip_layers     \
    --logit_block_size $logit_block_size     \
    --logging_steps 1     \
    --deepspeed $ds_config \
    --run_name $run_name     \
    --max_steps $steps \
    --max_samples $max_samples     \
    --report_to $report_to     \
    $@
