# Config accelerate first.
# You may need assign the model path and data path manually.
export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=swag
export OOD_TASK=hellaswag
export LEARNING_RATE=1e-5
export EVAL_SPLIT=val
export MODEL_NAME=roberta-base
export SEED=42

# DVAE-D
mlm_task=$TASK_NAME
mlm_prob=0.15
temperature=1
# Train
accelerate launch run_vae_mc.py \
  --model_name_or_path $MODEL_NAME \
  --dataset_name $TASK_NAME \
  --max_length 256 \
  --per_device_train_batch_size 32 \
  --learning_rate $LEARNING_RATE \
  --weight_decay 0.1 \
  --num_train_epochs 3 \
  --do_train \
  --eval_split $EVAL_SPLIT \
  --seed $SEED \
  --mlm_task $mlm_task \
  --mlm_prob $mlm_prob \
  --temperature $temperature \
  --output_dir ./outputs/ckpts/$TASK_NAME/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=$SEED \
  --conf_dir ./outputs/conf/$TASK_NAME/$EVAL_SPLIT/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=$SEED

# OOD Eval
python run_vae_mc.py \
  --model_name_or_path $MODEL_NAME \
  --dataset_name $OOD_TASK \
  --eval_split $EVAL_SPLIT \
  --max_length 256 \
  --per_device_train_batch_size 32 \
  --per_device_eval_batch_size 32 \
  --ckpt_path ./outputs/ckpts/$TASK_NAME/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=${SEED} \
  --conf_dir ./outputs/conf/$OOD_TASK/$EVAL_SPLIT/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=${SEED} 

# DVAE-P
mlm_task=wikitext-103-raw-v1
mlm_prob=0.15
temperature=1
# Train
accelerate launch run_vae_mc.py \
  --model_name_or_path $MODEL_NAME \
  --dataset_name $TASK_NAME \
  --max_length 256 \
  --per_device_train_batch_size 32 \
  --learning_rate $LEARNING_RATE \
  --weight_decay 0.1 \
  --num_train_epochs 3 \
  --do_train \
  --eval_split $EVAL_SPLIT \
  --seed $SEED \
  --mlm_task $mlm_task \
  --mlm_prob $mlm_prob \
  --temperature $temperature \
  --distill \
  --output_dir ./outputs/ckpts/$TASK_NAME/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=$SEED \
  --conf_dir ./outputs/conf/$TASK_NAME/$EVAL_SPLIT/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=$SEED

# OOD Eval
python run_vae_mc.py \
  --model_name_or_path $MODEL_NAME \
  --dataset_name $OOD_TASK \
  --eval_split $EVAL_SPLIT \
  --max_length 256 \
  --per_device_train_batch_size 32 \
  --per_device_eval_batch_size 32 \
  --distill \
  --ckpt_path ./outputs/ckpts/$TASK_NAME/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=${SEED} \
  --conf_dir ./outputs/conf/$OOD_TASK/$EVAL_SPLIT/${MODEL_NAME}_vae_temp=${temperature}_mlm=${mlm_task}-${mlm_prob}_seed=${SEED} 