#!/bin/bash
#SBATCH --time=18:00:00
#SBATCH --mem=12G
#SBATCH --account=<your_allocation>
#SBATCH --output=out/%A_%x.out
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1
#SBATCH --array=1-10
#SBATCH --mail-type=FAIL
#SBATCH --mail-user=<your_email@example.com>
#SBATCH --gpus-per-node=1

# module load httpproxy
module load opencv/4.8.1
module load mujoco/3.1.6 
# module load libffi
# module load python/3.10
source ~/lle-rl/bin/activate
export WANDB_MODE=disabled

unset CUDA_VISIBLE_DEVICES

alg=${1:-'sac_lle'}
game=${2:-'Lift-Panda'}
local_window_size=${3:-5}
lle_batch_size=${4:-2048}
lle_learning_rate_W=${5:-1e-2}
lle_learning_rate_Phi=${6:-1e-3}
lle_loss_reduction_threshold_W=${7:-1e-15}
lle_loss_reduction_threshold_Phi=${8:-1e-10}
lle_learning_rate_trunk=${9:-1e-3}
version=${10:-'v0'}
additional_args=${11:-''}
seed=${SLURM_ARRAY_TASK_ID:-0}
exp_name=${version}_${alg}_${game}_lws${local_window_size}_bs${lle_batch_size}_lrW${lle_learning_rate_W}_lrPhi${lle_learning_rate_Phi}_gsW${lle_loss_reduction_threshold_W}_gsPhi${lle_loss_reduction_threshold_Phi}_lrTr${lle_learning_rate_trunk}

if [ "$alg" = "sac_lle" ]; then
    python3 sac_continuous_action_robosuite_lle.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol \
        --local_window_size $local_window_size --lle_batch_size $lle_batch_size --lle_learning_rate_W $lle_learning_rate_W --use_lle_projection --train_trunk \
        --lle_learning_rate_Phi $lle_learning_rate_Phi --lle_loss_reduction_threshold_W $lle_loss_reduction_threshold_W \
        --lle_loss_reduction_threshold_Phi $lle_loss_reduction_threshold_Phi --lle_learning_rate_trunk $lle_learning_rate_trunk \
        $additional_args
elif [ "$alg" = "sac_lcr" ]; then
    python3 sac_continuous_action_robosuite_lle.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --use_lcr --wandb-project-name lle-gymcontrol \
        --local_window_size $local_window_size --lle_batch_size $lle_batch_size --lle_learning_rate_W $lle_learning_rate_W \
        --lle_learning_rate_Phi $lle_learning_rate_Phi --lle_loss_reduction_threshold_W $lle_loss_reduction_threshold_W \
        $additional_args
elif [ "$alg" = "sac_recon_lle" ]; then
    python3 sac_continuous_action_robosuite_attention.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol \
        --local_window_size $local_window_size --lle_batch_size $lle_batch_size --lle_learning_rate_W $lle_learning_rate_W --use_lle_projection --train_trunk \
        --lle_learning_rate_Phi $lle_learning_rate_Phi --lle_loss_reduction_threshold_W $lle_loss_reduction_threshold_W \
        --lle_loss_reduction_threshold_Phi $lle_loss_reduction_threshold_Phi --lle_learning_rate_trunk $lle_learning_rate_trunk --lle_epochs 0 \
        $additional_args
elif [ "$alg" = "sac_joint_lle" ]; then
    python3 sac_continuous_action_robosuite_lle_joint.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol \
        --local_window_size $local_window_size --lle_batch_size $lle_batch_size --lle_learning_rate_W $lle_learning_rate_W --use_lle_projection --train_trunk \
        --lle_learning_rate_Phi $lle_learning_rate_Phi --lle_loss_reduction_threshold_W $lle_loss_reduction_threshold_W \
        --lle_loss_reduction_threshold_Phi $lle_loss_reduction_threshold_Phi --lle_learning_rate_trunk $lle_learning_rate_trunk \
        $additional_args
elif [[ "$alg" == "sac_recon-"* ]]; then
    ssl_method=${alg#sac_recon-}
    python3 sac_continuous_action_robosuite_joint.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol --ssl_method $ssl_method $additional_args
elif [ "$alg" = "sac_recon" ]; then
    python3 sac_continuous_action_robosuite_joint.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol --ssl_method recon 
elif [ "$alg" = "sac_spr" ]; then
    python3 sac_continuous_action_robosuite_spr.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol --ssl_method spr 
elif [ "$alg" = "sac_dbc" ]; then
    python3 sac_continuous_action_robosuite_dbc.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol
elif [ "$alg" = "sac" ]; then
    python3 sac_continuous_action_robosuite.py --cuda --exp_name $exp_name --seed $seed --env-id $game --total-timesteps 600000 --wandb-project-name lle-gymcontrol
else
    echo "algorithm undefined"
fi
