
VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

FISHER_DIR=/fruitbasket/users/m/project_data/extract_merge1/fishers0
PER_EXAMPLE_FISHERS_DIR=/fruitbasket/users/m/project_data/extract_merge1/per_example_fishers0

TASK='mnli'
FINETUNED_MODEL="prajjwal1/bert-small-mnli"
PRETRAINED_MODEL="prajjwal1/bert-small"

METRIC_VAR_FISHER="bert_small_mnli_sparse_fisher_variances_32k.sp05.metric.131k.h5"
# UNIFORM_VAR_FISHER="bert_small_mnli_sparse_fisher_variances_32k.sp05.uniform.131k.h5"


#######################################

PER_EXAMPLES_FISHERS=dev_bert_small_mnli.no_embeddings.sparse_static.md.4k.h5

CUDA_VISIBLE_DEVICES=3 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$FINETUNED_MODEL \
    --task=$TASK \
    --n_examples=4096 \
    --batch_size=128 \
    --expectation_wrt_logits=true \
    --flavor=sparse_static \
    --sparse_fisher="${FISHER_DIR}/${METRIC_VAR_FISHER}" \
    --include_embeddings=false
#######################################

# PER_EXAMPLES_FISHERS=dev_bert_small_mnli.no_embeddings.sparse_dynamic_raw.8k.64k.h5
# PER_EXAMPLES_FISHERS=dev_bert_small_mnli.no_embeddings.sparse_dynamic_raw.16k.64k.h5
PER_EXAMPLES_FISHERS=dev_bert_small_mnli.no_embeddings.sparse_dynamic_raw.16k.32k.h5

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$FINETUNED_MODEL \
    --task=$TASK \
    --n_examples=32768 \
    --batch_size=16 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=16384 \
    --include_embeddings=false

#######################################

PER_EXAMPLES_FISHERS=dev_bert_small_mnli.no_embeddings.sparse_dynamic_metric_derived.16k.32k.h5

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$FINETUNED_MODEL \
    --pretrained_model=$PRETRAINED_MODEL \
    --task=$TASK \
    --n_examples=32768 \
    --batch_size=16 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_metric_derived \
    --n_fisher_values_per_example=16384 \
    --include_embeddings=false

###################################################################################
###################################################################################

FISHER_DIR=/fruitbasket/users/m/project_data/extract_merge1/fishers0
PER_EXAMPLE_FISHERS_DIR=/fruitbasket/users/m/project_data/extract_merge1/per_example_fishers0

TASK='qqp'
FINETUNED_MODEL="TehranNLP-org/roberta-base-qqp-2e-5-42"
PRETRAINED_MODEL="roberta-base"

#######################################

PER_EXAMPLES_FISHERS=dev_roberta_qqp.no_embeddings.sparse_dynamic_metric_derived.16k.8k.h5

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$FINETUNED_MODEL \
    --pretrained_model=$PRETRAINED_MODEL \
    --task=$TASK \
    --n_examples=8192 \
    --batch_size=8 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_metric_derived \
    --n_fisher_values_per_example=16384 \
    --include_embeddings=false

#######################################

PER_EXAMPLES_FISHERS=dev_roberta_qqp.no_embeddings.sparse_dynamic_raw.16k.8k.h5

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model=$FINETUNED_MODEL \
    --task=$TASK \
    --n_examples=8192 \
    --batch_size=8 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=16384 \
    --include_embeddings=false

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/math_datasets_dev1"
MODELS_DIR="${EXPS_DIR}/models1"
FISHER_DIR="${EXPS_DIR}/fishers0"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers0"

#######################################


# CUDA_VISIBLE_DEVICES=0 python  ./scripts1/training/finetune.py  \
#     --output_path="${MODELS_DIR}/og_tf__bert_small" \
#     --model="prajjwal1/bert-small" \
#     --task=math_dataset/original_true_false \
#     --batch_size=32 \
#     --learning_rate=4e-5 \
#     --n_steps=10_000


