# For Topic-XICL
# 1. prepare clustered training data and test data
DATA_PATH="./data/"
python src/A_data_preprocess/data.py DATA_PATH

TENSORIZE_DIR="./src/B_train_Topic_model/output/tensorizer/tensorized_bloomz1b7"
# 2. train the latent topic model
## for xnli
OUTPUT_PATH="./src/B_train_Topic_model/output/checkpoints-xnli-"
for seed in 32 44 100
do
  CUDA_VISIBLE_DEVICES=0 python  train_xnli.py --dataset xnli \
  --file_path ${DATA_PATH} \
  --tensorize_dir ${TENSORIZE_DIR} \
  --n_prefix_tokens 10 \
  --n_clusters 20 \
  --max_length_per_example 256 \
  --max_length 256 \
  --lr 1e-6 \
  --warmup_steps 100 \
  --gradient_accumulation_steps 8 \
  --num_training_steps 500 \
  --save_period 1000 \
  --data_seed ${seed} \
  --out_dir ${OUTPUT_PATH}${seed}
done


## for xcopa
OUTPUT_PATH="./src/B_train_Topic_model/output/checkpoints-xcopa-"
for seed in 32 44 100
do
  CUDA_VISIBLE_DEVICES=0 python  train_xnli.py --dataset xcopa \
  --file_path ${DATA_PATH} \
  --tensorize_dir ${TENSORIZE_DIR} \
  --n_prefix_tokens 15 \
  --n_clusters 5 \
  --max_length_per_example 256 \
  --max_length 256 \
  --lr 1e-6 \
  --warmup_steps 20 \
  --gradient_accumulation_steps 8 \
  --num_training_steps 250 \
  --save_period 500 \
  --data_seed ${seed} \
  --out_dir ${OUTPUT_PATH}${seed}
done

## for tydiqa
OUTPUT_PATH="./src/B_train_Topic_model/output/checkpoints-tydiqa-"
for seed in 32 44 100
do
  CUDA_VISIBLE_DEVICES=0 python  train_tydiqa.py --dataset tydiqa \
  --file_path ${DATA_PATH} \
  --tensorize_dir ${TENSORIZE_DIR} \
  --n_prefix_tokens 15 \
  --n_clusters 20 \
  --max_length_per_example 512 \
  --max_length 512 \
  --lr 1e-6 \
  --warmup_steps 100 \
  --gradient_accumulation_steps 8 \
  --num_training_steps 500 \
  --save_period 1000 \
  --data_seed ${seed} \
  --out_dir ${OUTPUT_PATH}${seed}
done


# 3. construct in-context learning data by topic inference
## for xnli
OUTPUT_PATH="./src/B_train_Topic_model/output/checkpoints-xnli-"
for seed in 32 44 100
do
  CUDA_VISIBLE_DEVICES=0 python  inference_xnli.py --dataset xnli \
  --file_path ${DATA_PATH} \
  --n_prefix_tokens 10 \
  --n_clusters 20 \
  --max_length_per_example 256 \
  --max_length 256 \
  --seed ${seed} \
  --data_name cluster_1b7 \
  --concept_dir ${OUTPUT_PATH}${seed} \
  --prefix_embed_file ${OUTPUT_PATH}${seed}/soft_embeddings-4000.pt
done


## for xcopa
OUTPUT_PATH="./src/B_train_Topic_model/output/checkpoints-xcopa-"
for seed in 32 44 100
do
  CUDA_VISIBLE_DEVICES=0 python  inference_xnli.py --dataset xcopa \
  --file_path ${DATA_PATH} \
  --n_prefix_tokens 15 \
  --n_clusters 5 \
  --max_length_per_example 256 \
  --max_length 256 \
  --seed ${seed} \
  --data_name cluster_1b7 \
  --concept_dir ${OUTPUT_PATH}${seed} \
  --prefix_embed_file ${OUTPUT_PATH}${seed}/soft_embeddings-2000.pt
done

## for tydiqa
OUTPUT_PATH="./src/B_train_Topic_model/output/checkpoints-tydiqa-"
for seed in 32 44 100
do
  CUDA_VISIBLE_DEVICES=0 python  inference_tydiqa.py --dataset tydiqa \
  --file_path ${DATA_PATH} \
  --n_prefix_tokens 15 \
  --n_clusters 20 \
  --max_length_per_example 512 \
  --max_length 512 \
  --seed ${seed} \
  --data_name cluster_1b7 \
  --concept_dir ${OUTPUT_PATH}${seed} \
  --prefix_embed_file ${OUTPUT_PATH}${seed}/soft_embeddings-4000.pt
done


# 4. test in-context learning in six LLMs
for a in bigscience/bloom-7b1 meta-llama/Meta-Llama-3.1-8B Qwen/Qwen1.5-7B
do
  python test_tydiqa.py --set_up cluster_1b7_4000 --model_path ${a} --file_path ${DATA_PATH} --per_max_len 512
done

for a in bigscience/bloom-7b1 meta-llama/Meta-Llama-3.1-8B Qwen/Qwen1.5-7B
do
  python test_xnli.py --set_up cluster_1b7_4000 --model_path ${a} --file_path ${DATA_PATH} --per_max_len 256
done

for a in bigscience/bloom-7b1 meta-llama/Meta-Llama-3.1-8B Qwen/Qwen1.5-7B
do
  python test_xcopa.py --set_up cluster_1b7_2000 --model_path ${a} --file_path ${DATA_PATH} --per_max_len 256
done