# Gradient merging stuff.

###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/gradient_merging1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


#######################################

FINETUNED_MODEL="prajjwal1/roberta-base-mnli"
# PRETRAINED_MODEL="roberta-base"

PER_EXAMPLES_FISHERS="roberta_base_mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${FINETUNED_MODEL}" \
    --tokenizer="roberta-base" \
    --from_pt_trained=true \
    --task=glue/mnli \
    --n_examples=32768 \
    --batch_size=8 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --include_embeddings=false


PER_EXAMPLES_FISHERS="roberta_base_mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.8k.16k.64.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=1 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=8192 \
    --start_fisher_index=0 \
    --end_fisher_index=16384 \
    --nmf_n_components=64 \
    --reduce_threshold=1 \
    --nmf_max_iter=3000 \
    --nmf_tol=1e-8
