###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/pi1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################


compute_pefs_qnli_val_first_20k() {
    local device=$1
    local n_values_per_ex=$2

    local MODEL="textattack/bert-base-uncased-QQP"
    local PER_EXAMPLES_FISHERS="bert_base_qqp.qqp_val.all_vars.first_20k.${n_values_per_ex}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/qqp \
        --split=validation \
        --n_examples=20000 \
        --batch_size=8 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}

compute_pefs_qnli_val_first_20k 0 131072



compute_pefs_qnli_test_set() {
    local device=$1
    local n_examples=$2
    local n_values_per_ex=$3
    local batch_size=$4
    local n_sub_batches=$5


    local MODEL="textattack/bert-base-uncased-QQP"
    local PER_EXAMPLES_FISHERS="bert_base_qqp.qqp_test.all_vars.first_${n_examples}ex.${n_values_per_ex}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/qqp \
        --split=test \
        --n_examples=$n_examples \
        --batch_size=$batch_size \
        --n_sub_batches=$n_sub_batches \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}

compute_pefs_qnli_test_set 1 100000 65536 16 2



###################################################################################

compute_fisher_qnli_val_first_20k() {
    local device=$1

    local MODEL="textattack/bert-base-uncased-QQP"
    local FISHER="bert_base_qqp.qqp_val.all_vars.first_20k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=glue/qqp \
        --split=validation \
        --n_examples=20000 \
        --batch_size=8 \
        --sequence_length=64 \
        --include_classifier=true
}

compute_fisher_qnli_val_first_20k 1


###################################################################################


enrich_incorrects_qnli_pef() {
    local n_values_per_ex=$1
    local desired_fraction_decimals=$2

    local PER_EXAMPLES_FISHERS="bert_base_qqp.qqp_val.all_vars.first_20k.${n_values_per_ex}.h5"
    local ENRICHED_PEFS="enriched_incorrects_to_0_${desired_fraction_decimals}.${PER_EXAMPLES_FISHERS}"

    CUDA_VISIBLE_DEVICES= python ./scripts1/data_gen/enrich_pefs.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --example_subtype=incorrect_prediction \
        --desired_fraction=0.${desired_fraction_decimals} \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${ENRICHED_PEFS}"
}

enrich_incorrects_qnli_pef 131072 5
enrich_incorrects_qnli_pef 131072 25



###################################################################################

cd /fruitbasket/users/m/project_code/cuda_nmf1


run_enriched_incorrects_qnli_nmf() {
    local devices=$1
    local desired_fraction_decimals=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5

    local pef_name="enriched_incorrects_to_0_${desired_fraction_decimals}.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
    local output_name="nmf_decomp.c${n_comps}_2kIters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"

    echo "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

    nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --nmf_max_iter=2000 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

# It appears that the number of components was too high given the number of examples.
run_enriched_incorrects_qnli_nmf 0,1,2,3 5 1024 65536 8

run_enriched_incorrects_qnli_nmf 0,1,2,3 25 1024 65536 8



###################################################################################


sparsify_qnli_nmf() {
    # local nmf_name="nmf_decomp.c1024_2kIters_65536pe_20000ex_mvpp8.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
    # local output_name="spH.${nmf_name}"


    local pef_filename="enriched_incorrects_to_0_25.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
    local nmf_name="nmf_decomp.c1024_2kIters_65536pe_mvpp8.${pef_filename}"
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-8
}

sparsify_qnli_nmf

# sparsify_nmf 0 65536


###################################################################################
###################################################################################

compute_fisher_paws_train_set() {
    local device=$1

    local MODEL="textattack/bert-base-uncased-QQP"
    local FISHER="bert_base_qqp.paws_final_train.all_vars.all_ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=paws/final \
        --split=train \
        --n_examples=49_401 \
        --batch_size=8 \
        --sequence_length=64 \
        --include_classifier=true
}

compute_fisher_paws_train_set 1



###################################################################################

compute_pefs_paws_train_set() {
    local device=$1
    local n_values_per_ex=$2

    local MODEL="textattack/bert-base-uncased-QQP"
    local PER_EXAMPLES_FISHERS="bert_base_qqp.paws_final_train.all_vars.all_ex.${n_values_per_ex}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=paws/final \
        --split=train \
        --n_examples=49_401 \
        --batch_size=8 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}



compute_pefs_paws_train_set 3 131072


###################################################################################

cd /fruitbasket/users/m/project_code/cuda_nmf1


run_paws_nmf() {
    local devices=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4

    local pef_name="bert_base_qqp.paws_final_train.all_vars.all_ex.131072.h5"
    local output_name="nmf_decomp.c${n_comps}_2kIters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"

    nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --nmf_max_iter=2000 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_paws_nmf 0,1,2,3 1024 65536 16

###################################################################################


sparsify_paws_nmf() {
    local nmf_name="nmf_decomp.c1024_2kIters_65536pe_mvpp16.bert_base_qqp.paws_final_train.all_vars.all_ex.131072.h5"
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-8
}

sparsify_paws_nmf

###################################################################################
###################################################################################
###################################################################################
###################################################################################


compute_fisher_sci_tail_train_set() {
    local device=$1

    local MODEL="textattack/bert-base-uncased-RTE"
    local FISHER="bert_base_rte.scit_tail_final_train.all_vars.all_ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=sci_tail/default \
        --split=train \
        --n_examples=23097 \
        --batch_size=8 \
        --sequence_length=64 \
        --include_classifier=true
}

compute_fisher_sci_tail_train_set 1


compute_pefs_sci_tail_train_set() {
    local device=$1
    local n_values_per_ex=$2

    local MODEL="textattack/bert-base-uncased-RTE"
    local PER_EXAMPLES_FISHERS="bert_base_rte.sci_tail_train.all_vars.all_ex.${n_values_per_ex}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=sci_tail/default \
        --split=train \
        --n_examples=23097 \
        --batch_size=8 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}

compute_pefs_sci_tail_train_set 1 65536



###################################################################################

cd /fruitbasket/users/m/project_code/cuda_nmf1

run_scitail_nmf() {
    local devices=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4

    local pef_name="bert_base_rte.sci_tail_train.all_vars.all_ex.65536.h5"
    local output_name="nmf_decomp.c${n_comps}_2kIters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"

    nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --nmf_max_iter=2000 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_scitail_nmf 0,1,2,3 1024 65536 8

###################################################################################

sparsify_scitail_nmf() {
    local n_comps=$1
    local n_vals_pe=$2
    local min_values_per_parameter=$3

    
    local pef_name="bert_base_rte.sci_tail_train.all_vars.all_ex.65536.h5"
    local nmf_name="nmf_decomp.c${n_comps}_2kIters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"

    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-8
}

sparsify_scitail_nmf 1024 65536 8



###################################################################################
###################################################################################

compute_pefs_rte_hans() {
    local device=$1

    local MODEL="textattack/bert-base-uncased-RTE"
    local PER_EXAMPLES_FISHERS="bert_base_rte.hans_lone_with_flipped.all_vars.all_ex..h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ne_with_flipped \
        --split=validation \
        --n_examples=10000 \
        --batch_size=8 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=131072 \
        --sequence_length=64 \
        --include_classifier=true
}

compute_pefs_rte_hans 1


run_hans_rte_nmf() {
    local devices=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4

    local pef_name="bert_base_rte.hans_lone_with_flipped.all_vars.all_ex..h5"
    local output_name="nmf_decomp.c${n_comps}_2kIters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"

    nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    # CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
    #     --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
    #     --nmf_n_components=$n_comps \
    #     --n_fisher_values=$n_vals_pe \
    #     --nmf_max_iter=2000 \
    #     --min_values_per_parameter=$min_values_per_parameter \
    #     --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --nmf_max_iter=2000 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_hans_rte_nmf 0,1,2,3 1024 131072 4


###################################################################################
###################################################################################

compute_pefs_mnli() {
    local device=$1
    local model_num=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.mnli_${split}.all_vars.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/mnli \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128 \
        --include_classifier=true

}

# compute_pefs_mnli 1 0 131072 9815 validation_matched
compute_pefs_mnli 1 0 131072 100000 train


compute_pefs_mnli_snli() {
    local device=$1
    local model_num=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.snli_${split}.all_vars.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128 \
        --include_classifier=true

}

# compute_pefs_mnli_snli 1 0 65536 50000 train
compute_pefs_mnli_snli 0 0 65536 10000 validation



compute_pefs_mnli_snli2() {
    local device=$1
    local model_num=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5
    local skip=$6

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.snli_${split}.all_vars.skip${skip}.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --skip=$skip \
        --sequence_length=128 \
        --include_classifier=true

}
compute_pefs_mnli_snli2 3 0 131072 250000 train 50000



compute_fisher_mnli_snli() {
    local device=$1
    local model_num=$2
    local n_ex=$3
    local split=$4

    local MODEL="connectivity/feather_berts_${model_num}"
    local FISHER="feather_berts_${model_num}.mnli_snli_${split}.all_vars.${n_ex}ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --sequence_length=128 \
        --include_classifier=true
}

# compute_fisher_mnli_snli 1 0 50000 train
compute_fisher_mnli_snli 0 0 10000 validation


###################################################################################

###################################################################################

compute_fisher_mnli() {
    local device=$1
    local model_num=$2
    local n_ex=$3
    local split=$4

    local MODEL="connectivity/feather_berts_${model_num}"
    local FISHER="feather_berts_${model_num}.mnli_${split}.all_vars.${n_ex}ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=glue/mnli \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --sequence_length=128 \
        --include_classifier=true
}

compute_fisher_mnli 1 0 100000 train


###################################################################################

cd /fruitbasket/users/m/project_code/cuda_nmf1

run_mnli_nmf() {
    local devices=$1
    local model_num=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5
    local n_examples=$6

    local pef_name="feather_berts_${model_num}.mnli_train.all_vars.100000ex.131072.h5"
    local output_name="nmf_decomp.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"

    nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=1250 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_mnli_nmf 0,1,2,3 0 1024 65536 16 50000



cd /fruitbasket/users/m/project_code/cuda_nmf1

run_mnli_snli_nmf() {
    local devices=$1
    local model_num=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5
    local n_examples=$6

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local output_name="nmf_decomp.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"

    nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=1250 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_mnli_snli_nmf 0,1,2,3 0 512 65536 10 50000


run_mnli_snli_nmf2() {
    local devices=$1
    local model_num=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5
    local n_examples=$6

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local output_name="nmf_decomp2.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local NVCC=/usr/local/cuda/bin/nvcc
    local COMPUTE_CUDA_VERSION_SH=./build_scripts/compute_cuda_version.sh
    local OX=O3

    $NVCC mains/em/run_nmf_on_pefs.cu \
        -gencode $(sh $COMPUTE_CUDA_VERSION_SH)\
        "-${OX}" \
        -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=1250 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_mnli_snli_nmf2 0,1,2 0 512 65536 10 50000



refit_coeffs_mnli_snli2() {
    local device=$1
    local model_num=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5
    local n_examples=$6

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local h_name="spH.nmf_decomp2.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    local output_name="refit_w.${h_name}"


    sh ./build_scripts/em/fit_coeffs_to_sparse_H.sh
    CUDA_VISIBLE_DEVICES=$device ./build/fit_coeffs_to_sparse_H \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --H_path="${PER_EXAMPLE_FISHERS_DIR}/${h_name}" \
        --nmf_max_iter=12000 \
        --n_examples=$n_examples \
        --n_fisher_values=$n_vals_pe \
        --n_splits_sparse_matmul=2048 \
        --n_row_splits_pefs=10


        # --output_path="/dev/null" \

}

refit_coeffs_mnli_snli2 2 0 512 65536 10 50000



# fit_coeffs_mnli_snli2() {
#     local device=$1
#     local model_num=$2
#     local n_comps=$3
#     local n_vals_pe=$4
#     local min_values_per_parameter=$5
#     local n_examples=$6

#     local pef_name=$7

#     local nmf_name="nmf_decomp2.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"

#     sh ./build_scripts/em/fit_coeffs_to_sparse_H.sh

#     CUDA_VISIBLE_DEVICES=$device ./build/fit_coeffs_to_sparse_H \

# }






sparsify_mnli_snli_nmf() {
    local model_num=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4
    local n_examples=$5

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local nmf_name="nmf_decomp.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"

    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-10
}

sparsify_mnli_snli_nmf 0 512 65536 10 50000



sparsify_mnli_snli_nmf2() {
    local model_num=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4
    local n_examples=$5

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local nmf_name="nmf_decomp2.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"

    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-10
}

sparsify_mnli_snli_nmf2 0 512 65536 10 50000



###################################################################################

sparsify_mnli_nmf() {
    local model_num=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4
    local n_examples=$5

    
    local pef_name="feather_berts_${model_num}.mnli_train.all_vars.100000ex.131072.h5"
    local nmf_name="nmf_decomp.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"

    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-8
}

sparsify_mnli_nmf 0 1024 65536 16 50000



###################################################################################
###################################################################################
###################################################################################
###################################################################################


# __TEST__run_paws_nmf() {
#     local devices=$1
#     local n_comps=$2
#     local n_vals_pe=$3
#     local min_values_per_parameter=$4

#     local pef_name="bert_base_qqp.paws_final_train.all_vars.all_ex.131072.h5"
#     local output_name="nmf_decomp.c${n_comps}_2kIters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"

#     nvcc mains/em/run_nmf_on_pefs.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

#     CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
#         --output_path="/dev/null" \
#         --nmf_n_components=$n_comps \
#         --n_fisher_values=$n_vals_pe \
#         --nmf_max_iter=2000 \
#         --min_values_per_parameter=$min_values_per_parameter \
#         --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

# }

# __TEST__run_paws_nmf 1,2,3 1024 65536 16


###################################################################################
###################################################################################




compute_pefs_qnli_val_first_20k_test() {
    local device=$1
    local n_values_per_ex=$2
    local batch_size=$3
    local n_sub_batches=$4

    local MODEL="textattack/bert-base-uncased-QQP"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="/dev/null" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/qqp \
        --split=validation \
        --n_examples=20000 \
        --batch_size=$batch_size \
        --n_sub_batches=$n_sub_batches \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_values_per_ex \
        --sequence_length=64 \
        --include_classifier=true
}



# compute_pefs_qnli_val_first_20k_test 2 131072 8 1
compute_pefs_qnli_val_first_20k_test 2 131072 16 2
# compute_pefs_qnli_val_first_20k_test 2 131072 24 3

# 1.0385713577270508
# 1.3054118156433105



# 252 looks like an interesting component to ablate.
# 250
# 245 maybe, low coeff values though.
# 244
# 239
# 231
# 227
# 221


# - Run some ablations on the QQP comps I have now. [Need to sparsify NMFs first.]
# - Create script to make enriched PEFS, specifically for wrong predictions.
# - Run NMF on those enriched datasets.
# - Fix the escape_latex python function.
# - ake faster top-k during PEF computation.

# Make quick plan and email Colin.



# 34