# Copyright (C) king.com Ltd 2025
# License: Apache 2.0

ppo_online_interactions=100000
dt_eval_env_name="CircleStopEnv-randomAngle_randomRad-v0"
dt_train_epochs=501
dt_model="traj_pdt"
ppo_seed=0
dt_seeds=(0)
dt_batchsize=8

ppo_env_names=(
  "CircleStopEnv-angle0.0_radius0.9-v0"
  "CircleStopEnv-angle0.0_radius1.9-v0"
  "CircleStopEnv-angle0.0_radius2.9-v0"
  "CircleStopEnv-angle0.1_radius0.9-v0"
  "CircleStopEnv-angle0.1_radius1.9-v0"
  "CircleStopEnv-angle0.1_radius2.9-v0"
  "CircleStopEnv-angle0.2_radius0.9-v0"
  "CircleStopEnv-angle0.2_radius1.9-v0"
  "CircleStopEnv-angle0.2_radius2.9-v0"
  "CircleStopEnv-angle0.3_radius0.9-v0"
  "CircleStopEnv-angle0.3_radius1.9-v0"
  "CircleStopEnv-angle0.3_radius2.9-v0"
  "CircleStopEnv-angle0.4_radius0.9-v0"
  "CircleStopEnv-angle0.4_radius1.9-v0"
  "CircleStopEnv-angle0.4_radius2.9-v0"
  "CircleStopEnv-angle0.5_radius0.9-v0"
  "CircleStopEnv-angle0.5_radius1.9-v0"
  "CircleStopEnv-angle0.5_radius2.9-v0"
  "CircleStopEnv-angle0.6_radius0.9-v0"
  "CircleStopEnv-angle0.6_radius1.9-v0"
  "CircleStopEnv-angle0.6_radius2.9-v0"
  "CircleStopEnv-angle0.7_radius0.9-v0"
  "CircleStopEnv-angle0.7_radius1.9-v0"
  "CircleStopEnv-angle0.7_radius2.9-v0"
  "CircleStopEnv-angle0.8_radius0.9-v0"
  "CircleStopEnv-angle0.8_radius1.9-v0"
  "CircleStopEnv-angle0.8_radius2.9-v0"
  "CircleStopEnv-angle0.9_radius0.9-v0"
  "CircleStopEnv-angle0.9_radius1.9-v0"
  "CircleStopEnv-angle0.9_radius2.9-v0"
  "CircleStopEnv-angle1.0_radius0.9-v0"
  "CircleStopEnv-angle1.0_radius1.9-v0"
  "CircleStopEnv-angle1.0_radius2.9-v0"
  "CircleStopEnv-angle1.1_radius0.9-v0"
  "CircleStopEnv-angle1.1_radius1.9-v0"
  "CircleStopEnv-angle1.1_radius2.9-v0"
  "CircleStopEnv-angle1.2_radius0.9-v0"
  "CircleStopEnv-angle1.2_radius1.9-v0"
  "CircleStopEnv-angle1.2_radius2.9-v0"
  "CircleStopEnv-angle1.3_radius0.9-v0"
  "CircleStopEnv-angle1.3_radius1.9-v0"
  "CircleStopEnv-angle1.3_radius2.9-v0"
  "CircleStopEnv-angle1.4_radius0.9-v0"
  "CircleStopEnv-angle1.4_radius1.9-v0"
  "CircleStopEnv-angle1.4_radius2.9-v0"
  "CircleStopEnv-angle1.5_radius0.9-v0"
  "CircleStopEnv-angle1.5_radius1.9-v0"
  "CircleStopEnv-angle1.5_radius2.9-v0"
  "CircleStopEnv-angle1.6_radius0.9-v0"
  "CircleStopEnv-angle1.6_radius1.9-v0"
  "CircleStopEnv-angle1.6_radius2.9-v0"
  "CircleStopEnv-angle1.7_radius0.9-v0"
  "CircleStopEnv-angle1.7_radius1.9-v0"
  "CircleStopEnv-angle1.7_radius2.9-v0"
  "CircleStopEnv-angle1.8_radius0.9-v0"
  "CircleStopEnv-angle1.8_radius1.9-v0"
  "CircleStopEnv-angle1.8_radius2.9-v0"
  "CircleStopEnv-angle1.9_radius0.9-v0"
  "CircleStopEnv-angle1.9_radius1.9-v0"
  "CircleStopEnv-angle1.9_radius2.9-v0"
)

