###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/winogrange1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################


CUDA_VISIBLE_DEVICES=0 python ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/bert_base_mnli_to_winogrande_custom_8_epochs_01" \
    --model=textattack/bert-base-uncased-MNLI \
    --from_pt=true \
    --tokenizer=bert-base-uncased \
    --task=winogrande/custom \
    --batch_size=32 \
    --n_epochs=8 \
    --sequence_length=50 \
    --learning_rate=2e-5 \
    --clipnorm=0.1 \
    --force_n_labels=3

###################################################################################

MODEL=bert_base_mnli_to_winogrande_custom_8_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande_heldout.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/custom \
    --n_examples=10_000 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false \
    --split=heldout


###################################################################################

VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets



EXPS_DIR="${DATA_DIR}/winogrange1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


MODEL=bert_base_mnli_to_winogrande_custom_8_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande_heldout.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

DECOMP_FILENAME="nmf_decomp.per_sub_block.10k.16k.256.${PER_EXAMPLES_FISHERS}"

run_per_subset_nmf () {
    local device=$1
    local subset_indices=$2
    #        
    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=10_000 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components=256 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model="${MODELS_DIR}/${MODEL}" \
        --from_pt=false
}


# 25 sub-blocks total.

run_per_subset_nmf 0 0,4,8,12,16,20,24
run_per_subset_nmf 1 1,5,9,13,17,21
run_per_subset_nmf 2 2,6,10,14,18,22
run_per_subset_nmf 3 3,7,11,15,19,23

###################################################################################

EXPS_DIR="${DATA_DIR}/winogrange1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

MODEL=bert_base_mnli_to_winogrande_custom_8_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande_heldout.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.per_sub_block.10k.16k.256.${PER_EXAMPLES_FISHERS}"

OUTPUT_FILENAME="nmf_components_fishers.incorrect.per_sub_block.10k.16k.256.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/projects/wino/create_nmf_components_fishers.py  \
    --pef_embeddings=false \
    --output_path="${FISHER_DIR}/${OUTPUT_FILENAME}" \
    --model="${MODELS_DIR}/${MODEL}" \
    --from_pt=false \
    --tokenizer=bert-base-uncased \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --subset_style=per_sub_block \
    --shift_labels=true \
    --selection_coeff_factors=0.3,0.4,0.5 \
    --selection_frac_thresholds=0.5,0.6,0.7 \
    --selection_p_value_thresholds=0.05,0.01 \
    --also_coalesce_via_batch_fishers=true \
    --batch_fisher_batch_size=16 \
    --tuning_indicator=incorrect

###################################################################################
###################################################################################

# MODEL=bert_base_mnli_to_winogrande_xl_4_epochs_01
# DECOMP_PE_FISHERS="${MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${DECOMP_PE_FISHERS}"

# PER_EXAMPLES_FISHERS="${MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

# TRANSFORMED_NMF_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.transformed_from_nmf_train.${PER_EXAMPLES_FISHERS}"

# CUDA_VISIBLE_DEVICES= python ./scripts1/decomp/transform_using_nmf.py  \
#     --output_path="${PER_EXAMPLE_FISHERS_DIR}/${TRANSFORMED_NMF_FILENAME}" \
#     --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
#     --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
#     --start_fisher_index=0 \
#     --end_fisher_index=16384 \
#     --pef_embeddings=false \
#     --subset_style=per_sub_block \
#     --model="${MODELS_DIR}/${MODEL}" \
#     --from_pt=false


MODEL=bert_base_mnli_to_winogrande_custom_8_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande_custom_train.no_embeddings.sparse_dynamic_raw.60k.32k.h5"

CUDA_VISIBLE_DEVICES=1 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/custom \
    --n_examples=60796 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false \
    --split=train \
    --ds_force_deterministic=true



###################################################################################
###################################################################################

MODEL=bert_base_mnli_to_winogrande_custom_8_epochs_01
PER_EXAMPLES_FISHERS="${MODEL}.winogrande_heldout.no_embeddings.sparse_dynamic_raw.20k.32k.h5"

CUDA_VISIBLE_DEVICES=2 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}"  \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=winogrande/custom \
    --n_examples=20_000 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --sequence_length=50 \
    --include_embeddings=false \
    --split=heldout \
    --ds_force_deterministic=true
