uuid=$1
numdigits=$2

traintype=${3:-supervised}
skipgen=${4:-false}
SIZE=${5:-small}
orignumdigits=${6:-$((numdigits+1))}
gofast=${7:-false}
batch_size=${8:-64}

# Print all input variables
echo "UUID: "$uuid
echo "NUMDIGITS: "$numdigits
echo "TRAINTYPE: "$traintype
echo "SKIPGEN: "$skipgen
echo "SIZE: "$SIZE
echo "ORIGNUMDIGITS: "$orignumdigits
echo "GOFAST: "$gofast
echo "BATCH_SIZE: "$batch_size

host_name=$(hostname)
WORKING_DIR="/mnt/batch/tasks/shared/LS_root/mounts/clusters/"$host_name"/code/checkpoints/supervised/"$uuid
REGULAR_LOG="/mnt/batch/tasks/shared/LS_root/mounts/clusters/DIRNAME1/code/Users/DIRNAME/addition/logs/supervised_train/"$uuid"_supervisedtrain.log"

echo "Working directory: "$WORKING_DIR
mkdir -p $WORKING_DIR
echo UUID: $uuid >> $WORKING_DIR"/master.log"
echo MASTERLOG: $WORKING_DIR"/master.log" >> $WORKING_DIR"/master.log"
echo REGULARLOG: $REGULAR_LOG >> $WORKING_DIR"/master.log"
echo SIZE: $SIZE >> $WORKING_DIR"/master.log"
echo Git Hash: $(git rev-parse HEAD) >> $WORKING_DIR"/master.log"

# code $WORKING_DIR"/master.log"

SEED=0
EXIT=0

CHECKPOINTDIR=/mnt/batch/tasks/shared/LS_root/mounts/clusters/$host_name/code/checkpoints/supervised/$uuid
CHECKPOINT=$CHECKPOINTDIR/model-$traintype-$numdigits-digits.ckpt
NUMGPUS=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)


echo "CHECKPOINT IS AT: "$CHECKPOINT

# Loop while last command has exit code 0
while [ $EXIT -eq 0 ]; do


    
    numdigits=$((numdigits+1))
    echo "Trying with "$numdigits
    sudo fuser -v /dev/nvidia* -k
    sleep 10

    # Skip if skipgen is true
    if [ "$skipgen" = true ] ; then
        echo "Skipping data generation"
        skipgen=false
    else

        # Loop from 0 to NUMGPUS - 1
        for i in $(seq 0 $((NUMGPUS-1))); do
            echo "Starting data generation on GPU $i"

            if [ "$gofast" = true ]; then
                python3 model_generate_data.py $uuid \
                    --primary_num_digits=$numdigits --device=$i --traintype=$traintype --size=$SIZE --fast &
            else
                python3 model_generate_data.py $uuid \
                    --primary_num_digits=$numdigits --device=$i --traintype=$traintype --size=$SIZE &
            fi
        done

        wait $(jobs -p)

        # python3 model_generate_data.py --primary_num_digits=$numdigits --checkpoint=$CHECKPOINT --fast
        sudo fuser -v /dev/nvidia* -k
        sleep 10
    fi

    # Increment SEED so no repeats
    SEED=$((SEED+1))
    datafile=$CHECKPOINTDIR/data/digits:NUMDIGITS_device:DEVICENUMBER_DATATYPE_dataset.pkl

    json_data="{
        \"num_digits\": $numdigits,
        \"checkpoint\": \"$CHECKPOINT\",
        \"batch_size\": $batch_size,
        \"total_steps\": 0
    }"

    echo $json_data > $CHECKPOINTDIR"/config.json"
    echo "Wrote config.json to "$CHECKPOINTDIR"/config.json"

    if [ "$gofast" = true ]; then
        python3 supervised_train.py --type=decomp --uuid=$uuid --size=$SIZE --seed=$SEED --wandb_id=$wandb \
            --read_data_from=$datafile \
            --generate_digit_start=$orignumdigits --fast
        EXIT=$?
    else
        python3 supervised_train.py --type=decomp --uuid=$uuid --size=$SIZE --seed=$SEED --wandb_id=$wandb \
            --read_data_from=$datafile \
            --generate_digit_start=$orignumdigits
        EXIT=$?
    fi
    #data/2aed21d4-be6e-4f1e-ab06-695bf190f259/digits:10_device:DEVICENUMBER_dataset.pkl

    SEED=$((SEED+1))
    code "logs/supervised_train/"$uuid"_supervisedtrain.log"

    CHECKPOINT="/mnt/batch/tasks/shared/LS_root/mounts/clusters/$host_name/code/checkpoints/supervised/$uuid/model-selftrain-$numdigits-digits.ckpt"
    traintype=selftrain


    # EXIT=100 means that the model was not able to be trained, should retry.
    if [ $EXIT -eq 100 ]; then
        numdigits=$((numdigits-1))
        EXIT=0
        echo Failed to train model, retrying with $numdigits digits >> $WORKING_DIR"/master.log"
    fi

done
