#!/bin/bash
# Train pick_place trajectory encoder with endpoint direction conditioning
#
# Dataset: 160 trajectories (80 REACH + 80 CARRY), each 64 steps
# The encoder learns to embed trajectory shapes into a 2D latent space
#
# KEY DESIGN - solving mode collapse while handling varying s0:
#
# 1. Encoder input: Full state+action trajectory (30D per timestep)
#    - Like close_drawer, encodes the full trajectory shape
#    - Forces z to be discriminative (can't get away with collapsed z)
#
# 2. Decoder input: z + endpoint_direction + time
#    - endpoint_direction: normalized (end_ee - start_ee), 3D
#    - This tells decoder WHERE to go, not HOW
#    - z must encode HOW (the trajectory curve/mode)
#
# 3. Why endpoint_direction instead of full s0:
#    - Full s0 differs for REACH vs CARRY → decoder could overfit
#    - endpoint_direction is minimal: just the direction of travel
#    - Forces z to carry trajectory shape information
#
# Output will be saved to: encoder/z2_endpoint/

cd /home/placeholder/dsrl/dppo/RLBench_pick_place/encoder

# Enable wandb logging
# export WANDB_MODE=disabled

python train_endpoint_encoder.py \
    --horizon=64 \
    --latent_dim=2 \
    --hidden_dim=128 \
    --num_layers=4 \
    --num_heads=4 \
    --batch_size=32 \
    --num_epochs=5000 \
    --lr=1e-4 \
    --kl_weight=0.001 \
    --dtw_weight=1.0 \
    --vis_freq=100 \
    --save_freq=500 \
    --dataset_path=/scratch4/workspace/placeholder-hdp/dppo/data/stack_blocks/variation0/processed/train_raw.npz \
    --metadata_path=/scratch4/workspace/placeholder-hdp/dppo/data/stack_blocks/variation0/train \
    --save_dir=/scratch4/workspace/placeholder-hdp/dppo/data/stack_blocks/variation0/encoder \
    --device=cuda:0 \
    --wandb_project=pick-place-endpoint-encoder
