#!/usr/bin/env bash
set -euo pipefail

# ============================================
# Configuration - Edit these values as needed
# ============================================

# Environment parameters
ALPHA=0.05
WORLD="random"          # beta, beta_mixture, or random (50/50 each rollout batch)
CONC=6.0                # used when conc randomization is off
CONC_MIN=0.1            # lower bound for conc randomization
CONC_MAX=11.0           # upper bound for conc randomization

# Training hyperparameters
EPISODES=960000
BUFFER_CAPACITY=3200000
BATCH_SIZE=512
TARGET_UPDATE_INTERVAL=1000
# Set TAU to empty string or comment out to use None (hard updates)
# Set TAU to a float value (e.g., 0.005) to use soft updates
TAU=""
MIN_BUFFER_SIZE=40000
NUM_ENVS=16
TRAIN_FREQ=16
ACTOR_UPDATE_INTERVAL=1000
EXPLORE_EPS_START=1.0
EXPLORE_EPS_END=0.02
EXPLORE_EPS_DECAY=0.99998

# File paths (SUFFIX can be customized; empty by default)
SUFFIX="dqn"
CHECKPOINT_PATH="best_dqn_policy_${SUFFIX}.pt"
TRAINING_PLOT="training_returns_${SUFFIX}.png"
REJECTION_PLOT="rejection_curves_${SUFFIX}.png"
EPS_PLOT="dqn_eps_on_logwealth_${SUFFIX}.png"
MODAL_EPS_PLOT="modal_eps_grid_${SUFFIX}.png"
MODAL_CONFIDENCE_PLOT="modal_confidence_grid_${SUFFIX}.png"
MIN_VISITS=5
MODAL_GRID_TRIALS=5000
T_BIN_WIDTH=5

# Domain randomization parameters
# Set DOMAIN_RANDOMIZE to 1 to enable domain randomization, 0 to disable
DOMAIN_RANDOMIZE=1
N_RANGE_MIN=100
N_RANGE_MAX=350
M_RANGE_MIN=0.01
M_RANGE_MAX=0.99
DIFFICULTY_RANGE_MIN=0.70
DIFFICULTY_RANGE_MAX=1.30
MU_CLIP_MIN=0.01
MU_CLIP_MAX=0.99
LCAP=100.0

# Learning rate
LEARNING_RATE=3e-4

# Python entrypoint (wrapper around the modular dqn/ package)
PYTHON_SCRIPT="dqn_entrypoint.py"

# Greedy eval (Option 3) for checkpoint selection
EVAL_EPISODES=4000
EVAL_SEED=12345
EVAL_BATCH_SIZE=512

# ============================================
# Run training
# ============================================

echo "========================================="
echo "DQN Training Configuration:"
echo "========================================="
echo "Environment:"
echo "  alpha: $ALPHA"
echo "  world: $WORLD"
echo "  conc (base): $CONC"
echo ""
echo "Training:"
echo "  Episodes: $EPISODES"
echo "  Buffer capacity: $BUFFER_CAPACITY"
echo "  Batch size: $BATCH_SIZE"
echo "  Target update interval: $TARGET_UPDATE_INTERVAL"
if [ -z "$TAU" ]; then
    echo "  Tau (soft update): None (hard updates)"
else
    echo "  Tau (soft update): $TAU"
fi
echo "  Min buffer size: $MIN_BUFFER_SIZE"
echo "  Num envs: $NUM_ENVS"
echo "  Train freq: $TRAIN_FREQ"
echo "  Actor update interval: $ACTOR_UPDATE_INTERVAL"
echo "  Explore eps start: $EXPLORE_EPS_START"
echo "  Explore eps end: $EXPLORE_EPS_END"
echo "  Explore eps decay: $EXPLORE_EPS_DECAY"
echo "  Learning rate: $LEARNING_RATE"
echo ""
if [ "$DOMAIN_RANDOMIZE" = "1" ]; then
    echo "Domain Randomization:"
    echo "  Enabled: Yes"
    echo "  N range: [$N_RANGE_MIN, $N_RANGE_MAX]"
    echo "  m range: [$M_RANGE_MIN, $M_RANGE_MAX]"
    echo "  Difficulty range: [$DIFFICULTY_RANGE_MIN, $DIFFICULTY_RANGE_MAX]"
    echo "  μ clip range: [$MU_CLIP_MIN, $MU_CLIP_MAX]"
    echo "  Lambda cap: $LCAP"
    echo "  conc range: [$CONC_MIN, $CONC_MAX]"