ppo_exp_dirs=()
for env_name in "${ppo_env_names[@]}";
do
  echo "============================================="
  echo "Running PPO on $env_name"
  # Collect data for subsequent DT training
  # get the last line of the scripts output - its the path to the .gz trajectory
  # the additional piping (| tee /dev/tty ) is needed to show the print output
  ppo_exp_dir=$(python3 \
    train_ppo.py \
    --env_id "$env_name" \
    --total_timesteps $ppo_online_interactions \
    --seed "$ppo_seed" \
    --exp_str "seed$ppo_seed" \
    | tee /dev/tty | tail -n 1)
  echo "PPO training finished, exp dir: $ppo_exp_dir"

  ppo_exp_dirs+=("$ppo_exp_dir")

done # env loop

echo "============================================="
echo "Training PDT with on data collected by PPO"


for seed in "${dt_seeds[@]}"; do
  echo "Training DT with seed=$seed, j=1, h=3, ctx=5"
  dt_exp_dir=$(python3 \
    train_pdt.py \
    --seed "$seed" \
    --env "$dt_eval_env_name" \
    --dataset_dirs "${ppo_exp_dirs[@]}" \
    --train_dataset_subset_idxs $(seq 0 48) \
    --use_state_dims 0 1 \
    --pdt_use_sparse_reward \
    --max_train_iters "$dt_train_epochs" \
    --pred_mlp_dropout_p "0.1" \
    --norm_obs \
    --exp_str "ctx5_j1_h3" \
    --transformer_dropout_p "0.1" \
    --model "$dt_model" \
    --context_len 5 \
    --traj_prompt_j 1 \
    --traj_prompt_h 3 \
    --rtg_targets 10 -10 \
    --rtg_scale 1 \
    --batch_size $dt_batchsize \
    --dataset_every_nth_traj 2 \
    | tee /dev/tty | tail -n 1)
  echo "DT training finished, exp dir: $dt_exp_dir"
  echo "============================================="

  echo "Training PDT with seed=$seed, j=2, h=3, ctx=5"
  dt_exp_dir=$(python3 \
    train_pdt.py \
    --seed "$seed" \
    --env "$dt_eval_env_name" \
    --dataset_dirs "${ppo_exp_dirs[@]}" \
    --train_dataset_subset_idxs $(seq 0 48) \
    --use_state_dims 0 1 \
    --pdt_use_sparse_reward \
    --max_train_iters "$dt_train_epochs" \
    --pred_mlp_dropout_p "0.1" \
    --norm_obs \
    --exp_str "ctx5_j2_h3" \
    --transformer_dropout_p "0.1" \
    --model "$dt_model" \
    --context_len 5 \
    --traj_prompt_j 2 \
    --traj_prompt_h 3 \
    --rtg_targets 10 -10 \
    --rtg_scale 1 \
    --batch_size $dt_batchsize \
    --dataset_every_nth_traj 2 \
    | tee /dev/tty | tail -n 1)
  echo "DT training finished, exp dir: $dt_exp_dir"
  echo "============================================="

  echo "Training PDT with seed=$seed, j=4, h=3, ctx=5"
  dt_exp_dir=$(python3 \
    train_pdt.py \
    --seed "$seed" \
    --env "$dt_eval_env_name" \
    --dataset_dirs "${ppo_exp_dirs[@]}" \
    --train_dataset_subset_idxs $(seq 0 48) \
    --use_state_dims 0 1 \
    --pdt_use_sparse_reward \
    --max_train_iters "$dt_train_epochs" \
    --pred_mlp_dropout_p "0.1" \
    --norm_obs \
    --exp_str "ctx5_j4_h3" \
    --transformer_dropout_p "0.1" \
    --model "$dt_model" \
    --context_len 5 \
    --traj_prompt_j 4 \
    --traj_prompt_h 3 \
    --rtg_targets 10 -10 \
    --rtg_scale 1 \
    --batch_size $dt_batchsize \
    --dataset_every_nth_traj 2 \
    | tee /dev/tty | tail -n 1)
  echo "DT training finished, exp dir: $dt_exp_dir"
  echo "============================================="
done  # seed loop


