#! /bin/bash

STEPS=4000

DATA_PATH_CONFIG=config/configs/tmp.json
DATA_PATH="#config ${DATA_PATH_CONFIG}"
SAVE_INTERVAL=2000
# # STEPS iteration from different checkpoints from one directory
for LOAD_STEP in $(seq 2000 6000 25000); do
  for DIRECTION1 in $(seq 1 7); do
    for DIRECTION2 in $(seq 1 7); do
      if [[ $(( $(( $DIRECTION1 + $DIRECTION2 )) % 2 )) == 1 ]] || \
         [[ $(( $DIRECTION1 + $DIRECTION2 )) -gt 12 ]] || \
         [[ $(( $DIRECTION1 + $DIRECTION2 )) -lt 4 ]]
        then continue
      fi
      echo Starting experiment with direction $DIRECTION1 $DIRECTION2
      cat <<EOT > $DATA_PATH_CONFIG
{
  "data_prefixes": [
    "/inputs/data_url_0/en_c4",
    "/inputs/data_url_2/ru_c4mc4"
  ],
  "weights": [
    $( python -c "print(0.5 - (${DIRECTION2} - 4) / 3 * 0.45)"),
    $( python -c "print(0.5 + (${DIRECTION2} - 4) / 3 * 0.45)")
  ],
  "sampling_strategy": {
    "type": null
  }
}
EOT
      LOAD_PATH_ACTUAL=${LOAD_PATH}/model_step02_load${LOAD_STEP}_direction${DIRECTION1}
      CHECKPOINT_PATH_ACTUAL=${CHECKPOINT_PATH}_load${LOAD_STEP}_direction${DIRECTION1}_4000_direction${DIRECTION2}
      TENSORBOARD_PATH_ACTUAL=${TENSORBOARD_PATH}_load${LOAD_STEP}_direction${DIRECTION1}_4000_direction${DIRECTION2}
      TRAIN_ITERS=$(($LOAD_STEP+4000+$STEPS))
      python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
             --tensorboard-dir $TENSORBOARD_PATH_ACTUAL \
             --log-validation-ppl-to-tensorboard \
             --log-timers-to-tensorboard \
             --log-batch-size-to-tensorboard \
             --tensorboard-log-interval $LOG_INTERVAL \
             --num-layers $NUM_LAYERS \
             --hidden-size $HIDDEN_SIZE \
             --num-attention-heads $NUM_HEADS \
             --micro-batch-size $MICRO_BATCH_SIZE \
             --global-batch-size $GLOBAL_BATCH_SIZE \
             --seq-length $SEQ_LENGTH \
             --max-position-embeddings $SEQ_LENGTH \
             --train-iters $TRAIN_ITERS \
             --lr-decay-iters $LR_DECAY_ITERS \
             --save ${CHECKPOINT_PATH_ACTUAL} \
             --load $LOAD_PATH_ACTUAL \
             --data-path $DATA_PATH \
             --eval-dataset $EVAL_DATA_PATH \
             --input-keys $INPUT_KEYS \
             --languages $LANGUAGES \
             --vocab-file $VOCAB_FILE \
             --tokenizer-type $TOKENIZER_TYPE \
             --data-impl mmap \
             --split 998,1,1 \
             --distributed-backend nccl \
             --lr $LR \
             --lr-decay-style cosine \
             --min-lr $LR_MIN \
             --weight-decay $WEIGHT_DECAY \
             --clip-grad 1.0 \
             --lr-warmup-fraction $LR_WARMUP_FRACTION \
             --checkpoint-activations \
             --log-interval 100 \
             --save-interval $SAVE_INTERVAL \
             --eval-interval 100 \
             --eval-iters 10 \
             --fp16 \
             --loss-scale 12 \
             --adam-beta1 $ADAM_BETA_1 \
             --adam-beta2 $ADAM_BETA_2 \
             --init-method-std $INIT_METHOD_STD \
             $ds_args ;
    done;
  done;
done