else
    echo "Domain Randomization:"
    echo "  conc (fixed): $CONC"
fi
echo ""
echo "Output files:"
echo "  Checkpoint: $CHECKPOINT_PATH"
echo "  Training plot: $TRAINING_PLOT"
echo "  Rejection plot: $REJECTION_PLOT"
echo "  Epsilon plot: $EPS_PLOT"
echo "  Modal epsilon grid plot: $MODAL_EPS_PLOT"
echo "  Modal confidence grid plot: $MODAL_CONFIDENCE_PLOT"
echo "  Min visits for masking: $MIN_VISITS"
echo "  Modal grid evaluation trials: $MODAL_GRID_TRIALS"
echo "  Time bin width: $T_BIN_WIDTH"
echo "========================================="
echo ""

# Build command arguments
CMD_ARGS=(
    train
    --alpha "$ALPHA"
    --world "$WORLD"
    --conc "$CONC"
    --conc_range "$CONC_MIN" "$CONC_MAX"
    --episodes "$EPISODES"
    --buffer_capacity "$BUFFER_CAPACITY"
    --batch_size "$BATCH_SIZE"
    --target_update_interval "$TARGET_UPDATE_INTERVAL"
)

# Only add --tau if TAU is set and non-empty
if [ -n "$TAU" ]; then
    CMD_ARGS+=(--tau "$TAU")
fi

CMD_ARGS+=(
    --min_buffer_size "$MIN_BUFFER_SIZE"
    --num_envs "$NUM_ENVS"
    --train_freq "$TRAIN_FREQ"
    --actor_update_interval "$ACTOR_UPDATE_INTERVAL"
    --explore_eps_start "$EXPLORE_EPS_START"
    --explore_eps_end "$EXPLORE_EPS_END"
    --explore_eps_decay "$EXPLORE_EPS_DECAY"
    --lr "$LEARNING_RATE"
    --eval_episodes "$EVAL_EPISODES"
    --eval_seed "$EVAL_SEED"
    --eval_batch_size "$EVAL_BATCH_SIZE"
    --checkpoint_path "$CHECKPOINT_PATH"
    --training_plot "$TRAINING_PLOT"
    --rejection_plot "$REJECTION_PLOT"
    --eps_plot "$EPS_PLOT"
    --modal_eps_plot "$MODAL_EPS_PLOT"
    --modal_confidence_plot "$MODAL_CONFIDENCE_PLOT"
    --min_visits "$MIN_VISITS"
    --modal_grid_trials "$MODAL_GRID_TRIALS"
    --t_bin_width "$T_BIN_WIDTH"
)

# Add domain randomization arguments if enabled
if [ "$DOMAIN_RANDOMIZE" = "1" ]; then
    CMD_ARGS+=(--domain_randomize)
    CMD_ARGS+=(--N_range "$N_RANGE_MIN" "$N_RANGE_MAX")
    CMD_ARGS+=(--m_range "$M_RANGE_MIN" "$M_RANGE_MAX")
    CMD_ARGS+=(--difficulty_range "$DIFFICULTY_RANGE_MIN" "$DIFFICULTY_RANGE_MAX")
    CMD_ARGS+=(--mu_clip "$MU_CLIP_MIN" "$MU_CLIP_MAX")
    CMD_ARGS+=(--lcap "$LCAP")
fi

python "$PYTHON_SCRIPT" "${CMD_ARGS[@]}"

EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
    echo ""
    echo "========================================="
    echo "Training completed successfully!"
    echo "========================================="
else
    echo ""
    echo "========================================="
    echo "Training failed with exit code $EXIT_CODE"
    echo "========================================="
    exit $EXIT_CODE
fi

