#!/bin/bash

echo "running"


rsync -av 'ROOTDIR/titer/' 'HOMEDIR/titer' --exclude '/wandb'
cd 'HOMEDIR/titer/'

# ************************* Parse command line arguments *************************
CUDA=""
TRAIN_MODE=""
EWC_LAMBDA=""
BUFFER_SIZE=""
DATA_SET=""

# Default values
EWC_LAMBDA=0.0
BUFFER_SIZE=0

# Parse the command line arguments
while getopts "c:t:e:b:d:h:l:p:" opt; do
  case "$opt" in
  c)  CUDA=$OPTARG
      ;;
  t)  TRAIN_MODE=$OPTARG
      ;;
  e)  EWC_LAMBDA=$OPTARG
      ;;
  b)  BUFFER_SIZE=$OPTARG
      ;;
  d)  DATA_SET=$OPTARG
      ;;
  h)  ATTENTION_HEAD=$OPTARG
      ;;
  l)  ATTENTION_LAYER=$OPTARG
      ;;
  p)  PREFIX=$OPTARG
      ;;
  *) echo "Invalid option -$OPTARG" >&2
      ;;
  esac
done

echo $CUDA

# Handle default values in case of inconsistency in command line arguments
if [[ ! "$TRAIN_MODE" == *"ewc"* ]]; then
  echo "not ewc"
  EWC_LAMBDA=0.0
fi

# If TRAIN_MODE does not contain 'er', set BUFFER_SIZE to 0
if [[ ! "$TRAIN_MODE" == *"er"* ]]; then
  echo "not er"
  BUFFER_SIZE=0
fi

if [[ "$DATA_SET" == *"GDELT"* ]]; then
  TIME_SPAN=15
  END=15
  VALID_EPOCH=60
elif [[ "$DATA_SET" == *"ICEWS1807"* ]]; then
  # Set variables for JCPenny condition
  # Example values:
  echo $DATA_SET
  TIME_SPAN=24
  END=15
  VALID_EPOCH=30
else
  echo $DATA_SET
  TIME_SPAN=24
  END=33
  VALID_EPOCH=30
fi


# ************************* Set directory paths *************************
DATA_DIR="data/$DATA_SET"
PARENT_OUTPUT_DIR="titer_ours_outputs/$DATA_SET"
CHECKPOINT_SAVE_PATH="$PARENT_OUTPUT_DIR/checkpoints_${PREFIX}${TRAIN_MODE}_${EWC_LAMBDA}_${BUFFER_SIZE}"
echo $CHECKPOINT_SAVE_PATH

if [ ! -d "$DATA_DIR" ];then
  echo 'copying data'
  mkdir -p "$DATA_DIR"
  cp -r "ROOTDIR/datasets/${DATA_SET}" "HOMEDIR/titer/data/"
fi

if [ ! -d "$PARENT_OUTPUT_DIR" ];then
  mkdir -p "$PARENT_OUTPUT_DIR"
fi

if [ ! -d "$CHECKPOINT_SAVE_PATH" ];then
  echo 'copying checkpoints'
  cp -r "ROOTDIR/$CHECKPOINT_SAVE_PATH" "$PARENT_OUTPUT_DIR/"
fi



export CUDA_LAUNCH_BLOCKING=1
# ************************* Run training *************************
# Basic Fine-tuning
if [[ "$TRAIN_MODE" == "sgd" ]]; then
echo 'running'
CUDA_VISIBLE_DEVICES=$CUDA python continual_main.py \
 --data_path $DATA_DIR \
 --do_train \
 --cuda \
 --train_on_test \
 --time_span $TIME_SPAN \
 --num_layers $ATTENTION_LAYER \
 --nheads $ATTENTION_HEAD \
 --end $END \
 --save_path $CHECKPOINT_SAVE_PATH \
 --load_path $CHECKPOINT_SAVE_PATH \
 --valid_epoch $VALID_EPOCH \
 --prefix $PREFIX \
# --max_epochs 2 \
# --valid_epoch 1
# --IM \
# --start_epoch 120
# --valid_epoch 1 \



# Experience Replay Fine-tuning

elif [[ "$TRAIN_MODE" == "er" ]]; then
CUDA_VISIBLE_DEVICES=$CUDA python continual_main.py \
 --data_path $DATA_DIR \
 --do_train \
 --train_on_test \
 --time_span $TIME_SPAN \
 --num_layers $ATTENTION_LAYER \
 --nheads $ATTENTION_HEAD \
 --end $END \
 --cuda \
 --er \
 --max_buffer_size $BUFFER_SIZE \
 --save_path $CHECKPOINT_SAVE_PATH \
 --load_path $CHECKPOINT_SAVE_PATH \
 --valid_epoch $VALID_EPOCH \
# --valid_epoch 10 \
# --max_epochs 1


 # EWC Fine-tuning
elif [[ "$TRAIN_MODE" == "ewc" ]]; then
CUDA_VISIBLE_DEVICES=$CUDA python continual_main.py \
--data_path $DATA_DIR \
--do_train \
--train_on_test \
--time_span $TIME_SPAN \
--num_layers $ATTENTION_LAYER \
--nheads $ATTENTION_HEAD \
--end $END \
--cuda \
--ewc \
--ewc_lambda $EWC_LAMBDA \
--save_path $CHECKPOINT_SAVE_PATH \
--load_path $CHECKPOINT_SAVE_PATH \
--valid_epoch $VALID_EPOCH \
#--valid_epoch 10 \
#--max_epochs 1

fi

echo "Copying code outputs to the shared directory"
if [[ ! -d "ROOTDIR/titer_ours_outputs/$DATA_SET" ]]; then
  mkdir -p "ROOTDIR/titer_ours_outputs/$DATA_SET"
fi
cp -r $CHECKPOINT_SAVE_PATH "ROOTDIR/titer_ours_outputs/$DATA_SET/"