#!/bin/bash

# Grid Engine options (lines prefixed with #$)
# Runtime limit of 10 hour:
#$ -l h_rt=10:00:00
#
# Set working directory to the directory where the job is submitted from:
#$ -cwd
#
# Request one GPU in the gpu queue:
#$ -q gpu 
#$ -pe gpu-a100 1
#
# Request 4 GB system RAM 
# the total system RAM available to the job is the value specified here multiplied by 
# the number of requested GPUs (above)
#$ -l h_vmem=175G

# Initialise the environment modules and load CUDA version 11.0.2
. /etc/profile.d/modules.sh
module load cuda
#module load cuda/12.1.1
module load anaconda 
conda config --add envs_dirs /exports/eddie/scratch/s2593541/anaconda/envs
conda config --add pkgs_dirs /exports/eddie/scratch/s2593541/anaconda/pkgs
conda activate lrd3_new
#pip install tiktoken
#pip install einops
#pip install pytest
nvidia-smi
export HF_DATASETS_CACHE="/exports/eddie/scratch/s2593541/cache/lm_eval"
export TOKENIZERS_PARALLELISM=false
export LOGLEVEL=ERROR

MAX_LEN=256
NUM_TRAIN_SAMPLES=3000
DISTILL_MODE="hs_last"
LR=1e-2
COMP_VALUES=(0.90 0.85 0.80)
EPOCHS=30
SCALE_COMP=1. 
eval_freq_steps=0 # for pre-training mode only
TV_SCALE=1. # scale for tv loss 
EVAL_BS=8
BATCH_SIZE=4
EPOCHS=30

# LLAMA3 8B
#MODEL=meta-llama/Meta-Llama-3-8B
#CACHE_DIR=/exports/eddie/scratch/s2593541/lrd/cache_train_llama
#SCHED_DIST=4
#eval_freq_steps=0
#SCHED_DIST="4"
#SCALE_COMP=1.0
#TV_SCALE=0.7

# gemma
#MODEL=google/gemma-7b
#CACHE_DIR=/exports/eddie/scratch/s2593541/lrd/cache_train_gemma
#SCHED_DIST="2"
#SCALE_COMP=15.0
#TV_SCALE=0.7

# llama-2-7b-hf
MODEL=meta-llama/Llama-2-7b-hf
CACHE_DIR=/exports/eddie/scratch/s2593541/lrd/cache_train_llama2
SCHED_DIST=3

## Pre-Train
#NUM_TRAIN_SAMPLES=50000
#DISTILL_MODE=""
#MODEL=meta-llama/Llama-2-7b-hf
#CACHE_DIR=/exports/eddie/scratch/s2593541/lrd/cache_train_llama2
#MODEL=meta-llama/Meta-Llama-3-8B
#eval_freq_steps=1000
#EPOCHS=1

# Loop over the COMP values
for i in ${!COMP_VALUES[@]}; do
    COMP=${COMP_VALUES[$i]}

    if [ -z "$DISTILL_MODE" ]; then
        EXP_NAME="${MODEL#*/}_pretrain_${COMP}"
    else
        EXP_NAME="${MODEL#*/}_distill_${COMP}_v2_${SCHED_DIST}"
    fi

    # Check if it's the first iteration
    if [ $i -eq 0 ]; then
        # Command for the first iteration without extra arguments
        python train.py --eval_full --tv_loss=1. --bias_init --alpha=0.5 --lr_schedule="plateau" \
        --schedule_distillation=$SCHED_DIST --scale_compression=$SCALE_COMP --target_param_ratio=$COMP --mask_eval_type="threshold" --act_aware=activation \
        --model_name=$MODEL --epochs=$EPOCHS --eval_freq=3 --distill_mode=$DISTILL_MODE \
        --batch_size=$BATCH_SIZE --lr=$LR --num_train_samples=$NUM_TRAIN_SAMPLES --exp_name=$EXP_NAME \
        --max_length=$MAX_LEN --cache_dir=$CACHE_DIR --save_model=reconstruct \
        --eval_batch_size=$EVAL_BS --eval_freq_steps=$eval_freq_steps
    else

        python train.py --eval_full --tv_loss=1. --bias_init --alpha=0.5 --lr_schedule="plateau" \
        --schedule_distillation=$SCHED_DIST --scale_compression=$SCALE_COMP --target_param_ratio=$COMP --mask_eval_type="threshold" --act_aware=activation \
        --model_name=$MODEL --epochs=$EPOCHS --eval_freq=3 --distill_mode=$DISTILL_MODE \
        --batch_size=$BATCH_SIZE --lr=$LR --num_train_samples=$NUM_TRAIN_SAMPLES --exp_name=$EXP_NAME \
        --max_length=$MAX_LEN --cache_dir=$CACHE_DIR --save_model=reconstruct \
        --eval_batch_size=$EVAL_BS --eval_freq_steps=$eval_freq_steps --load_distill_cache --load_act_cache
    fi
done


