###################################################################################
# Stuff for my second (or something) round of HANS ablation experiments.
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/pi1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################

compute_pefs_rte_hans() {
    local device=$1
    local split=$2

    local MODEL="textattack/bert-base-uncased-RTE"
    local PER_EXAMPLES_FISHERS="bert_base_rte.lexical_overlap.${split}.all_vars.all_ex.131072.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --trained_model="${MODEL}"  \
        --from_pt_trained=true \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap \
        --split=$split \
        --n_examples=10000 \
        --batch_size=8 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=131072 \
        --sequence_length=64 \
        --include_classifier=true
}

# compute_pefs_rte_hans 1 train
# compute_pefs_rte_hans 3 validation


###################################################################################

compute_fisher_rte_hans() {
    local device=$1
    local split=$2

    local MODEL="textattack/bert-base-uncased-RTE"
    local FISHER="bert_base_rte.lexical_overlap.${split}.all_vars.all_ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap \
        --split=$split \
        --n_examples=10000 \
        --batch_size=16 \
        --sequence_length=64 \
        --include_classifier=true
}

# compute_fisher_rte_hans 2 train
compute_fisher_rte_hans 2 validation


compute_fisher_rte_hans_ye() {
    local device=$1
    local split=$2

    local MODEL="textattack/bert-base-uncased-RTE"
    local FISHER="bert_base_rte.lexical_overlap_ye.${split}.all_vars.all_ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=hans/lexical_overlap_ye \
        --split=$split \
        --n_examples=5000 \
        --batch_size=16 \
        --sequence_length=64 \
        --include_classifier=true
}

# compute_fisher_rte_hans_ye 2 train
compute_fisher_rte_hans_ye 1 validation


compute_fisher_rte() {
    local device=$1

    local MODEL="textattack/bert-base-uncased-RTE"
    local FISHER="bert_base_rte.rte.train.all_vars.all_ex.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/ogmm/compute_fisher.py  \
        --fisher_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}" \
        --tokenizer="bert-base-uncased" \
        --task=glue/rte \
        --split=train \
        --n_examples=2490 \
        --batch_size=8 \
        --sequence_length=128 \
        --include_classifier=true
}

compute_fisher_rte 0



###################################################################################

cd /fruitbasket/users/m/project_code/cuda_nmf1


run_rte_hans_nmf() {
    local devices=$1
    local split=$2
    local n_comps=$3
    local n_vals_pe=$4
    local min_values_per_parameter=$5

    local n_iters=2500

    local pef_name="bert_base_rte.lexical_overlap.${split}.all_vars.all_ex.131072.h5"
    local output_name="nmf_decomp.c${n_comps}_${n_iters}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"
    
    local NVCC=/usr/local/cuda/bin/nvcc
    local COMPUTE_CUDA_VERSION_SH=./build_scripts/compute_cuda_version.sh
    local OX=O3

    $NVCC mains/em/run_nmf_on_pefs.cu \
        -gencode $(sh $COMPUTE_CUDA_VERSION_SH)\
        "-${OX}" \
        -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --nmf_max_iter=$n_iters \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/${pef_name}"

}

run_rte_hans_nmf 0,1,2,3 validation 256 131072 1


###################################################################################

sparsify_rte_hans_nmf() {
    local split=$1
    local n_comps=$2
    local n_vals_pe=$3
    local min_values_per_parameter=$4

    local n_iters=2500

    local pef_name="bert_base_rte.lexical_overlap.${split}.all_vars.all_ex.131072.h5"
    local nmf_name="nmf_decomp.c${n_comps}_${n_iters}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}.${pef_name}"
    
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-10
}

sparsify_rte_hans_nmf validation 256 131072 1


