###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/winogrange1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################


CUDA_VISIBLE_DEVICES=0 python ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/bert_base_mnli_to_winogrande_xl_4_epochs_01" \
    --model=textattack/bert-base-uncased-MNLI \
    --from_pt=true \
    --tokenizer=bert-base-uncased \
    --task=winogrande/xl \
    --batch_size=32 \
    --n_epochs=4 \
    --sequence_length=50 \
    --learning_rate=2e-5 \
    --clipnorm=0.1 \
    --force_n_labels=3


CUDA_VISIBLE_DEVICES=1 python ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/bert_base_to_winogrande_xl_4_epochs_01" \
    --model=bert-base-uncased \
    --from_pt=true \
    --tokenizer=bert-base-uncased \
    --task=winogrande/xl \
    --batch_size=32 \
    --n_epochs=4 \
    --sequence_length=50 \
    --learning_rate=2e-5 \
    --clipnorm=0.1


CUDA_VISIBLE_DEVICES=2 python ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/bert_base_mnli_to_winogrande_xl_8_epochs_01" \
    --model=textattack/bert-base-uncased-MNLI \
    --from_pt=true \
    --tokenizer=bert-base-uncased \
    --task=winogrande/xl \
    --batch_size=32 \
    --n_epochs=8 \
    --sequence_length=50 \
    --learning_rate=2e-5 \
    --clipnorm=0.1 \
    --force_n_labels=3


###################################################################################
###################################################################################


# TODO: Need to find way to go from train mnli model (3 classes) to 2 classes.
# Probably not hard, but yeah.


MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/xl \
    --n_examples=32768 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false



MODEL=bert_base_to_winogrande_xl_4_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=1 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/xl \
    --n_examples=32768 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false



MODEL=bert_base_mnli_to_winogrande_xl_8_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=2 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/xl \
    --n_examples=32768 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false



###################################################################################

# NOTE: I think the test set has no labels, not an issue here, but can affect analysis.

MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/xl \
    --n_examples=6000 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false \
    --ds_force_deterministic=true \
    --split=validation,test


MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

DECOMP_FILENAME="nmf_decomp.per_sub_block.6k.16k.64.${PER_EXAMPLES_FISHERS}"

run_per_subset_nmf () {
    local device=$1
    local subset_indices=$2
    #        
    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=6000 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components=64 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODELS_DIR}/${MODEL}" \
        --from_pt=false
}


# 25 sub-blocks total.

run_per_subset_nmf 0 0,4,8,12,16,20,24
run_per_subset_nmf 1 1,5,9,13,17,21
run_per_subset_nmf 2 2,6,10,14,18,22
run_per_subset_nmf 3 3,7,11,15,19,23

###################################################################################



MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

DECOMP_FILENAME="nmf_decomp.per_sub_block.6k.16k.256.${PER_EXAMPLES_FISHERS}"

run_per_subset_nmf () {
    local device=$1
    local subset_indices=$2
    #        
    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=6000 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components=256 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODELS_DIR}/${MODEL}" \
        --from_pt=false
}


# 25 sub-blocks total.

run_per_subset_nmf 0 0,4,8,12,16,20,24
run_per_subset_nmf 1 1,5,9,13,17,21
run_per_subset_nmf 2 2,6,10,14,18,22
run_per_subset_nmf 3 3,7,11,15,19,23

###################################################################################
###################################################################################


EXPS_DIR="${DATA_DIR}/winogrange1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"

run_per_subset_nmf () {
    local device=$1
    local subset_indices=$2
    #        
    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components=256 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODELS_DIR}/${MODEL}" \
        --from_pt=false
}


# 25 sub-blocks total.

run_per_subset_nmf 0 0,4,8,12,16,20,24
run_per_subset_nmf 1 1,5,9,13,17,21
run_per_subset_nmf 2 2,6,10,14,18,22
run_per_subset_nmf 3 3,7,11,15,19,23



CUDA_VISIBLE_DEVICES=2 python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components=192 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices=22 \
        --model="${MODELS_DIR}/${MODEL}" \
        --from_pt=false


############################################


MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
DECOMP_PE_FISHERS="${MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${DECOMP_PE_FISHERS}"

PER_EXAMPLES_FISHERS="${MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

TRANSFORMED_NMF_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.transformed_from_nmf_train.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES= python ./scripts1/decomp/transform_using_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${TRANSFORMED_NMF_FILENAME}" \
    --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --start_fisher_index=0 \
    --end_fisher_index=16384 \
    --pef_embeddings=false \
    --subset_style=per_sub_block \
    --model="${MODELS_DIR}/${MODEL}" \
    --from_pt=false


###################################################################################
###################################################################################


EXPS_DIR="${DATA_DIR}/winogrange1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

MODEL="bert_base_mnli_to_winogrande_xl_4_epochs_01"
PER_EXAMPLES_FISHERS="${MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.per_sub_block.6k.16k.256.${PER_EXAMPLES_FISHERS}"

OUTPUT_FILENAME="nmf_components_fishers.per_sub_block.6k.16k.256.${PER_EXAMPLES_FISHERS}"


CUDA_VISIBLE_DEVICES=0 python ./scripts1/projects/wino/create_nmf_components_fishers.py  \
    --pef_embeddings=false \
    --output_path="${FISHER_DIR}/${OUTPUT_FILENAME}" \
    --model="${MODELS_DIR}/${MODEL}" \
    --from_pt=false \
    --tokenizer=bert-base-uncased \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --subset_style=per_sub_block \
    --shift_labels=true \
    --n_labeled=2534 \
    --selection_coeff_factors=0.5,0.4,0.3 \
    --selection_frac_thresholds=0.8,0.85,0.9 \
    --selection_p_value_thresholds=0.2,0.1,0.05,0.01 \
    --also_coalesce_via_batch_fishers=true \
    --batch_fisher_batch_size=16


# CUDA_VISIBLE_DEVICES= python ./scripts1/projects/wino/create_nmf_components_fishers.py  \
#     --pef_embeddings=false \
#     --output_path="${FISHER_DIR}/${OUTPUT_FILENAME}" \
#     --model="${MODELS_DIR}/${MODEL}" \
#     --from_pt=false \
#     --tokenizer=bert-base-uncased \
#     --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
#     --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
#     --subset_style=per_sub_block \
#     --shift_labels=true \
#     --n_labeled=2534 \
#     --selection_coeff_factors=0.5,0.4,0.3 \
#     --selection_frac_thresholds=0.8,0.85,0.9 \
#     --selection_p_value_thresholds=0.2,0.1,0.05,0.01



################################################################

CUDA_VISIBLE_DEVICES=3 python ./scripts1/ogmm/compute_fisher.py  \
    --model="textattack/bert-base-uncased-MNLI" \
    --from_pt=true \
    --tokenizer=bert-base-uncased \
    --task="glue/mnli" \
    --batch_size=16 \
    --n_examples=32768 \
    --sequence_length=128 \
    --fisher_path="${FISHER_DIR}/bert_base_mnli_fisher.32k.h5"
