# Stuff for the simple divisibility dataset and my attempts to
# mechanistically understand the NMF components.

###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets


###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/divis_mech1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"


#######################################

CUDA_VISIBLE_DEVICES=0 python ./scripts1/mech/divis/train_divis_model.py  \
    --output_path="${MODELS_DIR}/misc_divis_model_001.h5" \
    --embeddings_size=128 \
    --n_layers=12 \
    --layer_config=ffw_res_nln:2048,8192 \
    --activation_fn=relu  \
    --steps_per_epoch=256 \
    --curriculum_epochs=10,8,8,8,8,8 \
    --batch_size=4096 \
    --max_divisor=9


CUDA_VISIBLE_DEVICES=0 python ./scripts1/mech/divis/train_divis_model.py  \
    --output_path="${MODELS_DIR}/misc_divis_model_001_half.h5" \
    --embeddings_size=64 \
    --n_layers=6 \
    --layer_config=ffw_res_nln:1024,4096 \
    --activation_fn=relu  \
    --steps_per_epoch=256 \
    --curriculum_epochs=10,8,8,8,8,8 \
    --batch_size=4096 \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999

CUDA_VISIBLE_DEVICES=1 python ./scripts1/mech/divis/train_divis_model.py  \
    --output_path="${MODELS_DIR}/misc_divis_model_001_semi_quarter.h5" \
    --embeddings_size=64 \
    --n_layers=6 \
    --layer_config=ffw_res_nln:512,2048 \
    --activation_fn=relu  \
    --steps_per_epoch=256 \
    --curriculum_epochs=10,8,8,8,8,8 \
    --batch_size=4096 \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999

CUDA_VISIBLE_DEVICES=0 python ./scripts1/mech/divis/train_divis_model.py  \
    --output_path="${MODELS_DIR}/misc_divis_model_001_semi_eighth.h5" \
    --embeddings_size=64 \
    --n_layers=6 \
    --layer_config=ffw_res_nln:256,1024 \
    --activation_fn=relu  \
    --steps_per_epoch=256 \
    --curriculum_epochs=10,8,8,8,8,8 \
    --batch_size=4096 \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999

CUDA_VISIBLE_DEVICES=3 python ./scripts1/mech/divis/train_divis_model.py  \
    --output_path="${MODELS_DIR}/misc_divis_model_001_semi_sixteenth.h5" \
    --embeddings_size=64 \
    --n_layers=6 \
    --layer_config=ffw_res_nln:128,512 \
    --activation_fn=relu  \
    --steps_per_epoch=256 \
    --curriculum_epochs=10,8,8,8,8,8 \
    --batch_size=4096 \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999

CUDA_VISIBLE_DEVICES=3 python ./scripts1/mech/divis/train_divis_model.py  \
    --output_path="${MODELS_DIR}/misc_divis_model_001_deep_semi_sixteenth.h5" \
    --embeddings_size=64 \
    --n_layers=12 \
    --layer_config=ffw_res_nln:128,512 \
    --activation_fn=relu  \
    --steps_per_epoch=256 \
    --curriculum_epochs=10,8,8,8,8,8 \
    --batch_size=4096 \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999

#######################################

MODEL="misc_divis_model_001_half"
PER_EXAMPLES_FISHERS="${MODEL}.sparse_dynamic_raw.32k.32k.h5"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/mech/divis/save_divis_per_example_fishers.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --model="${MODELS_DIR}/${MODEL}.h5" \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999 \
    --n_examples=32768 \
    --n_fisher_values_per_example=32768 \
    --batch_size=16


MODEL="misc_divis_model_001_deep_semi_sixteenth"
PER_EXAMPLES_FISHERS="${MODEL}.sparse_dynamic_raw.64k.32k.h5"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/mech/divis/save_divis_per_example_fishers.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --model="${MODELS_DIR}/${MODEL}.h5" \
    --min_divisor=2 \
    --max_divisor=9 \
    --min_dividend=100 \
    --max_dividend=999_999_999 \
    --n_examples=65536 \
    --n_fisher_values_per_example=32768 \
    --batch_size=16


#######################################

MODEL="misc_divis_model_001_half"
PER_EXAMPLES_FISHERS="${MODEL}.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME="nmf_decomp.16k.8k.64.${PER_EXAMPLES_FISHERS}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/decomp/run_nmf.py  \
    --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
    --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
    --n_examples=16384 \
    --start_fisher_index=0 \
    --end_fisher_index=8192 \
    --nmf_n_components=64 \
    --reduce_threshold=1 \
    --nmf_max_iter=3000 \
    --nmf_tol=1e-8



###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

##############################################

EXPS_DIR="${DATA_DIR}/divis_mech1"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"



MODEL="misc_divis_model_001_deep_semi_sixteenth"
PER_EXAMPLES_FISHERS="${MODEL}.sparse_dynamic_raw.64k.32k.h5"
DECOMP_FILENAME="nmf_decomp.32k.8k.384.${PER_EXAMPLES_FISHERS}"


# There are a total of 27 subset indices for this.

run_per_subset_nmf () {
    local device=$1
    local subset_indices=$2
    echo $device
    CUDA_VISIBLE_DEVICES=$device python ./scripts1/mech/divis/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --subset_style=per_layer \
        --model="${MODELS_DIR}/${MODEL}.h5" \
        --subset_indices="${subset_indices}" \
        --n_examples=32768 \
        --start_fisher_index=0 \
        --end_fisher_index=8192 \
        --nmf_n_components=384 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8
}


# NOTE: We have to do something for the subset with index 1
# (I did not compute here due to it OOMing, create something for it tomorrow.)
CUDA_VISIBLE_DEVICES=3 python ./scripts1/mech/divis/run_per_subset_nmf.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${DECOMP_FILENAME}" \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --subset_style=per_layer \
        --model="${MODELS_DIR}/${MODEL}.h5" \
        --subset_indices=1 \
        --n_examples=16384 \
        --start_fisher_index=0 \
        --end_fisher_index=8192 \
        --nmf_n_components=128 \
        --reduce_threshold=1 \
        --nmf_max_iter=3000 \
        --nmf_tol=1e-8

# run_per_subset_nmf 1 1,5,9,13,17,21,25

run_per_subset_nmf 0 0,4,8,12,16,20,24

# run_per_subset_nmf 1 9,13,17,21,25
# run_per_subset_nmf 2 6,10,14,18,22,26
# run_per_subset_nmf 3 7,11,15,19,23

run_per_subset_nmf 1 5,9,13,17,21,25
run_per_subset_nmf 2 2,6,10,14,18,22,26
run_per_subset_nmf 3 3,7,11,15,19,23
