###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/anli_correct1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


###################################################################################
###################################################################################



MODEL=textattack/bert-base-uncased-MNLI
PER_EXAMPLES_FISHERS="textattack_bert_base_uncased_MNLI.anli_r3.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=3 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$MODEL \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=true \
    --task=anli/r3 \
    --n_examples=32768 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --include_embeddings=false



DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"

run_per_subset_nmf () {
    local device=$1
    local subset_indices=$2
    #        
    CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=16384 \
        --nmf_n_components=256 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8 \
        --pef_embeddings=false \
        --subset_style=per_sub_block \
        --subset_indices="${subset_indices}" \
        --model=$MODEL \
        --from_pt=true
}

# 25 sub-blocks total.
run_per_subset_nmf 3 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24

# Looks like pooler NaNed.


###########################################################


MODEL=textattack/bert-base-uncased-MNLI

compute_annotated_anli_pe_fishers () {
    local device=$1
    local task=$2
    local n_examples=$3

    local per_examples_fishers="textattack_bert_base_uncased_MNLI.annotated_anli_${task}.no_embeddings.sparse_dynamic_raw.all.32k.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${per_examples_fishers}" \
        --trained_model=$MODEL \
        --tokenizer="bert-base-uncased" \
        --from_pt_trained=true \
        --task="annotated_anli/${task}" \
        --n_examples=$n_examples \
        --batch_size=4 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=32768 \
        --include_embeddings=false \
        --split=validation
}

# compute_annotated_anli_pe_fishers 3 r1 1000
# compute_annotated_anli_pe_fishers 2 r2 1000
# compute_annotated_anli_pe_fishers 1 r3 1198
###################################################################################


DECOMP_PE_FISHERS="textattack_bert_base_uncased_MNLI.anli_r3.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${DECOMP_PE_FISHERS}"

TASK=r3
PER_EXAMPLES_FISHERS="textattack_bert_base_uncased_MNLI.annotated_anli_${TASK}.no_embeddings.sparse_dynamic_raw.all.32k.h5"

TRANSFORMED_NMF_FILENAME="transformed_nmf.textattack_bert_base_uncased_MNLI.annotated_anli_${TASK}.no_embeddings.sparse_dynamic_raw.all.32k.h5"

MODEL=textattack/bert-base-uncased-MNLI

CUDA_VISIBLE_DEVICES= python ./scripts1/decomp/transform_using_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${TRANSFORMED_NMF_FILENAME}" \
    --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --start_fisher_index=0 \
    --end_fisher_index=16384 \
    --pef_embeddings=false \
    --subset_style=per_sub_block \
    --model=$MODEL \
    --from_pt=true


###################################################################################


MODEL=textattack/bert-base-uncased-MNLI
PER_EXAMPLES_FISHERS="textattack_bert_base_uncased_MNLI.mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=3 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$MODEL \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=true \
    --task=glue/mnli \
    --n_examples=32768 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --include_embeddings=false


###################################################################################


# Transform the MNLI PE-Fishers using the ANLI r3 NMF decomposition. Do this per-subblock.


DECOMP_PE_FISHERS="textattack_bert_base_uncased_MNLI.anli_r3.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${DECOMP_PE_FISHERS}"

PER_EXAMPLES_FISHERS="textattack_bert_base_uncased_MNLI.mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

TRANSFORMED_NMF_FILENAME="transformed_nmf.nmf_anli_r3.textattack_bert_base_uncased_MNLI.mnli.no_embeddings.sparse_dynamic_raw.8k.16k.subblocks.h5"

MODEL=textattack/bert-base-uncased-MNLI

CUDA_VISIBLE_DEVICES= python ./scripts1/decomp/transform_using_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${TRANSFORMED_NMF_FILENAME}" \
    --decomposition="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --start_fisher_index=0 \
    --end_fisher_index=16384 \
    --n_examples=8192 \
    --pef_embeddings=false \
    --subset_style=per_sub_block \
    --model=$MODEL \
    --from_pt=true


###################################################################################
###################################################################################




###################################################################################
###################################################################################


# MODEL="prajjwal1/roberta-base-mnli"
# PER_EXAMPLES_FISHERS="roberta_base_mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME="nmf_decomp.per_sub_block.16k.16k.256.${PER_EXAMPLES_FISHERS}"


# # TODO: Maybe some scheme for automatically reducing some parameter in case of OOM.


# run_per_subset_nmf () {
#     local device=$1
#     local subset_indices=$2
#     #        
#     CUDA_VISIBLE_DEVICES=$device python ./scripts1/decomp/run_per_subset_nmf.py  \
#         --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
#         --per_example_fishers="${DATA_DIR}/gradient_merging1/per_example_fishers/${PER_EXAMPLES_FISHERS}" \
#         --n_examples=16384 \
#         --start_fisher_index=0 \
#         --end_fisher_index=16384 \
#         --nmf_n_components=256 \
#         --reduce_threshold=1 \
#         --nmf_max_iter=3000 \
#         --nmf_tol=1e-8 \
#         --pef_embeddings=false \
#         --subset_style=per_sub_block \
#         --subset_indices="${subset_indices}" \
#         --model=$MODEL \
#         --from_pt=true
# }

# # 24 subsets

# # run_per_subset_nmf 0 "0,3,6,9,12,15,18,21"
# # run_per_subset_nmf 1 "1,4,7,10,13,16,19,22"
# # run_per_subset_nmf 2 "2,5,8,11,14,17,20,23"
