#!/bin/bash
set -e

# Set default values
export SEED=1
export NUM_LANDMARKS=4096
export WEIGHT_BY_INFLUENCE=false
export TASK=squad
export PHASE_LENS=0_10000
export FORCE_NO_OPTIM=false
export NORMALIZE_GRADS=true
export POOL_SIZE=50k

export RESTART_AFTER_WARMUP=true
export DRY_RUN=false

export MODEL_NAME=llama2-7b

export DISTILL_MODEL=meta-llama/Llama-2-7b-hf
export DISTILL_TAG=_distill_llama2
export DISTILL_PROJ_DIM=131072
export DISTILL_PROJ_TYPE=had
export DISTILL_MASK_NUMEL=1000000000
export DISTILL_PROB=false
export DISTILL_PARAM_REGEX=down_proj

export BASE_DATA_PATH=./data/training_data
export BASE_CKPT_PATH=./checkpoints
export BASE_EMBDS_PATH=./embeddings

export INDEX_PATH=

# Parse arguments
for ARGUMENT in "$@"
do
   KEY=$(echo $ARGUMENT | cut -f1 -d=)
   KEY_LENGTH=${#KEY}
   VALUE="${ARGUMENT:$KEY_LENGTH+1}"
   export "$KEY"="$VALUE"
   echo "Argument parsed: $KEY=$VALUE"
done

export DATA_PATH=${BASE_DATA_PATH}/subset_${POOL_SIZE}.jsonl

# if index path is empty, set it to default path
if [ -z "${INDEX_PATH}" ]; then
    export INDEX_PATH=${BASE_EMBDS_PATH}/subset_${POOL_SIZE}_jvp_${MODEL_NAME}.pt
fi

if [ ! -f ${INDEX_PATH} ]; then
    echo "Index not found, existing..."
    exit 1
fi

NO_OPTIM=$FORCE_NO_OPTIM

NO_OPTIM_ARG=
if [[ "$NO_OPTIM" == "true" ]]; then
    NO_OPTIM_ARG="--no_optim"
    echo "Not using optim for infdist (due to either no warmup or force!)."
fi

DISTILL_PROB_ARG=
if [[ "$DISTILL_PROB" == "true" ]]; then
    DISTILL_PROB_ARG="--probabilistic"
    echo "Using probabilistic selection."
fi

NORMALIZE_GRADS_ARG=
if [[ "$NORMALIZE_GRADS" == "true" ]]; then
    NORMALIZE_GRADS_ARG="--normalize_grads"
    echo "Normalizing gradients."
fi

export RUN_NAME=infdist_${MODEL_NAME}_raw${RESTART_AFTER_WARMUP}_${TASK}_nl${NUM_LANDMARKS}_wbi${WEIGHT_BY_INFLUENCE}_tulu_v2_${TASK}_${POOL_SIZE}_phases${PHASE_LENS}_${DISTILL_PROJ_DIM}_${DISTILL_PROJ_TYPE}_${DISTILL_MASK_NUMEL}_prob2${DISTILL_PROB}_${DISTILL_PARAM_REGEX}${DISTILL_TAG}${NO_OPTIM_ARG}_ng${NORMALIZE_GRADS}_seed${SEED}_$RANDOM
echo "RUN_NAME=$RUN_NAME"

export RUN_DIR=${BASE_CKPT_PATH}/${RUN_NAME}
mkdir -p ${RUN_DIR}

cat << EOF > config.json
{
    "task": "$TASK",
    "num_landmarks": "$NUM_LANDMARKS",
    "weight_by_influence": "$WEIGHT_BY_INFLUENCE",
    "phase_lens": "$PHASE_LENS",
    "distill_proj_dim": "$DISTILL_PROJ_DIM",
    "distill_proj_type": "$DISTILL_PROJ_TYPE",
    "distill_mask_numel": "$DISTILL_MASK_NUMEL",
    "distill_prob": "$DISTILL_PROB",
    "distill_param_regex": "$DISTILL_PARAM_REGEX",
    "distill_tag": "$DISTILL_TAG",
    "no_optim": "$NO_OPTIM",
    "normalize_grads": "$NORMALIZE_GRADS",
    "seed": "$SEED",
    "restart_after_warmup": "$RESTART_AFTER_WARMUP",
    "pool_size": "$POOL_SIZE",
    "model_name": "$MODEL_NAME"
}
EOF

# separate the phase lengths
IFS='_' read -r -a PHASE_LENS <<< "$PHASE_LENS"

BASE_MODEL=${DISTILL_MODEL}
for i in "${!PHASE_LENS[@]}"; do
    if [[ "${PHASE_LENS[i]}" -eq 0 ]]; then
        echo "[Phase ${i}]: Skipping phase ${i} with length zero."
        continue
    fi

    export PHASE_DIR=${RUN_DIR}/phase_${i}
    export PHASE_DATA_PATH=${PHASE_DIR}/data_${PHASE_LENS[i]}.jsonl
    export PHASE_MODEL_DIR=${PHASE_DIR}/model

    echo "PHASE_DIR=${PHASE_DIR}"
    echo "PHASE_DATA_PATH=${PHASE_DATA_PATH}"

    mkdir -p ${PHASE_DIR}

    WEIGHT_ARG=
    if [[ "$i" -eq 0 ]]; then
        echo "[Phase ${i}]: Warmup with length ${PHASE_LENS[i]}."

        if [ ! -f "${PHASE_DATA_PATH}" ]; then
            CMD="head -n "${PHASE_LENS[i]}" ${BASE_DATA_PATH}/warmup_10k.jsonl > ${PHASE_DATA_PATH}"
            echo "$CMD"
            [ "$DRY_RUN" == "false" ] && eval "$CMD"
        else
            echo "Warmup dataset already exists. Skipping..."
        fi
        NUM_EPOCHS=2
    else
        echo "[Phase ${i}]: Length ${PHASE_LENS[i]}"

        echo "[Phase ${i}] Computing influence for $TASK."
        CMD="python -m minimal_multitask.compute_influence_infdist \
            --model_name ${DISTILL_MODEL} \
            --proj_dim ${DISTILL_PROJ_DIM} \
            --proj_type ${DISTILL_PROJ_TYPE} \
            --mask_numel ${DISTILL_MASK_NUMEL} \
            --param_regex ${DISTILL_PARAM_REGEX} \
            --seed ${SEED} \
            --train_dataset ${BASE_DATA_PATH}/subset_${POOL_SIZE}.jsonl \
            --eval_dataset ${TASK} \
            --index_path ${INDEX_PATH} \
            --batch_size 1 \
            --num_landmarks ${NUM_LANDMARKS} \
            --num_samples ${PHASE_LENS[i]} \
            --output_file ${PHASE_DATA_PATH} ${NO_OPTIM_ARG} ${NORMALIZE_GRADS_ARG} ${DISTILL_PROB_ARG}" 
        echo "$CMD"
        [ "$DRY_RUN" == "false" ] && eval "$CMD"

        if [[ "$WEIGHT_BY_INFLUENCE" == "true" ]]; then
            WEIGHT_ARG="--weight_by_influence"
            echo "Weighting by influence."
        fi

        NUM_EPOCHS=2
    fi

    if [ "$i" -eq 1 ] && [ "$RESTART_AFTER_WARMUP" == "true" ]; then
        echo "[Phase ${i}]: Resetting the model after warmup."
        DISTILL_MODEL=${BASE_MODEL}
    fi

    if [ "$i" -eq 0 ] && [ -d "${PHASE_MODEL_DIR}" ]; then
        echo "Warmed up model already exists. Skipping..."
    else
        echo "[Phase ${i}]: Running training."
        CMD="python -m minimal_multitask.instruction_tune \
                --model_name ${DISTILL_MODEL} \
                --output_dir ${PHASE_MODEL_DIR} \
                --per_device_train_batch_size 1 \
                --gradient_accumulation_steps 128 \
                --num_train_epochs ${NUM_EPOCHS} \
                --learning_rate 2e-5 \
                --seed ${SEED} \
                --warmup_ratio 0.03 \
                --lr_scheduler_type linear \
                --weight_decay 0. \
                --evaluation_strategy no \
                --save_strategy epoch \
                --logging_steps 1 \
                --is_llama=True \
                --use_hf_auth_token True \
                --train_dataset $PHASE_DATA_PATH $WEIGHT_ARG"
        echo "$CMD"
        [ "$DRY_RUN" == "false" ] && eval "$CMD"
    fi
    
    last_checkpoint=$(find ${PHASE_MODEL_DIR} -maxdepth 1 -type d -name 'checkpoint-*' | sed 's|.*/checkpoint-||' | sort -n | tail -n 1)
    DISTILL_MODEL=${PHASE_MODEL_DIR}/checkpoint-${last_checkpoint}
    echo "[Phase ${i}]: Model changed to ${DISTILL_MODEL}."
done

CMD="bash run_eval.sh ${PHASE_MODEL_DIR}"
echo "$CMD"
[ "$DRY_RUN" == "false" ] && eval "$CMD"
