
###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/transfer1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

#######################################

FROZEN_MODEL=frozen_bert_base_rte_001
PRETRAINED_MODEL="bert-base-uncased"

TASK="glue/rte"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/transfer1/train_frozen.py  \
    --output_path="${MODELS_DIR}/${FROZEN_MODEL}" \
    --pretrained_model=$PRETRAINED_MODEL \
    --task="${TASK}" \
    --n_cache_train_examples=10_000 \
    --cache_batch_size=512 \
    --n_val_examples=4_000 \
    --learning_rate=1e-3 \
    --clipnorm=0.1 \
    --batch_size=512 \
    --n_epochs=20 \
    --steps_per_epoch=2048

#######################################

CUDA_VISIBLE_DEVICES=0 python ./scripts1/ogmm/compute_fisher.py  \
    --model="${MODELS_DIR}/${FROZEN_MODEL}" \
    --from_pt=false \
    --tokenizer="${PRETRAINED_MODEL}" \
    --task=$TASK \
    --fisher_path="$FISHER_DIR/frozen_bert_base_rte_001.fisher.h5" \
    --batch_size=16

#######################################

FINETUNED_MODEL="textattack/bert-base-uncased-MNLI"
FINETUNED_TASK="glue/mnli"

####################


PER_EXAMPLES_FISHERS="bert_base_mnli.sparse_dynamic_raw.32k.16k.h5"

CUDA_VISIBLE_DEVICES=3 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${FINETUNED_MODEL}" \
    --from_pt_trained=true \
    --task=$FINETUNED_TASK \
    --n_examples=32768 \
    --batch_size=8 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=16384 \
    --include_embeddings=true


DECOMP_FILENAME="nmf_decomp.16k.8k.64.reduced_1.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=16384 \
    --reduce_threshold=1 \
    --start_fisher_index=0 \
    --end_fisher_index=8192 \
    --nmf_n_components=64 \
    --nmf_max_iter=3000 \
    --nmf_tol=1e-8


####################


PER_EXAMPLES_FISHERS="bert_base_mnli.sparse_dynamic_metric_derived.32k.16k.h5"

CUDA_VISIBLE_DEVICES=2 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${FINETUNED_MODEL}" \
    --pretrained_model=$PRETRAINED_MODEL \
    --from_pt_trained=true \
    --from_pt_pretrained=true \
    --task=$FINETUNED_TASK \
    --n_examples=32768 \
    --batch_size=8 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_metric_derived \
    --n_fisher_values_per_example=16384 \
    --include_embeddings=true


DECOMP_FILENAME="nmf_decomp.16k.8k.64.reduced_1.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=1 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=16384 \
    --reduce_threshold=1 \
    --start_fisher_index=0 \
    --end_fisher_index=8192 \
    --nmf_n_components=64 \
    --nmf_max_iter=3000 \
    --nmf_tol=1e-8
