#!/bin/bash
NAME="0"
ENV="spiral_env"
DATASET="spiral_env"
DATASET_EVAL="spiral_env_eval"
Z_DIM=32
N_STEP_PREDICTIOM=3

W_RL_LOW=1.0
W_RL_HIGH=0.2
W_DYN=0.001

W_BC_HIGH=0.001
W_BC_EXP_HIGH=0.5

W_BC_LOW=0.001
W_BC_EXP_LOW=0.0
N_QS_LOW=3

LOAD_PRETRAINED_LOW=1
LOAD_PRETRAINED_HIGH=1
LOAD_PRETRAINED_DYN=1

BATCH_SIZE=64
ENC_ITERS=50000
POLICY_LOW_ITERS=25000
POLICY_HIGH_ITERS=25000
DYN_ITERS=25000
ACTION_SAMPLER_ITERS=150000
BC_ITERS=150000
IRIS_ITERS=150000


RUNS=5
for ((i = 0 ; i < $RUNS ; i++)); do
  EXPNAME="bench/${ENV}_${NAME}_${i}"

  echo "launching experiment ${EXPNAME}"

  # train encoder
  python3 train/train_rl_encoder.py \
  --exp_name $EXPNAME \
  --n_iters $ENC_ITERS \
  --env $ENV \
  --dataset_train $DATASET \
  --dataset_eval $DATASET_EVAL \
  --z_dim $Z_DIM \
  --w_rl_low $W_RL_LOW \
  --w_rl_high $W_RL_HIGH \
  --w_bc_high $W_BC_HIGH \
  --w_bc_exp_high $W_BC_EXP_HIGH \
  --w_bc_low $W_BC_LOW \
  --w_bc_exp_low $W_BC_EXP_LOW \
  --w_dyn $W_DYN \
  --T_contr 1.0 \
  --n_step_prediction $N_STEP_PREDICTIOM \
  --discount 0.96 \
  --lr_policy 3e-4 \
  --lr_critic 3e-4 \
  --batch_size $BATCH_SIZE

  # compute latent stats
  python3 train/compute_encodings.py \
  --exp_name $EXPNAME \
  --plot 0

  # train action sampler
  python3 train/train_action_sampler.py  \
  --exp_name $EXPNAME \
  --n_iters $ACTION_SAMPLER_ITERS \
  --vae_latent_dim 16 \
  --vae_beta 0.01

  # Finetune dynamics
  python3 train/train_dynamics.py \
  --exp_name $EXPNAME \
  --n_iters $DYN_ITERS \
  --load_pretrained $LOAD_PRETRAINED_DYN

  # Finetune policy low
  python3 train/train_policy_low.py  \
  --n_iters $POLICY_LOW_ITERS \
  --w_bc $W_BC_LOW \
  --w_bc_exp $W_BC_EXP_LOW \
  --n_qs $N_QS_LOW \
  --lr_policy 1e-4 \
  --lr_critic 1e-4 \
  --exp_name $EXPNAME \
  --load_pretrained $LOAD_PRETRAINED_LOW

  # Finetune policy high
  python3 train/train_policy_high.py  \
  --n_iters $POLICY_HIGH_ITERS \
  --w_bc $W_BC_HIGH \
  --w_bc_exp $W_BC_EXP_HIGH \
  --exp_name $EXPNAME \
  --lr_policy 1e-4 \
  --lr_critic 1e-4 \
  --load_pretrained $LOAD_PRETRAINED_HIGH

done
