###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/ll1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################


# Make HANS dataset.

# Maybe


###################################################################################


# MNLI heuristic cluster model numbers: 15, 25, 26, 42, 44, 46, 56, 61, 63, 70, 73, 89

compute_pefs_mnli() {
    device=$1
    model_num=$2

    MODEL="connectivity/feather_berts_${model_num}"
    PER_EXAMPLES_FISHERS="feather_berts_${model_num}.mnli.no_embeddings.16k.16k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/mnli \
        --split=validation \
        --n_examples=16384 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=16384 \
        --sequence_length=128 \
        --include_embeddings=false
}


# Generalizing cluster.
compute_pefs_mnli 0 0
compute_pefs_mnli 1 1

# Heuristic cluster.
compute_pefs_mnli 2 15
compute_pefs_mnli 3 25

###################################################################################

compute_pefs_hans() {
    device=$1
    model_num=$2

    MODEL="connectivity/feather_berts_${model_num}"
    PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans.no_embeddings.16k.16k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/default \
        --split=validation \
        --n_examples=16384 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=16384 \
        --sequence_length=128 \
        --include_embeddings=false
}


# Generalizing cluster.
compute_pefs_hans 0 0
compute_pefs_hans 1 1

# Heuristic cluster.
compute_pefs_hans 2 15
compute_pefs_hans 3 25

###################################################################################


nmf_mnli () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)

    local device=$1
    local model_num=$2
    local subset_indices=$3

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.mnli.no_embeddings.16k.16k.h5"
    local DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true
}

# 12,13, 16,17, 20,21, 22,23

# Generalizing cluster.
nmf_mnli 0 0 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_mnli 1 1 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24

# Heuristic cluster.
nmf_mnli 2 15 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_mnli 3 25 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24

###################################################################################


nmf_hans () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)
    
    local device=$1
    local model_num=$2
    local subset_indices=$3

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans.no_embeddings.16k.16k.h5"
    local DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true
}

# Generalizing cluster.
nmf_hans 0 0 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans 1 1 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24

# Heuristic cluster.
nmf_hans 2 15 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans 3 25 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


###################################################################################
###################################################################################
###################################################################################
###################################################################################
###################################################################################
###################################################################################
###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/ll1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

compute_pefs_mnli() {
    device=$1
    model_num=$2

    MODEL="connectivity/feather_berts_${model_num}"
    PER_EXAMPLES_FISHERS="feather_berts_${model_num}.mnli.no_embeddings.16k.16k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/mnli \
        --split=validation \
        --n_examples=16384 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=16384 \
        --sequence_length=128 \
        --include_embeddings=false
}

compute_pefs_hans() {
    device=$1
    model_num=$2

    MODEL="connectivity/feather_berts_${model_num}"
    PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans.no_embeddings.16k.16k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/default \
        --split=validation \
        --n_examples=16384 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=16384 \
        --sequence_length=128 \
        --include_embeddings=false
}

nmf_mnli () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)

    local device=$1
    local model_num=$2
    local subset_indices=$3

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.mnli.no_embeddings.16k.16k.h5"
    local DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true
}

nmf_hans () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)
    
    local device=$1
    local model_num=$2
    local subset_indices=$3

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans.no_embeddings.16k.16k.h5"
    local DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true
}

###################################################################################


compute_pefs_mnli 0 0
compute_pefs_hans 0 0
nmf_mnli 0 0 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans 0 0 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24



compute_pefs_mnli 1 1
compute_pefs_hans 1 1
nmf_mnli 1 1 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans 1 1 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24




compute_pefs_mnli 2 15
compute_pefs_hans 2 15
nmf_mnli 2 15 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans 2 15 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24



compute_pefs_mnli 3 25
compute_pefs_hans 3 25
nmf_mnli 3 25 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans 3 25 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


###################################################################################
###################################################################################
###################################################################################
###################################################################################






compute_pefs_hans_lone() {
    local device=$1
    local model_num=$2

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.5k.32k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne \
        --split=validation \
        --n_examples=5000 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=32768 \
        --sequence_length=64 \
        --include_embeddings=false
}



nmf_hans_lone () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)
    
    local device=$1
    local model_num=$2
    local subset_indices=$3

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.5k.32k.h5"
    local DECOMP_FILENAME="nmf_decomp.per_sub_block.5k.32k.256.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=5000 \
        --start_fisher_index=0 \
        --end_fisher_index=32768 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true
}


compute_pefs_hans_lone 0 0
nmf_hans_lone 0 0 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


compute_pefs_hans_lone 1 1
nmf_hans_lone 1 1 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


compute_pefs_hans_lone 2 15
nmf_hans_lone 2 15 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


compute_pefs_hans_lone 3 25
nmf_hans_lone 3 25 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


