#!/bin/bash

# -----------------------------------------
# Setup script for multi-framework GPU/CPU environment
# Supports:
#   - Linux + CUDA 11.8 (JAX, PyTorch, TensorFlow)
#   - macOS (osx-arm64)
# -----------------------------------------

set -e  # Exit immediately on any error

# Detect platform
OS=$(uname)
ENV_NAME=gradpca_env

echo "🔍 Detected OS: $OS"

# Ensure conda is available in the script environment
eval "$(conda shell.bash hook)"

# Create or reuse the environment
if conda info --envs | grep -q "$ENV_NAME"; then
  echo "✅ Conda environment '$ENV_NAME' already exists"
else
  echo "📦 Creating conda environment '$ENV_NAME'"
  conda create -n "$ENV_NAME" python=3.10 -y
fi

conda activate "$ENV_NAME"
pip install --upgrade pip

# Install shared dependencies (no jaxlib or torch pinned here)
pip install -r requirements.txt

# Platform-specific install
if [[ "$OS" == "Linux" ]]; then
  echo "📦 Installing TensorFlow (CPU-only)"
  pip install tensorflow-cpu==2.16.2 tensorflow_datasets

  echo "🚀 Installing JAX with CUDA 11.8"
  pip install jax==0.4.23 flax==0.7.5 optax==0.1.7 ml-dtypes==0.3.2 tensorstore jaxlib==0.4.23+cuda11.cudnn86 \
    -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 

  echo "🔥 Installing PyTorch with CUDA 11.8"
  pip install detectors==0.1.11 timm==1.0.15 torch==2.1.2+cu118 torchvision==0.16.2+cu118 \
  -f https://download.pytorch.org/whl/torch_stable.html


elif [[ "$OS" == "Darwin" ]]; then
  echo "🍏 Installing CPU-only JAX, Torch, and TensorFlow for macOS"
  pip install tensorflow==2.16.2 tensorflow_datasets
  pip install jax==0.4.23 jaxlib==0.4.23 flax==0.7.5 optax==0.1.7 ml-dtypes==0.3.2 tensorstore
  pip install torch==2.1.2 torchvision==0.16.2 detectors==0.1.11 timm==1.0.15

else
  echo "❌ Unsupported OS: $OS"
  exit 1
fi
