#!/bin/bash
set -e
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[0;33m'
CYAN='\033[0;36m'
WHITE='\033[1;37m'
RESET='\033[0m'
YELLOW='\033[0;33m'
MAGENTA='\033[0;35m'

SECONDS=0

seed=42
dataset=gsm8k
exp_dir="${EXP_DIR:-"ads_jvp_sanity_check_lp_neurips_oaigsm8k"}"
mkdir -p "${exp_dir}"
echo -e "${YELLOW}Experiment directory: ${exp_dir}${RESET}"

PY="time uv run accelerate launch --config_file acc_config.yaml"

# Run the Python script to generate parameters and save to a temporary file
python grid.py $(hostname) > params_temp.txt

# Initialize the taulamepss array
declare -a taulamepss

# Read the file line by line into the array
while IFS= read -r line; do
    taulamepss+=("$line")
done < params_temp.txt

echo -e "$(hostname)" 
echo -e "TAU      LAM      EPS" 
for item in "${taulamepss[@]}"; do
	echo $item
done

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

teacher="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
proxy_student="Qwen/Qwen2.5-3B"
student="meta-llama/Llama-3.2-3B"

run_stage() {
    local stage="$1"
    local sentinel="$2"
    local cmd="$3"
    
    if [ -e "${sentinel}" ]; then
        echo -e "${YELLOW}⏭️ Skipping ${stage}: ${sentinel} already exists.${RESET}"
        return 0
    else
        local clean_cmd=$(echo "$cmd" | tr '\n' ' ' | sed 's/  */ /g')
        echo -e "${CYAN}🌀 ${stage}:\n> ${clean_cmd}${RESET}"
        eval "$cmd"
        echo -e "${GREEN}✅ ${stage} completed.${RESET}"
        return 0
    fi
}

stage="HOLDOUT"
trace_name="holdout"
holdout_sentinel="${exp_dir}/traces/${trace_name}"
cmd="$PY \
    gentraces.py \
    hydra.run.dir=${exp_dir}/metadata/holdout \
    teacher=${teacher} \
    exp_dir=${exp_dir} \
    seed=${seed} \
    data_split=${dataset}_holdout \
    trace_name=${trace_name}"
run_stage "$stage" "$holdout_sentinel" "$cmd"

stage="STUDENT GRAD"
grad_path="${exp_dir}/student_grads.pt"
grad_sentinel="${grad_path}"
cmd="$PY save_grad.py ${holdout_sentinel}.yaml --proxy_student=${proxy_student}"
run_stage "$stage" "$grad_sentinel" "$cmd"

for taulameps in "${taulamepss[@]}"; do
    read -r tau lam eps <<< "$taulameps"
    stage="AD SAMPLING TAU=${tau}, LAM=${lam}, EPS=${eps}"
    trace_name="tau${tau}_lam${lam}_eps${eps}"
    ad_sentinel="${exp_dir}/traces/${trace_name}"
    batch_size=$([[ "$lam" == "0.0e+00" ]] && echo "512" || echo "192")
    cmd="$PY \
        gentraces.py \
        use_jvp=true \
        hydra.run.dir=${exp_dir}/metadata/train/${trace_name} \
        exp_dir=${exp_dir} \
        seed=${seed} \
        data_split=${dataset}_train \
        grad_path=${grad_path} \
        batch_size=48 \
        tau=${tau} \
        lam=${lam} \
        eps=${eps} \
        trace_name=${trace_name}"
    run_stage "$stage" "$ad_sentinel" "$cmd"

    stage="DISTILLATION TAU=${tau}, LAM=${lam}, EPS=${eps}"
    model_name="llama-3b-base_tau${tau}_lam${lam}_eps${eps}"
    model_path="${exp_dir}/models/${model_name}"
    ad_traces="${exp_dir}/traces/tau${tau}_lam${lam}_eps${eps}"
    distill_sentinel="${model_path}/final"
    cmd="$PY \
        distill.py \
        hydra.run.dir=${exp_dir}/metadata/distill/${model_name} \
        exp_dir=${exp_dir} \
        train_traces=${ad_traces} \
        holdout_traces=${holdout_sentinel} \
        model_name=${model_name}"
    run_stage "$stage" "$distill_sentinel" "$cmd"

    stage="EVAL TAU=${tau}, LAM=${lam}, EPS=${eps}"
    eval_traces="eval_tau${tau}_lam${lam}_eps${eps}"
    eval_sentinel="${exp_dir}/traces/${eval_traces}"
    cmd="$PY \
        gentraces.py \
        hydra.run.dir=${exp_dir}/metadata/eval/tau${tau}_lam${lam}_eps${eps} \
        teacher=${distill_sentinel} \
        teacher_cfg=${model_path}.yaml \
        use_wandb=true \
        exp_dir=${exp_dir} \
        seed=${seed} \
        data_split=${dataset}_test \
        trace_name=${eval_traces}"
    run_stage "$stage" "$eval_sentinel" "$cmd"
done

duration=$SECONDS
printf "${WHITE}\n🎯 All processes completed in %02dh:%02dm:%02ds${RESET}\n" \
  $((duration / 3600)) $(((duration % 3600) / 60)) $((duration % 60))
