#!/bin/bash

# Step 1: Specify Miniconda installation directory (default is $HOME/miniconda)
INSTALL_DIR=${1:-$HOME/miniconda}  # Use the first argument as the install path or default to $HOME/miniconda

# Step 2: Install Miniconda
echo "Installing Miniconda in $INSTALL_DIR..."
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
bash ~/miniconda.sh -b -p $INSTALL_DIR
eval "$($INSTALL_DIR/bin/conda shell.bash hook)"
conda init
source ~/.bashrc

# Step 3: Create Conda Environment
echo "Creating Conda environment 'arc_env'..."
conda create -n arc_env python=3.10 -y

eval "$(conda shell.bash hook)"
conda activate arc_env

# Step 4: Install Dependencies from requirements.txt
echo "Installing packages from requirements.txt..."
if [ -f requirements.txt ]; then
    pip install -r requirements.txt
else
    echo "requirements.txt file not found!"
    exit 1
fi

# Step 5: Export current directory to PYTHONPATH
echo "Exporting current working directory to PYTHONPATH..."
export PYTHONPATH=$PYTHONPATH:$(pwd)

# Optional: Save the PYTHONPATH to .bashrc for future sessions
echo "export PYTHONPATH=\$PYTHONPATH:$(pwd)" >> ~/.bashrc

# Step 6: Install JAX with TPU or CUDA support

echo "TPU environment detected. Installing JAX with TPU support..."
python3 -m pip install -U -r requirements_gpu.txt "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html


# Step 7: Final confirmation message
echo "TPU setup complete! Conda environment 'arc_env' is activated, JAX installed, and PYTHONPATH is set."

# Ensure Github credentials are stored after the first time
git config --global credential.helper store