#!/bin/bash

#SBATCH ...

# Stop if any command fails and on unbound variables
set -eu

# ==========================================
# 0. Configuration & Array Mapping
# ==========================================

MODEL_CONFIGS=($STR_MODEL_CONFIGS)
MULTIPLIERS=($STR_MULTIPLIERS)
QUANT_SETUPS=($STR_QUANT_SETUPS)
BACKWARD_SCHEMES=($STR_BACKWARD_SCHEMES)

N_MODELS=${#MODEL_CONFIGS[@]}
N_MULTS=${#MULTIPLIERS[@]}
N_QUANTS=${#QUANT_SETUPS[@]}
N_BACKWARD_SCHEMES=${#BACKWARD_SCHEMES[@]}

# Decode SLURM_ARRAY_TASK_ID to get indices for each dimension
# Index logic: Model -> Multiplier -> Quant Setup
idx_backward_scheme=$(( SLURM_ARRAY_TASK_ID % N_BACKWARD_SCHEMES))
idx_quant=$(( (SLURM_ARRAY_TASK_ID / N_BACKWARD_SCHEMES) % N_QUANTS ))
idx_mult=$(( (SLURM_ARRAY_TASK_ID / (N_BACKWARD_SCHEMES * N_QUANTS)) % N_MULTS ))
idx_model=$(( SLURM_ARRAY_TASK_ID / (N_BACKWARD_SCHEMES * N_QUANTS * N_MULTS) ))

# Select specific parameters based on calculated indices
CURRENT_MODEL_CFG="${MODEL_CONFIGS[$idx_model]}"
CURRENT_MULT="${MULTIPLIERS[$idx_mult]}"
CURRENT_SETUP="${QUANT_SETUPS[$idx_quant]}"
CURRENT_BACKWARD_SCHEME="${BACKWARD_SCHEMES[$idx_backward_scheme]}"

# Parse Model Config
IFS=":" read -r MODEL_SIZE_PREFIX N_LAYER N_EMBD N_HEAD LR BASE_TOKENS <<< "$CURRENT_MODEL_CFG"

# Calculate Tokens (using python for float math)
TOKENS=$(python3 -c "print(int($BASE_TOKENS * $CURRENT_MULT))")

# Parse Quant Setup
IFS=":" read -r GROUP_DIM SCALE_DTYPE UNBIASED SCALE_OVERRIDE <<< "$CURRENT_SETUP"

# ==========================================
# 1. Static Environment Setup
# ==========================================

echo "START TIME: $(date)"
echo "Running on host: $(hostname)"
echo "Job Array ID: ${SLURM_ARRAY_TASK_ID}"
echo "Config: ${MODEL_SIZE_PREFIX} | Multiplier: ${CURRENT_MULT} | Tokens: ${TOKENS} | Backward Scheme: ${CURRENT_BACKWARD_SCHEME}"
echo "Quant: Dim=${GROUP_DIM}, Scale=${SCALE_DTYPE}, Unbiased=${UNBIASED}, Scale Override=${SCALE_OVERRIDE}"

export VOCAB_SIZE=32000 
export BATCH_SIZE=128
export ACC_STEPS=4
export SEQUENCE_LENGTH=512
export DATASET="c4"
export TORCHINDUCTOR_AUTOGRAD_CACHE=0
export WANDB_ENTITY=ist

cd ...
pip install schedulefree
export DATASET_BUFFER="..."

# ==========================================
# 2. Quantization Configuration
# ==========================================

export W_QUANT="NoQuantizer"
export W_BITS=16
export W_QUANT_KWARGS="{}"
export A_QUANT="NoQuantizer"
export A_BITS=16
export A_QUANT_KWARGS="{}"

# Gradients (Dynamic)
export G_QUANT="EdenSRQuantizer"
export G_BITS=4
export G_QUANT_KWARGS="{\"hadamard_dim\": 128, \"group_dim\": ${GROUP_DIM}, \"rerotate\": \"signs\", \"scale_dtype\": \"${SCALE_DTYPE}\", \"unbiased\": \"${UNBIASED}\", \"scale_override\": ${SCALE_OVERRIDE}}"

export BACKWARD_SCHEME="${CURRENT_BACKWARD_SCHEME}"
export BACKWARD_SCHEME_KWARGS="{}"

# ==========================================
# 3. Calculation & Execution
# ==========================================

export ITERATIONS=$((TOKENS / (BATCH_SIZE * ACC_STEPS * SEQUENCE_LENGTH)))
export WARMUP_STEPS=$((ITERATIONS / 10))

# WandB Prefix
SETUP_STR="${GROUP_DIM};${SCALE_DTYPE};${UNBIASED};${SCALE_OVERRIDE}"
if [[ "$BACKWARD_SCHEME" == "Q(E)Q(Wt)t_Q(Et)Q(Xt)t" ]]; then
    BACKWARD_SCHEME_STR=""
else
    BACKWARD_SCHEME_STR="-$BACKWARD_SCHEME"
fi
WANDB_PREFIX="${MODEL_SIZE_PREFIX}-TOK${TOKENS}-${G_QUANT}-triton@${G_BITS}@${SETUP_STR}${BACKWARD_SCHEME_STR}-${DATASET}"

echo "Launching torchrun..."

torchrun --nproc_per_node=4 ./src/main.py \
    --distributed-backend nccl \
    --dataset ${DATASET} \
    --datasets-dir $DATASET_BUFFER \
    --latest-ckpt-interval 1000 \
    --model llama \
    --vocab-size $VOCAB_SIZE \
    --compile \
    --acc-steps ${ACC_STEPS} \
    --batch-size ${BATCH_SIZE} \
    --wandb \
    --wandb-project "backprop-scaling-laws" \
    --wandb-run-prefix "${WANDB_PREFIX}" \
    --log-interval 1 \
    --n-layer ${N_LAYER} \
    --n-embd ${N_EMBD} \
    --n-head ${N_HEAD} \
    --warmup-steps ${WARMUP_STEPS} \
    --iterations ${ITERATIONS} \
    --lr ${LR} \
    --w-quant ${W_QUANT} \
    --w-quant-kwargs "${W_QUANT_KWARGS}" \
    --a-quant ${A_QUANT} \
    --a-quant-kwargs "${A_QUANT_KWARGS}" \
    --g-quant ${G_QUANT} \
    --g-quant-kwargs "${G_QUANT_KWARGS}" \
    --backward-scheme ${BACKWARD_SCHEME} \
    --backward-scheme-kwargs "${BACKWARD_SCHEME_KWARGS}"

echo "END TIME: $(date)"
