# Paste bin for stuff used for the SNLI stuff in the paper.

###################################################################################
###################################################################################

MODEL="connectivity/feather_berts_0"

og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name=f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME=f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

###################################################################################

RESULTS_DIR=f'{EXPS_DIR}/coeff_kl_relationships/attempt01'
RESULT_FILEPATH=os.path.join(RESULTS_DIR, 'test_coeff_kl_relationship.comp{component_index}.h5')

###################################################################################

RESULTS_DIR=f'{EXPS_DIR}/guided_kl_ablations/attempt01'
RESULT_FILEPATH=os.path.join(RESULTS_DIR, 'test_guided_kl_ablation.comp{component_index}.h5')

###################################################################################

# pdf is mnli_snli2_new_train_ex_0_512c_65536pe_10mvpp_50000ex.pdf

###################################################################################
###################################################################################


compute_pefs_mnli_snli() {
    local device=$1
    local model_num=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5

    local MODEL="connectivity/feather_berts_${model_num}"
    local PER_EXAMPLES_FISHERS="feather_berts_${model_num}.snli_${split}.all_vars.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=snli/default \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=128 \
        --include_classifier=true

}

# compute_pefs_mnli_snli 1 0 65536 50000 train
compute_pefs_mnli_snli 0 0 65536 10000 validation

###################################################################################

run_mnli_snli_nmf2() {
    local devices=$1
    local model_num=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5
    local n_examples=$6

    local pef_name="feather_berts_${model_num}.snli_train.all_vars.50000ex.65536.h5"
    local output_name="nmf_decomp2.c${n_comps}_1250Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local NVCC=/usr/local/cuda/bin/nvcc
    local COMPUTE_CUDA_VERSION_SH=./build_scripts/compute_cuda_version.sh
    local OX=O3

    $NVCC mains/em/run_nmf_on_pefs.cu \
        -gencode $(sh $COMPUTE_CUDA_VERSION_SH)\
        "-${OX}" \
        -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=1250 \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_mnli_snli_nmf2 0,1,2 0 512 65536 10 50000

###################################################################################
###################################################################################

coeff_kl_relationship_DEV01() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/coeff_kl_relationships/attempt01"
    # mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/coeff_kl_relationship.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --n_evaluation_examples=50000 \
        --min_target_kl=0.05 \
        --max_target_kl=0.2 \
        --n_kl_range_targeter_examples=2048 \
        --kl_range_targeter_max_iters=250 \
        --n_kl_range_finds=6 \
        --n_sign_guides_per_kl_range_find=3
}

# coeff_kl_relationship_DEV01 0 0 25
# coeff_kl_relationship_DEV01 1 25 25
# coeff_kl_relationship_DEV01 2 50 25
coeff_kl_relationship_DEV01 3 75 25

###################################################################################

guided_kl_ablations_DEV01() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6
}


# guided_kl_ablations_DEV01 0 0 50
# guided_kl_ablations_DEV01 1 50 100
# guided_kl_ablations_DEV01 2 100 150
guided_kl_ablations_DEV01 3 150 200

###################################################################################

guided_kl_ablations_DEV01_02() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01_02"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6 \
        --ablation_exp_types=random_examples_H
}

############################################

guided_kl_ablations_DEV01_02_02() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01_02_02"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6 \
        --ablation_exp_types=random_examples_H
}


# guided_kl_ablations_DEV01_02_02 0 0 50
# guided_kl_ablations_DEV01_02_02 1 50 100
# guided_kl_ablations_DEV01_02_02 2 100 150




###################################################################################

guided_kl_ablations_DEV01_03() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01_03"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    local fisher_name="feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"


    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="connectivity/feather_berts_0" \
        --tokenizer="bert-base-uncased" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=6 \
        --ablation_exp_types=component_examples,component_examples_H
}


# guided_kl_ablations_DEV01_03 0 0 50
# guided_kl_ablations_DEV01_03 1 50 100
# guided_kl_ablations_DEV01_03 2 100 150

###################################################################################

compute_pef_norm_ratios() {
    local top_ks=$1

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

    CUDA_VISIBLE_DEVICES= python ./scripts1/comp_measure/compute_pef_norm_ratios.py \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --top_ks="${top_ks}"

}

compute_pef_norm_ratios 8,16,32,64,128,256,512


###################################################################################


make_top_example_interpretation_set_up() {
    local n_examples=$1

    local output_dir="/fruitbasket/users/m/tmp/snli_comp_pdfs_${n_examples}ex_test"
    mkdir -p $output_dir

    local og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
    local og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
    local nmf_name="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

    local pef_name="feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"


    CUDA_VISIBLE_DEVICES= python ./em/projects/pi/exps/mains/top_examples/make_top_example_interpretation_set_up.py \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --tokenizer='bert-base-uncased' \
        --n_examples=$n_examples \
        --output_dir=$output_dir \
        --components_fontsize=Large

}

make_top_example_interpretation_set_up 32


rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/snli_comp_pdfs_32ex_test/" \
    "$HOME/Desktop/projects_data/extract_merge1/ll/pdfs/snli_indy_comp_pdfs/"



###################################################################################


