###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/m_npeff1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################


compute_pefs() {
    local device=$1
    local n_vals_pe=$2
    local n_ex=$3
    local split=$4

    local model="textattack/bert-base-uncased-QQP"
    local pef="textattack_bert_qqp.${split}.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_lrm_pefs.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
        --model="${model}"   \
        --from_pt=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/qqp \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=1 \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128
}

compute_pefs 1 65536 50000 train
compute_pefs 2 65536 40430 validation
compute_pefs 3 65536 50000 test


###################################################################################


compute_pefs_wrong_only() {
    local device=$1
    local n_vals_pe=$2
    local n_ex=$3
    local split=$4

    local model="textattack/bert-base-uncased-QQP"
    local pef="textattack_bert_qqp.wrongs_only.${split}.${n_ex}ex_prcsd.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_lrm_pefs_wrong_only.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
        --model="${model}"   \
        --from_pt=true \
        --tokenizer="bert-base-uncased" \
        --task=glue/qqp \
        --split=$split \
        --n_examples_to_process=$n_ex \
        --batch_size=1 \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128
}

compute_pefs_wrong_only 3 65536 363846 train



###################################################################################
###################################################################################

cd /fruitbasket/users/m/project_code/cuda_m_npeff
export PATH=/home/m/.local/bin:$PATH


cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.train.50000ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=0,1,2,3 ./build/mains/run_m_npeff \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.train.50000ex.65536.h5" \
    --min_nonzero_per_col=8 \
    --n_components=256 \
    --learning_rate_G_G_only=3e-5 \
    --learning_rate_G=3e-4 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint=1250


cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/


cmake --build build

CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=21000 \
    --truncate_when_G_index_map_mismatch=true





cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.test.50000ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=0,1,2 ./build/mains/run_m_npeff \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.test.50000ex.65536.h5" \
    --min_nonzero_per_col=8 \
    --n_components=256 \
    --learning_rate_G_G_only=3e-5 \
    --learning_rate_G=3e-4 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test.256comps.001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint=1250



cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test.256comps.001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test.256comps.001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test_to_validation.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=21000 \
    --truncate_when_G_index_map_mismatch=true


###################################################################################

cd /fruitbasket/users/m/project_code/cuda_m_npeff
export PATH=/home/m/.local/bin:$PATH





cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.test.50000ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=0,1,2,3 ./build/mains/run_m_npeff \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.test.50000ex.65536.h5" \
    --min_nonzero_per_col=8 \
    --n_components=1024 \
    --learning_rate_G_G_only=3e-5 \
    --learning_rate_G=3e-4 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test.1024comps.001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint=1250



cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test.1024comps.001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test.1024comps.001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.test_to_validation.1024comps.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=21000 \
    --truncate_when_G_index_map_mismatch=true




###################################################################################
###################################################################################

filter_pefs_to_incorrects() {
    local pef_filepath=$1
    local output_filepath=$2

    CUDA_VISIBLE_DEVICES= python ./em/projects/m_npeff/mains/filter_lrm_pefs_to_wrong_predictions.py \
        --pef_path="${pef_filepath}" \
        --output_path="${output_filepath}"
}

filter_pefs_to_incorrects \
    "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.train.50000ex.65536.h5" \
    "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.train.50000ex.65536.wrongs_only.h5"


filter_pefs_to_incorrects \
    "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5" \
    "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.wrongs_only.h5"



###################################################################################

cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.wrongs_only.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=0 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.wrongs_only.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --truncate_when_G_index_map_mismatch=true


cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=0,1 ./build/mains/run_m_npeff_expansion \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.wrongs_only.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.h5 \
    --n_additional_components=32 \
    --learning_rate_G_G_only=1e-3 \
    --learning_rate_G=1e-4 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.expansion_001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint_expansion_only=500 \
    --n_iters_joint=1000 \
    --use_W_from_decomposition=true



cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.expansion_001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.expansion_001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_to_validation.wrongs_only.coeffs_fit001.expansion_001.refit_to_validation.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=21000 \
    --truncate_when_G_index_map_mismatch=true


###################################################################################
###################################################################################



cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.wrongs_only.train.363846ex_prcsd.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.wrongs_only.train.363846ex_prcsd.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train.256comps.001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=21000 \
    --truncate_when_G_index_map_mismatch=true


cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=2,3 ./build/mains/run_m_npeff_expansion \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.wrongs_only.train.363846ex_prcsd.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.h5 \
    --n_additional_components=32 \
    --learning_rate_G_G_only=1e-3 \
    --learning_rate_G=1e-4 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.expansion_001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint_expansion_only=500 \
    --n_iters_joint=1500 \
    --use_W_from_decomposition=true




cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5 /playpen/users/m/project_data/m_npeff1/per_example_fishers/
cp /fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.expansion_001.h5  /playpen/users/m/project_data/m_npeff1/per_example_fishers/

CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.validation.40430ex.65536.h5" \
    --decomposition_filepath=/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.expansion_001.h5 \
    --output_filepath=/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/textattack_bert_qqp.mnpeff.train_fit_to_all_train_wrongs_only.expansion_001.refit_to_validation.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=1500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=21000 \
    --truncate_when_G_index_map_mismatch=true



# - Double check how many the model gets wrong on the train set.
# - Maybe try with smaller model so that it gets worse performance?


###################################################################################
###################################################################################
###################################################################################


# DEBUGGING BELOW:::


# cd /fruitbasket/users/m/project_code/cuda_m_npeff
# export PATH=/home/m/.local/bin:$PATH

# cmake --build build; CUDA_VISIBLE_DEVICES=1,2,3 ./build/mains/run_m_npeff \
#     --pef_filepath="/playpen/users/m/project_data/m_npeff1/per_example_fishers/textattack_bert_qqp.test.50000ex.65536.h5" \
#     --min_nonzero_per_col=8 \
#     --n_components=256 \
#     --learning_rate_G_G_only=3e-5 \
#     --learning_rate_G=3e-4 \
#     --output_filepath=/fruitbasket/users/m/tmp/blah.h5 \
#     --n_preprocess_cpu_threads=32 \
#     --n_iters_G_only=10 \
#     --n_iters_joint=10
