

VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/snli3og_lrm_npeff"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"
PERTURBATIONS_DIR="${EXPS_DIR}/perturbations"
TCAV_DIR="${EXPS_DIR}/tcav"

###################################################################################
###################################################################################


compute_pefs() {
    local device=$1
    local n_vals_pe=$2
    local n_ex=$3
    local split=$4

    local model="connectivity/feather_berts_0"
    local pef="feather_berts_0.${split}.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_lrm_pefs.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
        --model="${model}"   \
        --from_pt=true \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=2 \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128
}

# compute_pefs 1 65536 50000 train
compute_pefs 3 65536 50000 train_skip_50k


###################################################################################

cd /fruitbasket/users/m/project_code/cuda_m_npeff
export PATH=/home/m/.local/bin:$PATH


DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
EXPS_DIR="${DATA_DIR}/snli3og_lrm_npeff"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

PLAYPLAN_PEFS_DIR="/playpen/users/m/project_data/snli3og_lrm_npeff/per_example_fishers/"



cp "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR

CUDA_VISIBLE_DEVICES=0,1,3 ./build/mains/run_m_npeff \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/feather_berts_0.train.50000ex.65536.h5" \
    --min_nonzero_per_col=14 \
    --n_components=512 \
    --learning_rate_G_G_only=1e-4 \
    --learning_rate_G=3e-4 \
    --output_filepath="${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.h5" \
    --n_preprocess_cpu_threads=64 \
    --n_iters_G_only=100 \
    --n_iters_joint=1500 \
    --pefs_col_offsets_non_cumulative=false




cp "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train_skip_50k.50000ex.65536.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=1 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/feather_berts_0.train_skip_50k.50000ex.65536.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=25000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false

###################################################################################


filter_pefs_to_incorrects() {
    local pef_filepath=$1
    local output_filepath=$2

    CUDA_VISIBLE_DEVICES= python ./em/projects/m_npeff/mains/filter_lrm_pefs_to_wrong_predictions.py \
        --pef_path="${pef_filepath}" \
        --output_path="${output_filepath}" \
        --special_processing=HF_MNLI
}

filter_pefs_to_incorrects \
    "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.h5" \
    "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.wrongs_only.h5"


###################################################################################


cp "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.wrongs_only.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=0 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PLAYPLAN_PEFS_DIR}/feather_berts_0.train.50000ex.65536.wrongs_only.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.h5 \
    --n_preprocess_cpu_threads=32 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=25000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false



cp "${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.h5" $PLAYPLAN_PEFS_DIR

CUDA_VISIBLE_DEVICES=0,1,2,3 ./build/mains/run_m_npeff_expansion \
    --pef_filepath="${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train.50000ex.65536.wrongs_only.h5" \
    --decomposition_filepath=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.h5 \
    --n_additional_components=64 \
    --learning_rate_G_G_only=1e-3 \
    --learning_rate_G=3e-3 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.h5 \
    --n_preprocess_cpu_threads=64 \
     --n_iters_G_only=250 \
    --n_iters_joint_expansion_only=1000 \
    --n_iters_joint=0 \
    --pefs_col_offsets_non_cumulative=false \
    --use_W_from_decomposition=true


CUDA_VISIBLE_DEVICES=0 ./build/mains/fit_m_npeff_coeffs \
    --pef_filepath="/${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train_skip_50k.50000ex.65536.h5" \
    --decomposition_filepath=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.h5 \
    --n_preprocess_cpu_threads=64 \
    --n_iters=500 \
    --n_columns_per_chunk=2500000 \
    --n_examples_per_chunk=25000 \
    --mu_eps=1e-7 \
    --pefs_col_offsets_non_cumulative=false


###################################################################################

run_perturbations_kl() {
    local device=$1
    local split=$2
    local perturbation_magnitude=$3
    local max_abs_cos_sim=$4
    local nmf_filename=$5
    local pef_filename=$6
    local output_name=$7

    local n_top_examples=128
    local n_total_examples=10000

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_m_npeff_perturbations_kl.py \
        --nmf_filepath="${PER_EXAMPLE_FISHERS_DIR}/${nmf_filename}" \
        --pef_filepath="${PER_EXAMPLE_FISHERS_DIR}/${pef_filename}" \
        --model="connectivity/feather_berts_0"   \
        --from_pt=true \
        --tokenizer=bert-base-uncased \
        --sequence_length=128 \
        --task=snli/default \
        --split=$split \
        --output_filepath="${PERTURBATIONS_DIR}/${output_name}" \
        --n_top_examples=$n_top_examples \
        --n_total_examples=$n_total_examples \
        --perturbation_magnitude=$perturbation_magnitude \
        --max_abs_cos_sim=$max_abs_cos_sim
}

# run_perturbations_kl 0 train_skip_50k 1e-1 0.35 \
#     feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5 \
#     feather_berts_0.train_skip_50k.50000ex.65536.h5 \
#     kl_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.json


# run_perturbations_kl 3 train_skip_50k 1e-1 0.35 \
#     feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.h5 \
#     feather_berts_0.train_skip_50k.50000ex.65536.h5 \
#     kl_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.json

run_perturbations_kl 0 train_skip_50k 1e-1 "-1" \
    feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5 \
    feather_berts_0.train_skip_50k.50000ex.65536.h5 \
    kl_perturbations.no_semiorth.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.json



rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/project_data/extract_merge1/snli3og_lrm_npeff/perturbations/*" \
    "$HOME/Desktop/projects_data/extract_merge1/neurips2023/perturbation_kl_jsons/snli3og_lrm_npeff/"




run_perturbations_acc() {
    local device=$1
    local split=$2
    local perturbation_magnitude=$3
    local max_abs_cos_sim=$4
    local nmf_filename=$5
    local pef_filename=$6
    local output_name=$7
    local comp_inds=$8

    
    local n_top_examples=128
    local n_total_examples=10000

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_m_npeff_perturbations_kl.py \
        --nmf_filepath="${PER_EXAMPLE_FISHERS_DIR}/${nmf_filename}" \
        --pef_filepath="${PER_EXAMPLE_FISHERS_DIR}/${pef_filename}" \
        --model="connectivity/feather_berts_0"   \
        --from_pt=true \
        --tokenizer=bert-base-uncased \
        --sequence_length=128 \
        --task=snli/default \
        --split=$split \
        --output_filepath="${PERTURBATIONS_DIR}/${output_name}" \
        --n_top_examples=$n_top_examples \
        --n_total_examples=$n_total_examples \
        --perturbation_magnitude=$perturbation_magnitude \
        --max_abs_cos_sim=$max_abs_cos_sim \
        --component_indices="${comp_inds}" \
        --nli_label_swapping=true
}


run_perturbations_acc 0 train[-100000:-50000] 3e-1 0.35 \
    feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5 \
    feather_berts_0.train_skip_50k.50000ex.65536.h5 \
    acc_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.json \
    "15,26,102,119,128,129,134,143,146,162,168,177,189,200,212,221,242,277,282,379,404,426,436,471,491"



run_perturbations_acc 1 train[-100000:-50000] 3e-1 0.35 \
    feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.h5 \
    feather_berts_0.train_skip_50k.50000ex.65536.h5 \
    acc_perturbations.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.json \
    "0,1,2,4,9,14,17,18,19,21,22,24,27,28,30,34,38,40,42,45,50,54,55,59,62,63,79,90,166,193,198,207,210,226,232,241,253,276,285,335,341,346,443,468,490,535,547,555"


# probably baseline total acc of around 78.5%

# 79,90,166,193,198,207,210,226,232,241,253,276,285,335,341,346,443,468,490,535,547,555

# Component 18 -- new type of wrong heuristic
# Component 21 -- Flawed heuristic with clothing
# Component 34 -- interesting new type of wrong heuristic
# Component 38 -- Interesting wrong/flawed heuristic involving something like counting.
# Component 39 -- flawed heuristic
# Component 44 -- maybe flawed heuristic
# Component 50 -- clothing related heuristic, seems wrong here as opposed to just flawed here
# Component 62 -- maybe flawed application of "common sense knowledge" of people's feeling/intentions given a description of their actions.
# Component 90 -- kinda similar to component 62
# Component 193 -- flawed clothing heuristic

# 18, 34, 62

run_npeff_tcav() {
    local device=$1
    local decomposition_filename=$2
    local activations_filename=$3
    local output_filename=$4

    local n_runs=500
    local n_top_examples=32
    local n_negative_examples=128
    local n_scoring_examples=5000

    local CLS_ACTS_DIR="/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_bert_tcav_exp.py \
        --decomposition_filepath="${PER_EXAMPLE_FISHERS_DIR}/${decomposition_filename}" \
        --decomposition_type=NPEFF \
        --activations_filepath="${CLS_ACTS_DIR}/${activations_filename}" \
        --model="connectivity/feather_berts_0"   \
        --from_pt=true \
        --output_filepath="${TCAV_DIR}/${output_filename}" \
        --n_runs=$n_runs \
        --n_top_examples=$n_top_examples \
        --n_negative_examples=$n_negative_examples \
        --n_scoring_examples=$n_scoring_examples
}

run_npeff_tcav 0 \
    feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5 \
    "feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5" \
    tcav.feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5





CUDA_VISIBLE_DEVICES= python ./scripts1/comp_measure/compute_pef_norm_ratios.py \
    --pef_path="/${PER_EXAMPLE_FISHERS_DIR}/feather_berts_0.train_skip_50k.50000ex.65536.h5" \
    --nmf_path=$PER_EXAMPLE_FISHERS_DIR/feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5 \
    --top_ks=128,64,32,16
