#!/bin/bash

# HYPERPARAMETERS
DATASTORE_SIZE=50000000
DATASTORE_SIZE_HUMAN=50M
MAX_LENGTH=200
EMBEDDING_MODEL=castorini/tct_colbert-msmarco
EMBEDDING_MODEL_TAG=tct_colbert
EMBEDDING_DIM=768
KMEANS_MODEL=faiss
KMEANS_CLUSTERS=32000

# Step 1: Collect embeddings for training the k-means model
# this takes roughly 10 minutes and occupies 30GB of disk space
DATASTORE_PATH="/home/data/${EMBEDDING_MODEL_TAG}_${DATASTORE_SIZE_HUMAN}.memmap"
# run the following command if the file does not exist
if [ ! -f $DATASTORE_PATH ]; then
  python scripts/baseline_kmeans/build_datastore.py \
    --model_name $EMBEDDING_MODEL \
    --memmap_file $DATASTORE_PATH \
    --dstore_size $DATASTORE_SIZE \
    --batch_size 128 \
    --max_length $MAX_LENGTH \
    --device cuda
fi

# Step 2: Train Clustering Model
KMEANS_MODEL_PATH="/home/data/${EMBEDDING_MODEL_TAG}_${DATASTORE_SIZE_HUMAN}_K${KMEANS_CLUSTERS}.faiss"
python scripts/baseline_kmeans/train_kmeans.py \
 --memmap_file $DATASTORE_PATH \
 --dstore_size $DATASTORE_SIZE \
 --dimension $EMBEDDING_DIM \
 --model $KMEANS_MODEL \
 --num_clusters $KMEANS_CLUSTERS \
 --n_init 5 \
 --max_iter 100 \
 --random_state 8 \
 --model_path $KMEANS_MODEL_PATH

# Step 3: Map Text Corpus to Cluster Centroids
INDEX_OUTPUT_PATH=/home/data/indices_${EMBEDDING_MODEL_TAG}_${DATASTORE_SIZE_HUMAN}_K${KMEANS_CLUSTERS}
python scripts/baseline_kmeans/prepare_text_for_index.py \
 --model_name $EMBEDDING_MODEL \
 --clustering_model $KMEANS_MODEL \
 --clustering_model_path $KMEANS_MODEL_PATH \
 --index_path $INDEX_OUTPUT_PATH \
 --max_length $MAX_LENGTH \
 --batch_size 512 \
 --device cuda

# Step 4: Create BM25 Index (pyserini)
# the previous script should have created 3 directories under $INDEX_OUTPUT_PATH
# > data, data_documents_token_and_code, data_documents_unique_code
bash scripts/baseline_kmeans/build_bm25.sh \
  $INDEX_OUTPUT_PATH/data \
  $INDEX_OUTPUT_PATH/data_documents_token_and_code \
  $INDEX_OUTPUT_PATH/data_documents_unique_code 2>&1 | tee -a "${INDEX_OUTPUT_PATH}/bm25_index_output.log"

# Step 5: Evaluation on MSMARCO dev set
# encoding: code
python scripts/baseline_kmeans/retrieve_with_bm25.py \
  --model $EMBEDDING_MODEL \
  --clustering_model $KMEANS_MODEL \
  --clustering_model_path $KMEANS_MODEL_PATH \
  --dataset msmarco-passage/dev/small \
  --index_path $INDEX_OUTPUT_PATH/data \
  --output_path $INDEX_OUTPUT_PATH/data/eval \
  --encoding code \
  --limit -1 \
  --batch_size 128 \
  --device cuda \
  --max_length $MAX_LENGTH
# encoding: code+text
python scripts/baseline_kmeans/retrieve_with_bm25.py \
  --model $EMBEDDING_MODEL \
  --clustering_model $KMEANS_MODEL \
  --clustering_model_path $KMEANS_MODEL_PATH \
  --dataset msmarco-passage/dev/small \
  --index_path $INDEX_OUTPUT_PATH/data_documents_token_and_code \
  --output_path $INDEX_OUTPUT_PATH/data_documents_token_and_code/eval \
  --encoding code_plus_text \
  --limit -1 \
  --batch_size 128 \
  --device cuda \
  --max_length $MAX_LENGTH
# encoding: unique_code
python scripts/baseline_kmeans/retrieve_with_bm25.py \
  --model $EMBEDDING_MODEL \
  --clustering_model $KMEANS_MODEL \
  --clustering_model_path $KMEANS_MODEL_PATH \
  --dataset msmarco-passage/dev/small \
  --index_path $INDEX_OUTPUT_PATH/data_documents_unique_code \
  --output_path $INDEX_OUTPUT_PATH/data_documents_unique_code/eval \
  --encoding code_unique \
  --limit -1 \
  --batch_size 128 \
  --device cuda \
  --max_length $MAX_LENGTH