# Command line calls for the signal peptide protein experiments.
###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/pb/signal_peptide"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"



# If on guava (or watermelon when I update it).
set_cuda_ld_library_path
# set_tmux_ld_library_path


###################################################################################
###################################################################################


# python3 ./scripts1/training/finetune.py  \
#     --output_path=/tmp/finetune_test \
#     --model=Rostlab/prot_bert \
#     --task=signal_peptide/sp6_binary \
#     --batch_size=32 \
#     --sequence_length=72 \
#     --n_epochs=10


finetune_signal_peptide_binary() {
    local device=$1

    local model_name="prot_bert.sp6_binary.test01"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/training/finetune.py  \
        --output_path="${MODELS_DIR}/${model_name}" \
        --model=Rostlab/prot_bert \
        --task=signal_peptide/sp6_binary \
        --batch_size=32 \
        --sequence_length=72 \
        --n_epochs=10
}


finetune_signal_peptide_binary 0

# Maybe try at the end of epoch 0 and epoch 1, and the one with the lowest validation loss ?
#  ==> Do _epoch0 and _epoch7

###########################################################################################################################
# _epoch0 # loss: 0.2195 - sparse_categorical_accuracy: 0.9147 - val_loss: 0.0539 - val_sparse_categorical_accuracy: 0.9868
# _epoch1 # loss: 0.0474 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.0345 - val_sparse_categorical_accuracy: 0.9912
# _epoch2 # loss: 0.0467 - sparse_categorical_accuracy: 0.9871 - val_loss: 0.0257 - val_sparse_categorical_accuracy: 0.9927
# _epoch3 # loss: 0.0245 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.0184 - val_sparse_categorical_accuracy: 0.9956
# _epoch4 # loss: 0.0153 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.0176 - val_sparse_categorical_accuracy: 0.9966
# _epoch5 # loss: 0.0188 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.0191 - val_sparse_categorical_accuracy: 0.9946
# _epoch6 # loss: 0.0135 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.0178 - val_sparse_categorical_accuracy: 0.9961
# _epoch7 # loss: 0.0108 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.0136 - val_sparse_categorical_accuracy: 0.9956
# _epoch8 # loss: 0.0108 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.0191 - val_sparse_categorical_accuracy: 0.9961
# _epoch9 # loss: 0.0067 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.0234 - val_sparse_categorical_accuracy: 0.9956
###########################################################################################################################

###################################################################################

compute_pefs() {
    local device=$1
    local epoch=$2
    local n_vals_pe=$3
    local n_ex=$4
    local split=$5

    local model="prot_bert.sp6_binary.test01_epoch${epoch}"

    local pef="prot_bert.epoch_${epoch}.${split}.${n_ex}ex.${n_vals_pe}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/save_per_example_fishers_to_disk.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
        --trained_model="${MODELS_DIR}/${model}"   \
        --from_pt_trained=false \
        --tokenizer="Rostlab/prot_bert" \
        --task=signal_peptide/sp6_binary \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=1 \
        --expectation_wrt_logits=true \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=72 \
        --include_classifier=true

}

# compute_pefs 0 0 131072 20290 train
# compute_pefs 1 0 131072 8811 validation
# compute_pefs 2 7 131072 20290 train
# compute_pefs 3 7 131072 8811 validation

###################################################################################


cd /fruitbasket/users/m/project_code/cuda_nmf1

fit_nmf() {
    local devices=$1
    local epoch=$2
    local n_examples=$3
    local n_vals_pe=$4
    local n_comps=$5
    local min_values_per_parameter=$6
    local nmf_max_iter=$7

    local pef_name="prot_bert.epoch_${epoch}.train.20290ex.131072.h5"
    local output_name="nmf_decomp.c${n_comps}_${nmf_max_iter}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local NVCC=/usr/local/cuda/bin/nvcc
    local COMPUTE_CUDA_VERSION_SH=./build_scripts/compute_cuda_version.sh
    local OX=O3

    $NVCC mains/em/run_nmf_on_pefs.cu \
        -gencode $(sh $COMPUTE_CUDA_VERSION_SH)\
        "-${OX}" \
        -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=$nmf_max_iter \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}"

}



# fit_nmf 0 0 20290 65536 128 4 2500
# # Uses 18632MiB of GPU RAM on single A6000, takes about 7s/step.

# fit_nmf 1,2 7 20290 65536 128 4 2500


###################################################################################


sparsify_nmf() {
    local dummy=$1
    local epoch=$2
    local n_examples=$3
    local n_vals_pe=$4
    local n_comps=$5
    local min_values_per_parameter=$6
    local nmf_max_iter=$7

    
    local pef_name="prot_bert.epoch_${epoch}.train.20290ex.131072.h5"
    local nmf_name="nmf_decomp.c${n_comps}_${nmf_max_iter}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-10
}

# sparsify_nmf dummy 0 20290 65536 128 4 2500
sparsify_nmf dummy 7 20290 65536 128 4 2500

