#!/bin/bash


eval "$(conda shell.bash hook)"
conda activate jax
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

export GEMMA_MODEL_NAME="google/gemma-2-2b"
python $HOME/src/sae-jax/train_jax.py --model_name=$GEMMA_MODEL_NAME --mode="preprocess"
python $HOME/src/sae-jax/train_jax.py --model_name=$GEMMA_MODEL_NAME --mode="train"
