python prune.py \
  --device cuda:1 \
  --config_path configs/llama2/7b.yaml \
  --test_before_train \
  --instantation_model \
  --instantation_freq 10 \
  --instantation_test_mask \
  --K_inner 8 \
  --initialize_method wanda \
  --ppl_during_train \
  --ppl_freq 500 