###################################################################################


compute_pefs_hans_loye() {
    local device=$1
    local model_num=$2

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_loye.no_embeddings.5k.32k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ye \
        --split=validation \
        --n_examples=5000 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=32768 \
        --sequence_length=64 \
        --include_embeddings=false
}


transform_loye_using_nmf() {
    local model_num=$1

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_loye.no_embeddings.5k.32k.h5"
    local DECOMP_FILENAME="nmf_transformed.per_sub_block.5k.32k.256.${PER_EXAMPLES_FISHERS}"

    local EXISTING_PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.5k.32k.h5"
    local EXISTING_DECOMP_FILENAME="nmf_decomp.per_sub_block.5k.32k.256.${EXISTING_PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES= python ./scripts1/decomp/transform_using_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${EXISTING_DECOMP_FILENAME}" \
        --n_examples=5000 \
        --start_fisher_index=0 \
        --end_fisher_index=32768 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --subset_style=per_sub_block \
        --pef_embeddings=false \
        --model="${MODEL}" \
        --from_pt=true
}


# compute_pefs_hans_loye 3 0
compute_pefs_hans_loye 3 1
compute_pefs_hans_loye 3 15
compute_pefs_hans_loye 3 25


transform_loye_using_nmf 0


transform_loye_using_nmf 1
transform_loye_using_nmf 15
transform_loye_using_nmf 25

###################################################################################


nmf_hans_lone2 () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)
    
    local device=$1
    local model_num=$2
    local subset_indices=$3
    local n_components=$4

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.5k.32k.h5"
    local DECOMP_FILENAME="nmf_decomp.per_sub_block.5k.32k.${n_components}.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=5000 \
        --start_fisher_index=0 \
        --end_fisher_index=32768 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true
}

nmf_hans_lone2 3 0 17 2
nmf_hans_lone2 3 0 17 4
nmf_hans_lone2 3 0 17 8
nmf_hans_lone2 3 0 17 16
nmf_hans_lone2 3 0 17 32
nmf_hans_lone2 3 0 17 64
nmf_hans_lone2 3 0 17 128
nmf_hans_lone2 3 0 17 512



nmf_hans_lone2 3 0 18 2
nmf_hans_lone2 3 0 18 4
nmf_hans_lone2 3 0 18 8
nmf_hans_lone2 3 0 18 16
nmf_hans_lone2 3 0 18 32
nmf_hans_lone2 3 0 18 64
nmf_hans_lone2 3 0 18 128
nmf_hans_lone2 3 0 18 512


###################################################################################
###################################################################################

compute_dense_fisher() {
    set -u

    local device=$1
    local model_num=$2

    local MODEL="connectivity/feather_berts_${model_num}"
    local FISHER_FILENAME="feather_berts_${model_num}.hans_lone.validation.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER_FILENAME}" \
        --model=$MODEL \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne \
        --split=validation \
        --n_examples=5000 \
        --batch_size=4 \
        --sequence_length=64

}


compute_dense_fisher 0 0
compute_dense_fisher 1 1
compute_dense_fisher 2 15
compute_dense_fisher 3 25


###################################################################################
###################################################################################



compute_pefs_hans_lone_md() {
    local device=$1
    local model_num=$2

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.metric_derived.5k.131k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne \
        --split=validation \
        --n_examples=5000 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_metric_derived \
        --n_fisher_values_per_example=131072 \
        --sequence_length=64 \
        --include_embeddings=false \
        --pretrained_model="bert-base-uncased" \
        --from_pt_pretrained=true
}

compute_pefs_hans_lone_md 0 0
compute_pefs_hans_lone_md 1 1
compute_pefs_hans_lone_md 2 15
compute_pefs_hans_lone_md 3 25



nmf_hans_lone_md () {
    set -u

    # NOTE: This uses a beta of 2 and tol of 1e-12, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)
    
    local device=$1
    local model_num=$2
    local subset_indices=$3

    local n_components=256

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.metric_derived.5k.131k.h5"
    # local DECOMP_FILENAME="nmf_decomp.per_sub_block.5k.131k.256.${PER_EXAMPLES_FISHERS}"
    # Changed normalization and tol for v2.
    local DECOMP_FILENAME="nmf_decomp.v2.per_sub_block.5k.131k.256.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=5000 \
        --start_fisher_index=0 \
        --end_fisher_index=131072 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-9 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODEL}" \
        --from_pt=true \
        --use_metric_derived_distances=true
}

nmf_hans_lone_md 0 0 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans_lone_md 1 15 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans_lone_md 2 1 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
nmf_hans_lone_md 3 25 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24


###################################################################################
###################################################################################
###################################################################################


