#! /bin/bash

STEPS=4000

DATA_PATH_CONFIG=config/configs/tmp.json
DATA_PATH="#config ${DATA_PATH_CONFIG}"
SAVE_INTERVAL=2000
# # STEPS iteration from different checkpoints from one directory
for LOAD_STEP in $(seq 2000 6000 25000); do
  for DIRECTION in $(seq 1 7); do
    cat <<EOT > $DATA_PATH_CONFIG
{
  "data_prefixes": [
    "/inputs/data_url_1/en_c4",
    "/inputs/data_url_2/ru_mc4ru"
  ],
  "weights": [
    $( python -c "print(0.5 - (${DIRECTION} - 4) / 3 * 0.45)"),
    $( python -c "print(0.5 + (${DIRECTION} - 4) / 3 * 0.45)")
  ],
  "sampling_strategy": {
    "type": null
  }
}
EOT
    echo $LOAD_STEP > ${LOAD_PATH}/latest_checkpointed_iteration.txt
    echo global_step$LOAD_STEP > ${LOAD_PATH}/latest
    CHECKPOINT_PATH_ACTUAL=${CHECKPOINT_PATH}_load${LOAD_STEP}_direction${DIRECTION}
    TENSORBOARD_PATH_ACTUAL=${TENSORBOARD_PATH}_load${LOAD_STEP}_direction${DIRECTION}
    TRAIN_ITERS=$(($LOAD_STEP+$STEPS))
    python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
           --tensorboard-dir $TENSORBOARD_PATH_ACTUAL \
           --log-validation-ppl-to-tensorboard \
           --log-timers-to-tensorboard \
           --log-batch-size-to-tensorboard \
           --tensorboard-log-interval $LOG_INTERVAL \
           --num-layers $NUM_LAYERS \
           --hidden-size $HIDDEN_SIZE \
           --num-attention-heads $NUM_HEADS \
           --micro-batch-size $MICRO_BATCH_SIZE \
           --global-batch-size $GLOBAL_BATCH_SIZE \
           --seq-length $SEQ_LENGTH \
           --max-position-embeddings $SEQ_LENGTH \
           --train-iters $TRAIN_ITERS \
           --lr-decay-iters $LR_DECAY_ITERS \
           --save ${CHECKPOINT_PATH_ACTUAL} \
           --load $LOAD_PATH \
           --data-path $DATA_PATH \
           --eval-dataset $EVAL_DATA_PATH \
           --input-keys $INPUT_KEYS \
           --languages $LANGUAGES \
           --vocab-file $VOCAB_FILE \
           --tokenizer-type $TOKENIZER_TYPE \
           --data-impl mmap \
           --split 998,1,1 \
           --distributed-backend nccl \
           --lr $LR \
           --lr-decay-style cosine \
           --min-lr $LR_MIN \
           --weight-decay $WEIGHT_DECAY \
           --clip-grad 1.0 \
           --lr-warmup-fraction $LR_WARMUP_FRACTION \
           --checkpoint-activations \
           --log-interval 100 \
           --save-interval $SAVE_INTERVAL \
           --eval-interval 100 \
           --eval-iters 10 \
           --fp16 \
           --loss-scale 12 \
           --adam-beta1 $ADAM_BETA_1 \
           --adam-beta2 $ADAM_BETA_2 \
           --init-method-std $INIT_METHOD_STD \
           $ds_args ;
  done;
done