#set -e

pip install -r requirements.txt
pip install -e .

nvidia-smi

export CURL_CA_BUNDLE=""
export PYTHONWARNINGS="ignore"
export HF_DATASETS_CACHE="/app/datasets_cache"

pip install --upgrade wandb

python restore_state.py
export START_DATA_IDX=`cat .start_data_idx.txt`
export START_DATA_IDX=0  # TODO use this line to restart dataset
export MAX_LENGTH=256
export EXPAND_FACTOR=1
export TASK="mlm"
export T_MAX=10

python3 load_dataset.py --num-proc 64 --load-num-proc 1 --num-shards 1 --start-idx $START_DATA_IDX --tokenizer-name "c-tokenizer" --dataset-name "c4" --max-length $MAX_LENGTH --expand-factor $EXPAND_FACTOR

accelerate launch \
  --multi_gpu \
  --mixed_precision bf16 \
  train.py \
  --batch-size 2048 \
  --learning-rate 1e-4 \
  --t-max $T_MAX \
  --t-min 1 \
  --max-train-steps 1000000 \
  --max-length $MAX_LENGTH \
  --task $TASK \
  --tokenizer-name "c-tokenizer" \
  --hidden-size 1024 \
  --num-hidden-layers 8 \
  --num-attention-heads 8 \
  --clip-grad-norm \
  --clip-grad-norm-value 10 \
  --gradient-accumulation-steps 1 \
  --max-number-of-spans 8 \
  --compile \
  --use-scheduler \
  --flash-attention

rm -rf /app/datasets_cache

let START_DATA_IDX+=1
#
#
for (( i = $START_DATA_IDX; i < 1024; i+=32 ))
do
  let START_DATA_IDX+=32
  echo "loading and preparing dataset..."
  echo $i
  until python3 load_dataset.py --num-proc 100 --load-num-proc 1 --num-shards 32 --start-idx $i --tokenizer-name "c-tokenizer" --dataset-name "c4" --max-length $MAX_LENGTH --expand-factor $EXPAND_FACTOR
  do
    sleep 3
  done
  echo "done"
  accelerate launch \
    --multi_gpu \
    --mixed_precision bf16 \
    train.py \
    --batch-size 2048 \
    --learning-rate 1e-4 \
    --t-max $T_MAX \
    --t-min 1 \
    --max-train-steps 1000000 \
    --max-length $MAX_LENGTH \
    --task $TASK \
    --tokenizer-name "c-tokenizer" \
    --hidden-size 1024 \
    --num-hidden-layers 8 \
    --num-attention-heads 8 \
    --clip-grad-norm \
    --gradient-accumulation-steps 1 \
    --max-number-of-spans 8 \
    --compile \
    --use-scheduler \
    --clip-grad-norm-value 10 \
    --flash-attention
    rm -rf /app/datasets_cache
  #lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill {}
done
