#!/bin/bash


eval "$(conda shell.bash hook)"
conda activate jax
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

# Changeable parameters
export GEMMA_MODEL_NAME_SHORT="gemma-2-2b"
export GEMMA_MODEL_NAME="google/$GEMMA_MODEL_NAME_SHORT"
export SAE_MODEL_PATH="$HOME/$GEMMA_MODEL_NAME_SHORT-sae/k5_final_sae_model.pkl"
export SAE_CODE_PATH="$HOME/$GEMMA_MODEL_NAME_SHORT-sae/k5_whole_sae_final_z.npy"
export MLP_MODEL_PATH="$HOME/dual-map/model/$GEMMA_MODEL_NAME_SHORT/dual_map_${GEMMA_MODEL_NAME_SHORT}.pt"


python $HOME/src/eval/sae-softmax/eval_sae_aware_lsh.py \
    --no_load_sae_weights \
    --no_lsh \
    --total_samples 10000 \
    --model_name $GEMMA_MODEL_NAME \
    --sae_model_path $SAE_MODEL_PATH \
    --sae_code_path $SAE_CODE_PATH \
    --mlp_model_path $MLP_MODEL_PATH