CUDA_VISIBLE_DEVICES=3 python  ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/og_tf__bert_small__100k_steps" \
    --model="prajjwal1/bert-small" \
    --task=math_dataset/original_true_false \
    --batch_size=32 \
    --learning_rate=2e-5 \
    --n_steps=100_000

########################

# PER_EXAMPLES_FISHERS=og_tf__bert_small.no_embeddings.sparse_dynamic_raw.16k.32k.h5

# CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
#     --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
#     --trained_model="${MODELS_DIR}/og_tf__bert_small" \
#     --tokenizer="prajjwal1/bert-small" \
#     --from_pt_trained=false \
#     --task=math_dataset/original_true_false \
#     --n_examples=32768 \
#     --batch_size=128 \
#     --expectation_wrt_logits=true \
#     --flavor=sparse_dynamic_raw \
#     --n_fisher_values_per_example=16384 \
#     --include_embeddings=false


PER_EXAMPLES_FISHERS=og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/og_tf__bert_small__100k_steps" \
    --tokenizer="prajjwal1/bert-small" \
    --from_pt_trained=false \
    --task=math_dataset/original_true_false \
    --n_examples=32768 \
    --batch_size=128 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --include_embeddings=false


PER_EXAMPLES_FISHERS=og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5

CUDA_VISIBLE_DEVICES=1 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/og_tf__bert_small__100k_steps" \
    --tokenizer="prajjwal1/bert-small" \
    --from_pt_trained=false \
    --task=math_dataset/original_true_false \
    --n_examples=32768 \
    --batch_size=128 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --include_embeddings=false \
    --include_pooler=false

########################

PER_EXAMPLES_FISHERS=og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5
DECOMP_FILENAME="nmf_decomp.8k.4k.128.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=8192 \
    --start_fisher_index=0 \
    --end_fisher_index=4096 \
    --nmf_n_components=128 \
    --nmf_tol=1e-6

# PER_EXAMPLES_FISHERS=og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5
# DECOMP_FILENAME="nmf_decomp.1k.1k.128.${PER_EXAMPLES_FISHERS}"

# CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
#     --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
#     --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
#     --n_examples=1024 \
#     --start_fisher_index=0 \
#     --end_fisher_index=1024 \
#     --nmf_n_components=128 \
#     --nmf_tol=1e-6

PER_EXAMPLES_FISHERS=og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5
DECOMP_FILENAME="nmf_decomp.16k.8k.128.reduced_1.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=16384 \
    --reduce_threshold=1 \
    --start_fisher_index=0 \
    --end_fisher_index=8192 \
    --nmf_n_components=64 \
    --nmf_tol=1e-8

########################


PER_EXAMPLES_FISHERS=og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5
DECOMP_FILENAME="nmf_decomp.8k.4k.256.reduced_1.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=8192 \
    --reduce_threshold=1 \
    --start_fisher_index=0 \
    --end_fisher_index=4096 \
    --nmf_n_components=256 \
    --nmf_max_iter=1000 \
    --nmf_tol=1e-8

########################

TRAINED_CKPT="${MODELS_DIR}/og_tf__bert_small__100k_steps"
FISHER_PATH="${FISHER_DIR}/og_tf__bert_small__100k_steps.dense.32k.h5" 

CUDA_VISIBLE_DEVICES=3 python ./scripts1/ogmm/compute_fisher.py  \
    --model=$TRAINED_CKPT \
    --from_pt=false \
    --tokenizer="prajjwal1/bert-small" \
    --task=math_dataset/original_true_false \
    --fisher_path=$FISHER_PATH \
    --batch_size=128 \
    --n_examples=32768


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/math_datasets_dev1"
VOCABS_DIR="${EXPS_DIR}/vocabs1"
MODELS_DIR="${EXPS_DIR}/models1"
FISHER_DIR="${EXPS_DIR}/fishers0"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers0"

SPM_FILENAME="math_ds_og_tf.178.model"
SPM_FILEPATH="${VOCABS_DIR}/${SPM_FILENAME}"
SPM_TOKENIZER="${VOCABS_DIR}/math_ds_og_tf.178"

