#!/usr/bin/bash

### Task name
#SBATCH --account=xxxxxxx
#SBATCH --job-name=train_gphyt

### Output file
#SBATCH --output=results/slrm_logs/train_gphyt_%j.out


### Start a parallel job for a distributed-memory system on several nodes
#SBATCH --nodes=1

### How many CPU cores to use
#SBATCH --ntasks-per-node=96
#SBATCH --exclusive

### Mail notification configuration
#SBATCH --mail-type=ALL
#SBATCH --mail-user=email@example.com

### Maximum runtime per task
#SBATCH --time=72:00:00

### set number of GPUs per task
#SBATCH --gres=gpu:4

### create time series, i.e. 100 jobs one after another. Each runs for 24 hours
##SBATCH --array=1-10%1

### Set the time limit for the job, allows for graceful shutdown
### Should be lower than the time limit of the partition
### Format: HH:MM:SS
time_limit="72:00:00"

#####################################################################################
############################# Setup #################################################
#####################################################################################

# activate conda environment
export CONDA_ROOT=$HOME/miniforge3
source $CONDA_ROOT/etc/profile.d/conda.sh
export PATH="$CONDA_ROOT/bin:$PATH"
conda activate gphyt

######################################################################################
############################# Set paths ##############################################
######################################################################################
# debug mode
# debug=true
# Set up paths
base_dir="General-Physics-Transformer"
python_exec="${base_dir}/gphyt/run/train.py"
log_dir="${base_dir}/results"
data_dir="${base_dir}/data/datasets"
base_config_file="${base_dir}/gphyt/run/scripts/config.yaml"
# sim_name (same as wandb id)
sim_name="sim-name"
nnodes=1
ngpus_per_node=4
export OMP_NUM_THREADS=1 # (num cpu - num_workers) / num_gpus

# name of the checkpoint to use for training. Can be "best_model" or a number of a epoch directory
# if last_checkpoint, the last checkpoint is used
checkpoint_name="last_checkpoint"
# use a checkpoint to continue training with a new config file (learning rate, etc.)
# if false, the last training is continued
new_training=false
# config to use for new training, located in the log dir
new_config_name="config_cooldown.yaml"


# sim directory
sim_dir="${log_dir}/${sim_name}"

#######################################################################################
############################# Setup sim dir and config file ###########################
#######################################################################################

# delete the sim_dir if it exists and debug is true
if [ "$debug" = true ]; then
    rm -rf $sim_dir
fi

# create the sim_dir if it doesn't exist
mkdir -p $sim_dir

if [ "$new_training" = true ]; then
    # copy a new config file to the sim_dir and use it as the config file
    config_file="${sim_dir}/${new_config_name}"
    restart=false
    echo "Using checkpoint to continue training with new config file..."
else
    # Try to find config file in sim_dir
    restart_config_file="${sim_dir}/config.yaml"
    if [ -f "$restart_config_file" ]; then
    echo "Config file found in $sim_dir, attempting restart..."
        # if the config file is found, use it as the config file
        restart=true
        config_file=$restart_config_file
    else
        echo "No config file found in $sim_dir, starting new training..."
        # copy the base config file to sim_dir and use it as the config file
        cp $base_config_file $sim_dir
        config_file="${sim_dir}/config.yaml"
        restart=false
    fi
fi

#####################################################################################
############################# Training GPM ##########################################
#####################################################################################
echo "--------------------------------"
echo "Starting GPhyT training..."
echo "config_file: $config_file"
echo "sim_dir: $sim_dir"
echo "restart: $restart"
echo "new_training: $new_training"
echo "using checkpoint: $checkpoint_name"
echo "--------------------------------"

exec_args="--config_file $config_file \
    --sim_name $sim_name \
    --log_dir $log_dir \
    --data_dir $data_dir \
    --time_limit $time_limit \
    --checkpoint_name $checkpoint_name"

# Add --restart if the restart flag is true
if [ "$restart" = true ]; then
    exec_args="$exec_args --restart"
fi
if [ "$new_training" = true ]; then
    exec_args="$exec_args --new_training"
fi
# Capture Python output and errors in a variable and run the script
torchrun --standalone --nproc_per_node=$ngpus_per_node $python_exec $exec_args

# move the output file to the sim_dir
mv ${log_dir}/slrm_logs/${sim_name}_${SLURM_JOB_ID}.out $sim_dir
