GPU_ID=${1:-6}
EXP_NAME="gen_sparse"
N_TIMESTEP=8
PROMPT_PATH="prompts_2_b"
CALIB_DATA_NAME1="${EXP_NAME}/attn_155_t${N_TIMESTEP}_${PROMPT_PATH}"
N_TIMESTEP_2=30
CALIB_DATA_NAME3="${EXP_NAME}/permute_attn_64_t${N_TIMESTEP_2}_${PROMPT_PATH}"

# Step 1: export the downsampled attn_map for permute plan
CUDA_VISIBLE_DEVICES=$GPU_ID python quant_inference.py \
	--quant-config ./configs/fp.yaml \
	--log ./logs/calib_data/$EXP_NAME \
	--num-sampling-steps $N_TIMESTEP  \
	--prompt ${PROMPT_PATH}.txt \
	--export-calib-data $CALIB_DATA_NAME1  # the calib_data name is specified in the config

# Step 2: generate the permute_plan with the exported calib_data (1_5_5 downsampled)
CUDA_VISIBLE_DEVICES=$GPU_ID python get_permute_plan.py \
	--config ./configs/permute.yaml \
	--calib_data ./visualization/calib_data/$CALIB_DATA_NAME1.pth \
	--log ./logs/calib_data/$EXP_NAME  

# # Step 3: export the downsampled permuted attn_map for sparse plan
# # should not use sparse.yaml, donot have sparse_mask yet. 
CUDA_VISIBLE_DEVICES=$GPU_ID python quant_inference.py \
	--quant-config ./configs/permute.yaml \
	--log ./logs/calib_data/$EXP_NAME \
	--num-sampling-steps $N_TIMESTEP_2  \
	--prompt ${PROMPT_PATH_2}.txt \
	--export-calib-data ${CALIB_DATA_NAME3} # the calib_data name is specified in the config

# Step 4: generate the sparse_plan with the exported calib_data (64 downsampled, maybe reduce last_dim)
# the default permute_plan path is in arg.log/permute_plan.pth, unless contained in quant_config
CUDA_VISIBLE_DEVICES=$GPU_ID python get_sparse_plan.py \
	--config ./configs/sparse.yaml \
	--calib_data ./visualization/calib_data/${CALIB_DATA_NAME3}.pth \
	--log ./logs/calib_data/$EXP_NAME 
	#--plot

# Step 5: final quant_inference with exported sparse_mask
CUDA_VISIBLE_DEVICES=$GPU_ID python quant_inference.py \
	--quant-config ./configs/sparse.yaml \
	--log ./logs/calib_data/${EXP_NAME} \
	--num-sampling-steps $N_TIMESTEP_2 \
	--prompt prompts.txt

# --------------------------------------------------------
# Final: infer with both sparse and quantization
CFG_NAME="final/sparse_0.5_int8.yaml"
EXP_NAME="final_0.5/"

CUDA_VISIBLE_DEVICES=$GPU_ID python quant_inference.py \
		--quant-config ./configs/$CFG_NAME \
		--log ./logs/calib_data/$EXP_NAME \
		--num-sampling-steps 30  \
		--prompt prompts.txt