###########################################################

# Train the SPM.
CUDA_VISIBLE_DEVICES= python ./scripts1/data_gen/make_math_spm.py  \
    --output_path=$SPM_FILEPATH \
    --dataset=og_true_false \
    --n_examples=32768 \
    --vocab_size=178

###########################################################


CUDA_VISIBLE_DEVICES=3 python  ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/og_tf__h512_l8_spm__100k_steps" \
    --hidden_size=512 \
    --num_hidden_layers=8 \
    --tokenizer=$SPM_TOKENIZER \
    --task=math_dataset/original_true_false \
    --batch_size=64 \
    --learning_rate=3e-5 \
    --n_steps=100_000

###########################################################

PER_EXAMPLES_FISHERS=og_tf__h512_l8_spm__100k_steps.sparse_dynamic_raw.16k.32k.h5

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/og_tf__h512_l8_spm__100k_steps" \
    --tokenizer="${SPM_TOKENIZER}" \
    --from_pt_trained=false \
    --task=math_dataset/original_true_false \
    --n_examples=32768 \
    --batch_size=64 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=16384

###########################################################

PER_EXAMPLES_FISHERS=og_tf__h512_l8_spm__100k_steps.sparse_dynamic_raw.16k.32k.h5
DECOMP_FILENAME="nmf_decomp.8k.4k.64.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=8192 \
    --start_fisher_index=0 \
    --end_fisher_index=4096 \
    --nmf_n_components=64 \
    --nmf_tol=1e-6

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/math_datasets_dev1"
MODELS_DIR="${EXPS_DIR}/models1"
FISHER_DIR="${EXPS_DIR}/fishers0"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers0"

#######################################


CUDA_VISIBLE_DEVICES=3 python  ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/og_tf__bert_base_uncased__seqlen64__512_batch_50k_steps" \
    --model="bert-base-uncased" \
    --task=math_dataset/original_true_false \
    --batch_size=512 \
    --learning_rate=1e-4 \
    --n_steps=50_000 \
    --sequence_length=64 \
    --clipnorm=0.1


CUDA_VISIBLE_DEVICES=1 python  ./scripts1/training/finetune.py  \
    --output_path="${MODELS_DIR}/og_tf__bert_base_uncased__seqlen64__256_batch_50k_steps" \
    --model="bert-base-uncased" \
    --task=math_dataset/original_true_false \
    --batch_size=256 \
    --learning_rate=5e-5 \
    --n_steps=50_000 \
    --sequence_length=64 \
    --clipnorm=0.1


#######################################

MODEL=og_tf__bert_base_uncased__seqlen64__256_batch_50k_steps_epoch9
PER_EXAMPLES_FISHERS="${MODEL}.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=1 python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --trained_model="${MODELS_DIR}/${MODEL}" \
    --tokenizer="bert-base-uncased" \
    --from_pt_trained=false \
    --task=math_dataset/original_true_false \
    --n_examples=32768 \
    --batch_size=4 \
    --expectation_wrt_logits=true \
    --flavor=sparse_dynamic_raw \
    --n_fisher_values_per_example=32768 \
    --include_embeddings=false \
    --include_pooler=false


MODEL=og_tf__bert_base_uncased__seqlen64__256_batch_50k_steps_epoch9
PER_EXAMPLES_FISHERS="${MODEL}.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.8k.4k.64.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=8192 \
    --start_fisher_index=0 \
    --end_fisher_index=4096 \
    --nmf_n_components=64 \
    --reduce_threshold=1 \
    --nmf_max_iter=1000 \
    --nmf_tol=1e-8


MODEL=og_tf__bert_base_uncased__seqlen64__256_batch_50k_steps_epoch9
PER_EXAMPLES_FISHERS="${MODEL}.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.8k.16k.64.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=1 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=8192 \
    --start_fisher_index=0 \
    --end_fisher_index=16384 \
    --nmf_n_components=64 \
    --reduce_threshold=1 \
    --nmf_max_iter=3000 \
    --nmf_tol=1e-8

# nmf loss 633 @ 10