###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/pi1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################



compute_pefs_mnli_snli() {
    local device=$1
    local model_num=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.snli_${split}.all_vars.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128 \
        --include_classifier=true

}

# compute_pefs_mnli_snli 1 0 65536 50000 train
compute_pefs_mnli_snli 0 0 65536 10000 validation



compute_pefs_mnli_snli2() {
    local device=$1
    local model_num=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5
    local skip=$6

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.snli_${split}.all_vars.skip${skip}.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --skip=$skip \
        --sequence_length=128 \
        --include_classifier=true

}
compute_pefs_mnli_snli2 3 0 131072 250000 train 50000



refit_coeffs_mnli_snli2() {
    local device=$1
    local model_num=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5
    local n_examples=$6

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local h_name="spH.nmf_decomp2.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    local output_name="refit_w.${h_name}"

    sh ./build_scripts/em/fit_coeffs_to_sparse_H.sh
    CUDA_VISIBLE_DEVICES=$device ./build/fit_coeffs_to_sparse_H \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --H_path="${PER_EXAMPLE_FISHERS_DIR}/${h_name}" \
        --nmf_max_iter=12000 \
        --n_examples=$n_examples \
        --n_fisher_values=$n_vals_pe \
        --n_splits_sparse_matmul=2048 \
        --n_row_splits_pefs=10
}

refit_coeffs_mnli_snli2 2 0 512 65536 10 50000



cd /fruitbasket/users/m/project_code/cuda_nmf1

refit_coeffs_mnli_snli2_2() {
    local device=$1
    local n_vals_pe=$2
    local n_examples=$3

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
    local output_name="fit_w.skip50000.${n_examples}ex.${n_vals_pe}vpe.${h_name}"

    sh ./build_scripts/em/fit_coeffs_to_sparse_H.sh
    CUDA_VISIBLE_DEVICES=$device ./build/fit_coeffs_to_sparse_H \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --H_path="${PER_EXAMPLE_FISHERS_DIR}/${h_name}" \
        --nmf_max_iter=12000 \
        --n_examples=$n_examples \
        --n_fisher_values=$n_vals_pe \
        --n_splits_sparse_matmul=2048 \
        --n_row_splits_pefs=50
}

refit_coeffs_mnli_snli2_2 3 65536 50000



###############################################################################

# For tmux:
# export LD_LIBRARY_PATH=


source ~/.bashrc


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/pi1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


###################################################################################

# export LD_LIBRARY_PATH=""


coeff_kl_relationship1() {
    local device=$1
    local start_component_index=$2
    local n_components=$3
    local outname=$4
    local ablating_fisher_top_k=$5
    local kl_range_targeter_examples_type=$6
    local n_kl_range_targeter_examples=$7

    local output_dir="${EXPS_DIR}/coeff_kl_relationships/${outname}"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/coeff_kl_relationship.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --n_evaluation_examples=50000 \
        --min_target_kl=0.05 \
        --max_target_kl=0.2 \
        --n_kl_range_targeter_examples=$n_kl_range_targeter_examples \
        --kl_range_targeter_max_iters=250 \
        --n_kl_range_finds=6 \
        --n_sign_guides_per_kl_range_find=3 \
        --ablating_fisher_top_k=$ablating_fisher_top_k \
        --kl_range_targeter_examples_type=$kl_range_targeter_examples_type
}

# coeff_kl_relationship1 0 0 512 "kl_hc_2048ex" "-1" highest_coeffs 2048
# coeff_kl_relationship1 1 0 512 "kl_hc_256ex" "-1" highest_coeffs 256
# coeff_kl_relationship1 2 0 512 "kl_hc_2048ex__tkc_1024" 1024 highest_coeffs 2048
coeff_kl_relationship1 3 0 512 "kl_hc_256ex__tkc_1024" 1024 highest_coeffs 256




###################################################################################

coeff_kl_relationship_DEV01() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/coeff_kl_relationships/attempt01"
    # mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/coeff_kl_relationship.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --n_evaluation_examples=50000 \
        --min_target_kl=0.05 \
        --max_target_kl=0.2 \
        --n_kl_range_targeter_examples=2048 \
        --kl_range_targeter_max_iters=250 \
        --n_kl_range_finds=6 \
        --n_sign_guides_per_kl_range_find=3
}




# coeff_kl_relationship_DEV01 0 0 25
# coeff_kl_relationship_DEV01 1 25 25
# coeff_kl_relationship_DEV01 2 50 25
coeff_kl_relationship_DEV01 3 75 25



###################################################################################

guided_kl_ablations_DEV01() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6
}


# guided_kl_ablations_DEV01 0 0 50
# guided_kl_ablations_DEV01 1 50 100
# guided_kl_ablations_DEV01 2 100 150
guided_kl_ablations_DEV01 3 150 200

###################################################################################

guided_kl_ablations_DEV01_02() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01_02"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6 \
        --ablation_exp_types=random_examples_H
}


# guided_kl_ablations_DEV01_02 0 0 100
# guided_kl_ablations_DEV01_02 1 100 200

###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

EXPS_DIR="${DATA_DIR}/pi1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


guided_kl_ablations_DEV02() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt02_abl_var_grad"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6 \
        --ablating_variable_style=gradient
}


# guided_kl_ablations_DEV02 0 0 50
# guided_kl_ablations_DEV02 1 50 100
# guided_kl_ablations_DEV02 2 100 150
# guided_kl_ablations_DEV02 3 150 200

###################################################################################
###################################################################################

# HANS for "application" -> pred changing
# Treat both grouping of examples (W) and comp Fishers (H) as outputs of
# NPEFF, and use controls for each.
#
# Can also do similar for SNLI, where I create the groups on my own. (like some predictions
# or maybe high overlap between premise and hypothesis)
