#!/bin/bash
#SBATCH --job-name=HSC
#SBATCH --output=output_%A_%a.out
#SBATCH --error=error_%A_%a.out
#SBATCH --array=1-5
#SBATCH --ntasks=1
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=1
#SBATCH --mem-per-cpu=16G
#SBATCH --time=04:00:00

source ~/.bashrc
conda activate upfi

seed=$SLURM_ARRAY_TASK_ID

c=0.5
reg=vf
N=500
betamax=5.5

teacherforcing=true
train_iters=10000
_teacherforcingits=$([ $teacherforcing = true ] && echo $train_iters || echo 0)

python ../../src/train_cmd.py --data "sim_HSC_N_"$N"_T_10_c_"$c"_beta_"$betamax".pkl" --suffix "default_c_"$c"_seed_"$seed \
    --train_score --train_pfi --train_upfi \
    --train_score_iters 10000 --train_score_batch 256 --train_score_lr 1e-2 --hidden_sizes_score 128 128 128 \
    --train_upfi_iters $train_iters --train_upfi_batch 256 --train_upfi_lr 3e-3 --train_upfi_reg $reg --train_upfi_teacherforcing_its $_teacherforcingits --hidden_sizes_upfi 128 128 128 --hidden_sizes_upfi_g 128 \
    --train_pfi_iters $train_iters --train_pfi_batch 256  --train_pfi_lr 3e-3 --train_pfi_reg $reg --train_pfi_teacherforcing_its $_teacherforcingits --hidden_sizes_pfi 128 128 128 \
    --train_ode_iters $train_iters --train_ode_batch 256 --train_ode_lr 3e-3 --train_ode_teacherforcing_its $_teacherforcingits --hidden_sizes_ode 128 128 128 \
    --train_tigon_iters $train_iters --train_tigon_batch 256 --train_tigon_reg $reg --train_tigon_lr 3e-3 --train_tigon_teacherforcing_its 0 --hidden_sizes_tigon 128 128 128 \
    --reach 5.0 --reg_wfr 0.001 --alpha_wfr 1 \
    --score_logsigma_min -2 --score_logsigma_max 0 \
    --sigma_anneal_iters $([ $teacherforcing = true ] && echo None || echo 2000) \
    --D 0.25 \
    --outdir weights --figdir plots  --seed $seed

echo "Done running train_cmd, now starting run_mult"
SLURM_ARRAY_TASK_ID=$SLURM_ARRAY_TASK_ID bash run_mult.sh

echo "Done running run_mult.sh, now starting eval"
SLURM_ARRAY_TASK_ID=$SLURM_ARRAY_TASK_ID bash run_evals.sh
