#!/bin/bash

export PROJECT_CACHE=save_output
export WANDB_MODE=offline
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export TORCH_DISTRIBUTED_DEBUG=OFF
export HYDRA_FULL_ERROR=1
export HF_DATASETS_OFFLINE=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


dataset_name=saferpaca
model=llama3
n_epochs=1
batch_size=32
grad_norm=1
save_every=epoch_$n_epochs
sparsity_ratio=0.0
lr=5e-5
lora_rank=128
lora_alpha=256

gradient_accumulation_steps=1
indexfile=src/dataset/copy_from_dense_log128_llama3.json

exp_name="${dataset_name}_${model}/idLora_rank_${lora_rank}_alpha_${lora_alpha}_lr_${lr}_bs_${batch_size}"
adapter_path="${PROJECT_CACHE}/${exp_name}/epoch-${n_epochs}"
results_path="${PROJECT_CACHE}/${dataset_name}_${model}"


python -u src/train_id.py \
        model=$model \
        datasets=[$dataset_name] \
        exp_name=$exp_name \
        lr=$lr \
        save_every=$save_every \
        n_epochs=$n_epochs \
        batch_size=$batch_size \
        model.fsdp_policy_mp=bfloat16 \
        fsdp_port=$MASTER_PORT \
        optimizer=AdamW \
        grad_norm_strategy=even \
        max_grad_norm=$grad_norm \
        lora_rank=$lora_rank \
        lora_alpha=$lora_alpha \
        indexfile=$indexfile \
        gradient_accumulation_steps=$gradient_accumulation_steps 

python src/eval_model.py --model_name $model --adapter_path $adapter_path --datasets hexphi --results_path $results_path --sparsity_ratio $sparsity_ratio --verbose

