#!/bin/bash
#SBATCH -A p31796
#SBATCH -p gengpu
#SBATCH --gres=gpu:a100:1
#SBATCH -N 1
#SBATCH -n 1
#SBATCH -t 1:00:00
#SBATCH --mem=32G

# Define environment name/path
ENV_NAME="./node-cuda-12-4"

# Load required modules
module purge
module load mamba/24.3.0

# Create a new conda environment with CUDA-enabled PyTorch
echo "Creating conda environment with CUDA support..."
CONDA_OVERRIDE_CUDA="12.4" mamba create --prefix $ENV_NAME -c nvidia -c pytorch pytorch[build=*cuda12.4*] -y

# Activate the environment
source activate $ENV_NAME

# Install other dependencies
echo "Installing additional dependencies..."
pip install numpy matplotlib seaborn pandas plotly torchdyn

# Verify GPU access and PyTorch CUDA configuration
echo "Testing GPU configuration..."
python -c "import torch; print('CUDA available:', torch.cuda.is_available()); print('CUDA device count:', torch.cuda.device_count()); print('CUDA device name:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU available'); print('PyTorch version:', torch.__version__)"

echo "Environment setup complete!"