#!/bin/bash
#
#SBATCH -N 1
#SBATCH -t 0-02:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err
#SBATCH -a 0-5

source ./env.sh

if [ -z $LR ]; then
    LR=1e-4
fi

if [ -z $SHUFFLE ]; then
    SHUFFLE=""
else
    SHUFFLE="--shuffle"
fi

if [ -z $EPOCH ]; then
    EPOCH=150
fi

if [ -z $N_ROUNDS ]; then
    N_ROUNDS=300
fi

if [ -z $EPS_EPISODES ]; then
    EPS_EPISODES=0.8
fi

if [ -z $EPS_STEPS ]; then
    EPS_STEPS=0.4
fi

if [ -z $VICTIM_ITERS ]; then
    VICTIM_ITERS=30
fi

if [ -z $VICTIM_LR ]; then
    VICTIM_LR=0.00003
fi

if [ -z $ATTACKER_ITERS ]; then
    ATTACKER_ITERS=20
fi

if [ -z $ATTACKER_LR ]; then
    ATTACKER_LR=0.03
fi

if [ -z $BUDGET ]; then
    BUDGET=10.0
fi

ATTS=(dpt dpt_frozen npg ppo clean unifrand)

ATT_I=$SLURM_ARRAY_TASK_ID

setting_params="--env darkroom --n_states 25 $SHUFFLE --n_epochs 300 --epoch $EPOCH --lr $LR"
adv_params="--n_envs_eval 200 --n_steps_eval 200 --n_rounds $N_ROUNDS --eps_episodes $EPS_EPISODES --eps_steps $EPS_STEPS --victim_iters $VICTIM_ITERS --victim_lr $VICTIM_LR --attacker_iters $ATTACKER_ITERS --attacker_lr $ATTACKER_LR --max_poison_diff $BUDGET"

python3 mdp_eval_online_against_one.py \
    $setting_params \
    $adv_params \
    --attacker_against ppo-${ATTS[$ATT_I]} \
    --n_seeds 10
