#!/bin/bash
#SBATCH --output=outputs/%x/%a-%N-%j.out
#SBATCH --error=outputs/%x/%a-%N-%j.err
#SBATCH --partition=gpuA40x4
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --constraint=scratch
#SBATCH --cpus-per-task=16
#SBATCH --gpus-per-node=1
#SBATCH --gpu-bind=closest
#SBATCH --account=bdkj-delta-gpu
#SBATCH --no-requeue
#SBATCH --time=10:00:00
#SBATCH --mem=200G

source /u/audreyh/workspace/setup_conda.sh
source activate rlhf
export XDG_CACHE_HOME="/work/hdd/bdkj/audreyh/.cache"
export OUTLINES_CACHE_DIR="/u/audreyh/workspace/test-code"
cd code 

task=$1
policy=$2

seed=$(($SLURM_ARRAY_TASK_ID))

python generate.py task=$task policy=$policy sampling.seed=$seed

echo "All processes have finished"
