#!/bin/bash
#
#SBATCH -N 1
#SBATCH -t 0-01:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err
#SBATCH -a 0-9

if [ -z $AGAINST ]; then
    echo "Missing AGAINST arg"
    exit 64
fi

source ./env.sh

if [ -z $N_ROUNDS ]; then
    N_ROUNDS=20
fi

if [ -z $EPS_EPISODES ]; then
    EPS_EPISODES=0.8
fi

if [ -z $EPS_STEPS ]; then
    EPS_STEPS=0.4
fi

if [ -z $VICTIM_ITERS ]; then
    VICTIM_ITERS=20
fi

if [ -z $VICTIM_LR ]; then
    VICTIM_LR=0.00003
fi

if [ -z $ATTACKER_ITERS ]; then
    ATTACKER_ITERS=20
fi

if [ -z $ATTACKER_LR ]; then
    ATTACKER_LR=0.03
fi

if [ -z $BUDGET ]; then
    BUDGET=3.0
fi

if [ -z $LOG_ROUND_REWARDS ]; then
    LOG_ROUND_REWARDS=""
else
    LOG_ROUND_REWARDS="--log_round_rewards"
fi

SEED=$SLURM_ARRAY_TASK_ID

setting_params="--env bandit --arch 1 --variance 0.3 --context_len 500 --n_actions 5 --n_epochs 500 --epoch 400"
adv_params="--n_envs_eval 200 --n_rounds $N_ROUNDS --eps_episodes $EPS_EPISODES --eps_steps $EPS_STEPS --victim_iters $VICTIM_ITERS --victim_lr $VICTIM_LR --attacker_iters $ATTACKER_ITERS --attacker_lr $ATTACKER_LR --max_poison_diff $BUDGET"

python3 bandit_train_adv.py \
    $setting_params \
    $adv_params \
    --attacker_against $AGAINST \
    $LOG_ROUND_REWARDS \
    --seed $SEED
