###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/snli2_lrm_npeff"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"
PERTURBATIONS_DIR="${EXPS_DIR}/perturbations"

###################################################################################
###################################################################################

# Holdout last 150k train examples. Create 3 heldout sets of 50k each
# from these. Probably wont use all of them, but nice to have more.


# 550,152 => 400,152 => roughly 12500 steps per epoch
train_snli_150k_holdout() {
    local device=$1

    local output_name="bert_base_snli_150k_holdout_4_epochs_01"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/training/finetune.py \
        --output_path="${MODELS_DIR}/${output_name}" \
        --task='snli/default' \
        --split='train[:-150000]' \
        --batch_size=32 \
        --model='bert-base-uncased' \
        --learning_rate="2e-5" \
        --n_steps=50000
}

train_snli_150k_holdout 0


# epoch 3 (2 when 0-based) (lowest val loss, approximately equal to highest val acc)
# epoch 5 (4 when 0-based) (highest val acc)

###################################################################################

compute_pefs() {
    local device=$1
    local n_vals_pe=$2
    local n_ex=$3
    local split=$4
    local split_name=$5

    local model="bert_base_snli_150k_holdout_4_epochs_01_epoch2"
    local pef="${model}.${split_name}.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_lrm_pefs.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
        --model="${MODELS_DIR}/${model}"   \
        --from_pt=false \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128
}

compute_pefs 2 65536 50000 train[-150000:-100000] heldout_from_train_1
compute_pefs 1 65536 50000 train[-100000:-50000] heldout_from_train_2

###################################################################################


cd /fruitbasket/users/m/project_code/cuda_m_npeff
export PATH=/home/m/.local/bin:$PATH


DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
EXPS_DIR="${DATA_DIR}/snli2_lrm_npeff"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

PLAYPLAN_PEFS_DIR="/playpen/users/m/project_data/snli2_lrm_npeff/per_example_fishers/"


cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR

CUDA_VISIBLE_DEVICES=0,1 ./build/mains/run_m_npeff \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.h5" \
    --min_nonzero_per_col=12 \
    --n_components=128 \
    --learning_rate_G_G_only=3e-5 \
    --learning_rate_G=3e-4 \
    --output_filepath="${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.128comps.001.h5" \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint=1500 \
    --pefs_col_offsets_non_cumulative=false


cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR

CUDA_VISIBLE_DEVICES=0,3 ./build/mains/run_m_npeff \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.h5" \
    --min_nonzero_per_col=12 \
    --n_components=512 \
    --learning_rate_G_G_only=3e-5 \
    --learning_rate_G=3e-4 \
    --output_filepath="${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.h5" \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint=1500 \
    --pefs_col_offsets_non_cumulative=false

###################################################################################


cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.128comps.001.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=1 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.128comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.128comps.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=25000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false

###################################################################################

filter_pefs_to_incorrects() {
    local pef_filepath=$1
    local output_filepath=$2

    CUDA_VISIBLE_DEVICES= python ./em/projects/m_npeff/mains/filter_lrm_pefs_to_wrong_predictions.py \
        --pef_path="${pef_filepath}" \
        --output_path="${output_filepath}"
}

filter_pefs_to_incorrects \
    "$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.h5" \
    "$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5"

###################################################################################

cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.128comps.001.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=0 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.128comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false



cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=0 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false


###################################################################################


cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.coeffs_fit001.h5" $PLAYPLAN_PEFS_DIR

CUDA_VISIBLE_DEVICES=0,2 ./build/mains/run_m_npeff_expansion \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.coeffs_fit001.h5 \
    --n_additional_components=32 \
    --learning_rate_G_G_only=1e-2 \
    --learning_rate_G=3e-3 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.expansion_001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=100 \
    --n_iters_joint_expansion_only=500 \
    --n_iters_joint=1500 \
    --pefs_col_offsets_non_cumulative=false \
    --use_W_from_decomposition=true



cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.coeffs_fit001.h5" $PLAYPLAN_PEFS_DIR

CUDA_VISIBLE_DEVICES=0,2 ./build/mains/run_m_npeff_expansion \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.coeffs_fit001.h5 \
    --n_additional_components=64 \
    --learning_rate_G_G_only=1e-2 \
    --learning_rate_G=1e-2 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.expansion64comps.no_full_join.001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters_G_only=250 \
    --n_iters_joint_expansion_only=1000 \
    --n_iters_joint=0 \
    --pefs_col_offsets_non_cumulative=false \
    --use_W_from_decomposition=true

###################################################################################


cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.expansion64comps.no_full_join.001.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=0 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.50000ex.65536.wrongs_only.512comps.expansion64comps.no_full_join.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.512comps.expansion64comps.no_full_join.001.coeffs_fit001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=25000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false

###################################################################################





cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=1 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.coeffs_fit_toheldout_from_train_2.001.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=25000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false




###################################################################################
###################################################################################


run_perturbations_kl() {
    local device=$1
    local split=$2
    local perturbation_magnitude=$3
    local max_abs_cos_sim=$4
    local nmf_filename=$5
    local pef_filename=$6
    local output_name=$7

    local model="bert_base_snli_150k_holdout_4_epochs_01_epoch2"

    local n_top_examples=128
    local n_total_examples=10000

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_m_npeff_perturbations_kl.py \
        --nmf_filepath="${PER_EXAMPLE_FISHERS_DIR}/${nmf_filename}" \
        --pef_filepath="${PER_EXAMPLE_FISHERS_DIR}/${pef_filename}" \
        --model="${MODELS_DIR}/${model}"   \
        --from_pt=false \
        --tokenizer=bert-base-uncased \
        --sequence_length=128 \
        --task=snli/default \
        --split=$split \
        --output_filepath="${PERTURBATIONS_DIR}/${output_name}" \
        --n_top_examples=$n_top_examples \
        --n_total_examples=$n_total_examples \
        --perturbation_magnitude=$perturbation_magnitude \
        --max_abs_cos_sim=$max_abs_cos_sim
}

# run_perturbations_kl 1 train[-100000:-50000] 1e-1 0.35 \
#     bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.512comps.expansion64comps.no_full_join.001.coeffs_fit001.h5 \
#     bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5 \
#     kl_perturbations.bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.512comps.expansion64comps.no_full_join.001.coeffs_fit001.json


run_perturbations_kl 1 train[-100000:-50000] 1e-1 0.35 \
    bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.coeffs_fit_toheldout_from_train_2.001.h5 \
    bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5 \
    kl_perturbations.bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_1.mnpeff.512comps.001.coeffs_fit_toheldout_from_train_2.001.h5


# ~/Desktop/projects_data/extract_merge1/neurips2023/perturbation_kl_jsons

rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/perturbations/*" \
    "$HOME/Desktop/projects_data/extract_merge1/neurips2023/perturbation_kl_jsons/"



# compute_pefs() {
#     local device=$1
#     local n_vals_pe=$2
#     local n_ex=$3
#     local split=$4
#     local split_name=$5

#     local model="bert_base_snli_150k_holdout_4_epochs_01_epoch2"
#     local pef="${model}.${split_name}.${n_ex}ex.${n_vals_pe}.h5"

#     CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_lrm_pefs.py  \
#         --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
#         --model="${MODELS_DIR}/${model}"   \
#         --from_pt=false \
#         --tokenizer="bert-base-uncased" \
#         --task=snli/default \
#         --split=$split \
#         --n_examples=$n_ex \
#         --batch_size=4 \
#         --n_fisher_values_per_example=$n_vals_pe \
#         --sequence_length=128
# }
