#!/bin/bash




# load global parameters
source constants.sh

mkdir -p $CACHE
export HF_HOME=$CACHE
export TRANSFORMERS_CACHE=$CACHE
export HF_DATASETS_CACHE=$CACHE
export HF_DATASETS_IN_MEMORY_MAX_SIZE=0
export TORCH_EXTENSIONS_DIR=$CACHE
export TMPDIR=$CACHE
export WANDB_DIR=${CACHE}/wandb


PREPROCESSED_DATA=${PREPROCESSED_PILE_DIR}



NAME=pile_main_1B_ref:70m
accelerate launch \
    --config_file accelerate_config.yml \
    --num_machines 1 \
    --num_processes 4 \
    --multi_gpu \
    --main_process_port 0 \
    draw/train.py \
    --dataset_name pile \
    --model_type gpt_flash \
    --tokenizer_name togethercomputer/RedPajama-INCITE-Base-7B-v0.1 \
    --do_eval \
    --do_train false \
    --downstream_datasets "boolq" \
    --downstream_num_shots 0  \
    --cache_dir ${CACHE} \
    --dataset_dir ${PREPROCESSED_PILE_DIR} \
    --output_dir ${MODEL_OUTPUT_DIR}/${NAME} \
    --dirichlet_params_path ./configs/main_2048E_seed_42_alpha.json \
    --domain_weight_update_steps 10 \
    --max_token_length 1024 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 16 \
    --dataloader_num_workers 1 \
    --max_steps 5000 \
    --save_strategy steps \
    --save_steps 200 \
    --evaluation_strategy steps \
    --eval_steps 200 \
    --per_device_eval_batch_size 4 \
    --remove_unused_columns=False \
    --learning_rate 1e-3 \
    --lr_end 1e-4 \
    --weight_decay 0.01 \
    --max_grad_norm 1.0 \
    --adam_epsilon 1e-8 \
    --lr_scheduler_name linear_warmup_exponential \
    --warmup_ratio 0.06 \
    --run_name ${NAME} \
    --seed 111 \
    --skip_perplexity_eval \
    --logging_strategy steps \
    --logging_steps 100 \
    --logging_first_step \
    --report_to wandb \
    --optim adamw_torch_fused \
    --adam_beta1 0.9 \
    --adam_beta2 0.99 \
    --bf16 \
    --shuffle \
    --config_overrides="n_positions=1024,n_embd=2048,n_layer=12,n_head=32,rotary_emb_fraction=0.25,tie_word_embeddings=True,scale_attn_by_inverse_layer_idx=False,embd_pdrop=0.0,resid_pdrop=0.0,attn_pdrop=0.0,eos_token_id=0,bos_token_id=0,max_position_embeddings=0,vocab_size=256000" \
    ${ADDITIONAL_ARGS}