compute_pefs_hans_lone_mini() {
    local device=$1

    local MODEL="prajjwal1/bert-mini-mnli"
    local PER_EXAMPLES_FISHERS="bert_mini_mnli.hans_lone.no_embeddings.5k.131k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne \
        --split=validation \
        --n_examples=5000 \
        --batch_size=16 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=131072 \
        --sequence_length=64 \
        --include_embeddings=false
}


compute_pefs_hans_lone_mini 0

###################################################################################

nmf_hans_lone_mini () {
    set -u

    local device=$1
    local n_components=$2

    local MODEL="prajjwal1/bert-mini-mnli"
    local PER_EXAMPLES_FISHERS="bert_mini_mnli.hans_lone.no_embeddings.5k.131k.h5"

    local DECOMP_FILENAME="nmf_decomp.full.5k.32k.${n_components}.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=5000 \
        --start_fisher_index=0 \
        --end_fisher_index=32768 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-11
}

nmf_hans_lone_mini 0 256

###################################################################################
###################################################################################


nmf_hans_lone_full () {
    set -u

    local device=$1
    local model_num=$2
    local n_components=$3

    # local MODEL="prajjwal1/bert-mini-mnli"
    # local PER_EXAMPLES_FISHERS="bert_mini_mnli.hans_lone.no_embeddings.5k.131k.h5"

    # local DECOMP_FILENAME="nmf_decomp.full.5k.32k.${n_components}.${PER_EXAMPLES_FISHERS}"
    # NOTE: This uses a beta of 2 and tol of 1e-9, which differs from my previous work.
    #
    # Looks like it might be OK to set the tol to a smaller value (loss seems to be still decreasing at stop)
    

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.no_embeddings.5k.32k.h5"
    local DECOMP_FILENAME="nmf_decomp.full.5k.32k.${n_components}.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=5000 \
        --start_fisher_index=0 \
        --end_fisher_index=32768 \
        --nmf_n_components="${n_components}" \
        --nmf_beta=2.0 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-11
}


nmf_hans_lone_full 0 0 256
nmf_hans_lone_full 1 1 256
nmf_hans_lone_full 2 15 256
nmf_hans_lone_full 3 25 256



###################################################################################
###################################################################################


compute_pefs_hans_lone_all_vars() {
    local device=$1
    local model_num=$2
    local n_values_per_ex=$3

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone.all_vars.5k.${n_values_per_ex}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne \
        --split=validation \
        --n_examples=5000 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}

compute_pefs_hans_lone_all_vars 3 0 262144
compute_pefs_hans_lone_all_vars 3 15 262144

compute_pefs_hans_lone_all_vars 3 1 262144
compute_pefs_hans_lone_all_vars 3 25 262144



###################################################################################


sparsify_nmf() {
    local model_num=$1
    local n_values_per_ex=$2

    local nmf_name="nmf_decomp.c1024_2kIters_${n_values_per_ex}pe.feather_berts_${model_num}.hans_lone.all_vars.5k.262144.h5"
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-8
}

sparsify_nmf 0 65536
sparsify_nmf 15 65536

sparsify_nmf 1 65536
sparsify_nmf 25 65536


sparsify_nmf 0 131072
sparsify_nmf 15 131072

sparsify_nmf 1 131072
sparsify_nmf 25 131072



###################################################################################
###################################################################################


compute_pefs_hans_lone_with_flipped_all_vars() {
    local device=$1
    local model_num=$2
    local n_values_per_ex=$3

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.hans_lone_with_flipped.all_vars.10k.${n_values_per_ex}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne_with_flipped \
        --split=validation \
        --n_examples=10000 \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}

compute_pefs_hans_lone_with_flipped_all_vars 2 0 131072
compute_pefs_hans_lone_with_flipped_all_vars 2 1 131072



compute_pefs_hans_lone_with_flipped_all_vars 3 15 131072
compute_pefs_hans_lone_with_flipped_all_vars 3 25 131072


###################################################################################

sparsify_nmf() {
    local model_num=$1
    local n_values_per_ex=$2

    local nmf_name="nmf_decomp.c1024_2kIters_${n_values_per_ex}pe.feather_berts_${model_num}.hans_lone_with_flipped.all_vars.10k.131072.h5"
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-8
}

sparsify_nmf 0 65536
sparsify_nmf 15 65536

sparsify_nmf 1 65536
sparsify_nmf 25 65536

###################################################################################

compute_fisher_hans_lone_with_flipped_all_vars() {
    local device=$1
    local model_num=$2

    local MODEL="connectivity/feather_berts_${model_num}"
    local FISHER="feather_berts_${model_num}.hans_lone_with_flipped.all_vars.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne_with_flipped \
        --split=validation \
        --n_examples=10000 \
        --batch_size=4 \
        --sequence_length=64 \
        --include_classifier=true
}


compute_fisher_hans_lone_with_flipped_all_vars 0 0


