#!/bin/bash
#SBATCH --time=96:00:00
#SBATCH --mem=64G
#SBATCH --output=out/%A_%x.out
#SBATCH --cpus-per-task=1
#SBATCH --gpus-per-node=1
#SBATCH --account=<your_allocation>
#SBATCH --mail-type=FAIL
#SBATCH --mail-user=<your_email@example.com>

module load opencv
module load qt/5.15.11
module load mujoco/3.1
export MUJOCO_PATH=$EBROOTMUJOCO
export MUJOCO_PLUGIN_PATH=$MUJOCO_PATH/lib/mujoco/plugin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin:/usr/lib/nvidia
export MUJOCO_GL=osmesa
export PYOPENGL_PLATFORM=osmesa
export WANDB_MODE=offline

unset CUDA_VISIBLE_DEVICES

alg=${1:-'sac_lle'}
env_type=${2:-'dexgym'}
env_id=${3:-'EggHandOver-v0'}
local_window_size=${4:-10}
lle_batch_size=${5:-2048}
lle_learning_rate_W=${6:-1e-2}
lle_learning_rate_Phi=${7:-1e-3}
lle_loss_reduction_threshold_W=${8:-1e-10}
lle_loss_reduction_threshold_Phi=${9:-1e-5}
lle_learning_rate_trunk=${10:-1e-3}
version=${11:-'v1'}
seed=${12:-1}
additional_args=${13:-'--track'}
encoder_name=${14:-'facebook/dinov2-small'}
encoder_dim=${15:-384}

source ~/lle-rl/bin/activate

exp_name=${version}_${alg}_${env_type}_${env_id}_lws${local_window_size}_bs${lle_batch_size}_lrW${lle_learning_rate_W}_lrPhi${lle_learning_rate_Phi}_gsW${lle_loss_reduction_threshold_W}_gsPhi${lle_loss_reduction_threshold_Phi}_lrTr${lle_learning_rate_trunk}_seed${seed}

TOTAL_STEPS=10000000

if [ "$env_type" = "robosuite" ]; then
    case "$alg" in
        "sac_lle")
            python3 sac_continuous_action_robosuite_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                --lle_learning_rate_trunk "$lle_learning_rate_trunk" --use_lle_projection --train_trunk \
                $additional_args
            ;;
        "sac_lcr")
            python3 sac_continuous_action_robosuite_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --use_lcr --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                $additional_args
            ;;
        "sac_joint_lle")
            python3 sac_continuous_action_robosuite_lle_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                --lle_learning_rate_trunk "$lle_learning_rate_trunk" --use_lle_projection --train_trunk \
                $additional_args
            ;;
        "sac_recon"|"sac_recon-recon")
            python3 sac_continuous_action_robosuite_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --ssl_method recon \
                $additional_args
            ;;
        "sac_recon-next_state")
            python3 sac_continuous_action_robosuite_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --ssl_method next_state \
                $additional_args
            ;;
        "sac_recon-reward")
            python3 sac_continuous_action_robosuite_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --ssl_method reward \
                $additional_args
            ;;
        "sac_spr")
            python3 sac_continuous_action_robosuite_spr.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                --ssl_method spr \
                $additional_args
            ;;
        "sac_dbc")
            python3 sac_continuous_action_robosuite_dbc.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                $additional_args
            ;;
        "sac")
            python3 sac_continuous_action_robosuite.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-robosuite \
                $additional_args
            ;;
        *)
            echo "Incorrect algorithm specified for robosuite: $alg"
            ;;
    esac
elif [ "$env_type" = "dexgym" ]; then
    case "$alg" in
        "sac_lle")
            python3 sac_continuous_action_dexgym_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                --lle_learning_rate_trunk "$lle_learning_rate_trunk" --use_lle_projection --train_trunk \
                $additional_args
            ;;
        "sac_lcr")
            python3 sac_continuous_action_dexgym_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --use_lcr --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                $additional_args
            ;;
        "sac_joint_lle")
            python3 sac_continuous_action_dexgym_lle_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                --lle_learning_rate_trunk "$lle_learning_rate_trunk" --use_lle_projection --train_trunk \
                $additional_args
            ;;
        "sac_recon"|"sac_recon-recon")
            python3 sac_continuous_action_dexgym_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --ssl_method recon \
                $additional_args
            ;;
        "sac_recon-next_state")
            python3 sac_continuous_action_dexgym_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --ssl_method next_state \
                $additional_args
            ;;
        "sac_recon-reward")
            python3 sac_continuous_action_dexgym_joint.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --ssl_method reward \
                $additional_args
            ;;
        "sac_spr")
            python3 sac_continuous_action_dexgym_spr.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                --ssl_method spr \
                $additional_args
            ;;
        "sac_dbc")
            python3 sac_continuous_action_dexgym_dbc.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                $additional_args
            ;;
        "sac")
            python3 sac_continuous_action_dexgym.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env-id "$env_id" --total-timesteps $TOTAL_STEPS --wandb-project-name lle-dexgym \
                $additional_args
            ;;
        *)
            echo "Incorrect algorithm specified for dexgym: $alg"
            ;;
    esac
else
    case "$alg" in
        "sac_lle")
            python3 sac_image_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env_type "$env_type" --env_id "$env_id" --wandb-project-name lle \
                --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                --lle_learning_rate_trunk "$lle_learning_rate_trunk" --no-use_pretrained_encoder \
                $additional_args
            ;;
        "sac_lcr")
            python3 sac_image_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env_type "$env_type" --env_id "$env_id" --use_lcr --wandb-project-name lle \
                --local_window_size "$local_window_size" --lle_batch_size "$lle_batch_size" \
                --lle_learning_rate_W "$lle_learning_rate_W" --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --no-use_pretrained_encoder \
                $additional_args
            ;;
        "sac_lle_pretrained")
            python3 sac_image_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env_type "$env_type" --env_id "$env_id" --use_pretrained_encoder \
                --pretrained_encoder_name "$encoder_name" --encoder_feature_dim "$encoder_dim" \
                --wandb-project-name lle --local_window_size "$local_window_size" \
                --lle_batch_size "$lle_batch_size" --lle_learning_rate_W "$lle_learning_rate_W" \
                --lle_learning_rate_Phi "$lle_learning_rate_Phi" \
                --lle_loss_reduction_threshold_W "$lle_loss_reduction_threshold_W" \
                --lle_loss_reduction_threshold_Phi "$lle_loss_reduction_threshold_Phi" \
                --lle_learning_rate_trunk "$lle_learning_rate_trunk" \
                $additional_args
            ;;
        "sac_recon")
            python3 sac_image_lle.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env_type "$env_type" --env_id "$env_id" --wandb-project-name lle \
                --no-use_pretrained_encoder \
                $additional_args
            ;;
        "sac")
            python3 sac_image.py --cuda --exp_name "$exp_name" --seed "$seed" \
                --env_id "${env_type}:${env_id}" --wandb-project-name lle \
                $additional_args
            ;;
        *)
            echo "Incorrect algorithm specified: $alg"
            ;;
    esac
fi
