#!/bin/bash
# Train pick_place trajectory encoder with normalized EE input
#
# Dataset: 160 trajectories (80 REACH + 80 CARRY), each 64 steps
# The encoder learns to embed trajectory shapes into a 2D latent space
#
# KEY DESIGN - solving mode collapse while ensuring REACH/CARRY same-CP overlap:
#
# 1. Encoder input: Trajectory-frame normalized EE (3D: progress, perp1, perp2)
#    - This is IDENTICAL for REACH and CARRY with same CP params
#    - So encoder will produce IDENTICAL z for same-CP trajectories
#
# 2. Decoder output: Full state (22D) + action (8D) trajectory
#    - This is MUCH HARDER than reconstructing 3D curve
#    - Forces z to be discriminative (can't get away with collapsed z)
#
# 3. Decoder context: s0 + trajectory_length + time
#    - s0: WHERE the trajectory starts
#    - trajectory_length: HOW FAR it goes
#    - time: WHEN in the trajectory
#    - z must encode HOW it curves (the mode/CP)
#
# Output will be saved to: encoder/z2_normalized_ee/

cd /home/placeholder/dsrl/dppo/RLBench_pick_place/encoder

# Enable wandb logging
# export WANDB_MODE=disabled

python train_normalized_ee_encoder.py \
    --horizon=64 \
    --latent_dim=2 \
    --hidden_dim=128 \
    --num_layers=4 \
    --num_heads=4 \
    --batch_size=32 \
    --num_epochs=5000 \
    --lr=1e-4 \
    --kl_weight=0.001 \
    --dtw_weight=1.0 \
    --vis_freq=100 \
    --save_freq=500 \
    --dataset_path=/scratch4/workspace/placeholder-hdp/dppo/data/stack_blocks/variation0/processed/train_raw.npz \
    --metadata_path=/scratch4/workspace/placeholder-hdp/dppo/data/stack_blocks/variation0/train \
    --save_dir=/scratch4/workspace/placeholder-hdp/dppo/data/stack_blocks/variation0/encoder \
    --device=cuda:0 \
    --wandb_project=pick-place-normalized-ee-encoder
