#!/bin/bash
# shellcheck disable=SC2090,SC2086,SC2089,SC1091
# Default project path
PROJECT_PATH="$HOME/projects/repo"

# Parse command-line options
if ! OPTIONS=$(getopt -o p: --long project_path: -n 'parse-options' -- "$@"); then
	echo "install_hpc_env.sh: Error parsing options" >&2
	exit 1
fi

eval set -- "$OPTIONS"

while true; do
	case "$1" in
	-p | --project_path)
		PROJECT_PATH="$2"
		shift 2
		;;
	--)
		shift
		break
		;;
	*)
		break
		;;
	esac
done
echo "install_hpc_env.sh: Install env in PROJECT_PATH=$PROJECT_PATH"
#! Add modules from scratch to be sure everything works
#! Enable the module command
# shellcheck disable=SC1091
. /etc/profile.d/modules.sh
#! Remove all modules still loaded
module purge
#! Load base modules
module load singularity/current
module load rhel8/slurm
module load dot
module load rhel8/global
module -s load openmpi/4.1.1/gcc-9.4.0-epagguv # This might be unable to locate
#! Load cuda 12.1
module load cuda/12.1
module load cudnn/8.9_cuda-12.1
#! Load additional modules
module load ceuadmin/gettext/0.20
module load vgl/2.5.1/64
#! Install `UV`
UV_VER_OUTPUT=$(uv --version)
if [[ $UV_VER_OUTPUT == *"uv "* ]]; then
	echo "UV is already installed."
else
	#! Getting `UV`
	curl -LsSf https://astral.sh/uv/install.sh | sh
fi

#! Upgrade pip
uv pip install --upgrade pip
#! Monitoring utilities
sudo snap install bpytop
uv pip install nvitop
#! Install cmake
uv pip install cmake
#! Entering the project folder
cd "$PROJECT_PATH" || exit
#! Install the UV env no matter what
# GIT_LFS_SKIP_SMUDGE is necessary to avoid downloading large files
GIT_LFS_SKIP_SMUDGE=true uv sync -q

# shellcheck disable=SC1091
. "$VIRTUAL_ENV"/bin/activate
#! Check the output of `nvcc -V`
NVCC_OUTPUT=$(nvcc -V)
if [[ $NVCC_OUTPUT == *"release 12.1"* ]]; then
	echo "install_hpc_env.sh: CUDA 12.1 is detected."
else
	echo "install_hpc_env.sh: CUDA 12.1 not detected. Please install CUDA 12.1. Exiting..."
	exit 1
fi
#! Install `flash-attn`
if ! uv run pip list | grep -q flash-attn; then
	echo "install_hpc_env.sh: Installing flash-attn..."
	uv run uv pip install -q flash-attn==2.3.2 --no-build-isolation
else
	echo "install_hpc_env.sh: flash-attn is already installed."
fi
#! Downgrade python warnings (default in CSD3 is 'debug')
#! From here: https://docs.python.org/3/using/cmdline.html#envvar-PYTHONWARNINGS
#! And here: https://docs.python.org/3/library/warnings.html#describing-warning-filters
export PYTHONWARNINGS="ignore::DeprecationWarning,ignore::ResourceWarning"
#! Check Python version
PYTHON_OUTPUT=$(python --version)
if [[ $PYTHON_OUTPUT == *"3.10.13"* ]]; then
	echo "install_hpc_env.sh: Python 3.10.13 is detected."
else
	echo "install_hpc_env.sh: Python 3.10.13 not detected. Please install Python 3.10.13. Exiting..."
	exit 1
fi
#! Final message
echo "install_hpc_env.sh: Environment is ready."

# Set TRITON_CACHE_DIR to be system dependent, as such must be under '/home/<username>'
TRITON_CACHE_DIR="/home/$(whoami)/.triton_cache"
export TRITON_CACHE_DIR
