#!/usr/bin/sh


# get rotation parameters
LLAMA2_PATH="./modelzoo/llama2/llama-2-7b"
LLAMA2_OUTPUT_PATH="./output/llama2_w1a4kv16"

torchrun --nnodes 1 --nproc_per_node 8 --master-addr localhost --master-port 8902 main.py \
	--output_dir $LLAMA2_OUTPUT_PATH --model $LLAMA2_PATH  \
	--loss_type=kl_top --post_attn=True \
	--rotate_ov=True --rotate_post_rope=False --online_qk_hadamard=True --smooth_qk=True --smooth_ov=True --smooth_up_down=True --smooth_norm_linear=True \
	--bf16=True --lm_eval=True --per_device_train_batch_size=4 \
	--max_steps=100 --w_bits=1 --a_bits=4 --v_bits=4 --k_bits=4 --down_bits=4 \
	--train_enable_wquant=False --sub_mean False  --distribute=True --use_klt

# btc-llm
ROTATE_OUTPUT_PATH=$LLAMA2_OUTPUT_PATH + "/model.bin"
VECTOR_LENGTH=8
NUM_CENTROIDS=128
LOG_PATH=$LLAMA2_OUTPUT_PATH + "/logs/llama2-7b-arb-v$VECTOR_LENGTH-c$NUM_CENTROIDS.log"


python main.py \
	--output_dir $LLAMA2_OUTPUT_PATH --model $LLAMA2_PATH  \
	--loss_type=kl_top --post_attn=True \
	--rotate_ov=True --rotate_post_rope=False --online_qk_hadamard=False --smooth_qk=True --smooth_ov=True --smooth_up_down=True --smooth_norm_linear=True \
	--bf16=True --lm_eval=True --per_device_train_batch_size=4 \
	--max_steps=100 --a_bits=16 --v_bits=16 --k_bits=16 --down_bits=16 --w_clip=True\
	--train_enable_wquant=True --sub_mean False --distribute=True \
	--resume_path=$ROTATE_OUTPUT_PATH \
	--low_quant_method arb --nsamples 128 --blocksize 128  --salient_metric hessian --num_p 2 --order2_group \
	--use_last_iter_quantization --vector_length $VECTOR_LENGTH --num_centroids $NUM_CENTROIDS --max_iter 2 \
	> $LOG_PATH 2>&1 &

tail -f $LOG_PATH