###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/imagenet1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

###################################################################################
###################################################################################


compute_pefs() {
    local device=$1
    local n_vals_pe=$2
    local n_ex=$3
    local min_prob_class=$4
    local split=$5

    local MODEL="resnet:resnet50_imagenet"
    local PER_EXAMPLES_FISHERS="resnet50_imagenet.imagenet_${split}.all_vars.${n_ex}ex.nvpe${n_vals_pe}.mpc${min_prob_class}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_pefs2.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${PER_EXAMPLES_FISHERS}" \
        --model="${MODEL}"  \
        --task=imagenet/resnet \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=1 \
        --flavor=sparse_dynamic_raw \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=224 \
        --fisher_algorithm=fast_compile \
        --min_prob_class=$min_prob_class

}

# compute_pefs 3 131072 20000 3e-3 train

compute_pefs 3 131072 30000 3e-3 validation



cd /fruitbasket/users/m/project_code/cuda_nmf1

fit_nmf() {
    local devices=$1
    local n_examples=$2
    local n_vals_pe=$3
    local n_comps=$4
    local min_values_per_parameter=$5

    local nmf_max_iter=2500

    local pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local output_name="nmf_decomp.c${n_comps}_${nmf_max_iter}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local NVCC=/usr/local/cuda/bin/nvcc
    local COMPUTE_CUDA_VERSION_SH=./build_scripts/compute_cuda_version.sh
    local OX=O3

    $NVCC mains/em/run_nmf_on_pefs.cu \
        -gencode $(sh $COMPUTE_CUDA_VERSION_SH)\
        "-${OX}" \
        -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=$nmf_max_iter \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}"

}

fit_nmf 1,2,3 20000 65536 512 6



sparsify_nmf() {
    local n_examples=$1
    local n_vals_pe=$2
    local n_comps=$3
    local min_values_per_parameter=$4

    local nmf_max_iter=2500


    local pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local nmf_name="nmf_decomp.c${n_comps}_${nmf_max_iter}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local output_name="spH.${nmf_name}"

    CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --H_threshold=1e-10
}

sparsify_nmf 20000 65536 512 6




cd /fruitbasket/users/m/project_code/cuda_nmf1

fit_nmf_validation_set() {
    local device=$1
    local n_vals_pe=$2

    local og_pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local h_name="spH.nmf_decomp.c512_2500Iters_65536pe_mvpp6_20000ex.${og_pef_name}"
    
    local pef_name="resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"
    local output_name="fit_w.${n_vals_pe}vpe.${h_name}"

    sh ./build_scripts/em/fit_coeffs_to_sparse_H.sh
    CUDA_VISIBLE_DEVICES=$device ./build/fit_coeffs_to_sparse_H \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${output_name}" \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --H_path="${PER_EXAMPLE_FISHERS_DIR}/${h_name}" \
        --nmf_max_iter=12000 \
        --n_fisher_values=$n_vals_pe \
        --n_splits_sparse_matmul=2048 \
        --n_row_splits_pefs=50
}

fit_nmf_validation_set 2 65536


###################################################################################


compute_dataset_fisher() {
    local device=$1
    local n_ex=$2
    local min_prob_class=$3
    local split=$4

    local MODEL="resnet:resnet50_imagenet"
    local FISHER="resnet50_imagenet.imagenet_${split}.all_vars.${n_ex}ex.mpc${min_prob_class}.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_fisher2.py  \
        --output_path="${FISHER_DIR}/${FISHER}" \
        --model="${MODEL}"  \
        --task=imagenet/resnet \
        --split=$split \
        --n_examples=$n_ex \
        --batch_size=1 \
        --sequence_length=224 \
        --fisher_algorithm=fast_compile \
        --min_prob_class=$min_prob_class

}

compute_dataset_fisher 0 20000 3e-3 train


###################################################################################
###################################################################################


coeff_kl_relationship2() {
    local device=$1
    local start_component_index=$2
    local n_components=$3
    local outname=$4
    local n_kl_range_targeter_examples=$5

    local output_dir="${EXPS_DIR}/coeff_kl_relationships/${outname}"
    mkdir -p $output_dir

    local og_pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local h_name="spH.nmf_decomp.c512_2500Iters_65536pe_mvpp6_20000ex.${og_pef_name}"
    local nmf_name="fit_w.65536vpe.${h_name}"
    
    local pef_name="resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"

    local fisher_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.mpc3e-3.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/coeff_kl_relationship2.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="resnet:resnet50_imagenet" \
        --output_dir=$output_dir \
        --output_prefix="" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --ds_task=imagenet/resnet \
        --ds_split=validation \
        --ds_n_examples=30000 \
        --ds_sequence_length=224 \
        --min_target_kl=0.05 \
        --max_target_kl=0.2 \
        --n_kl_range_targeter_examples=$n_kl_range_targeter_examples \
        --kl_range_targeter_max_iters=100 \
        --n_kl_range_finds=6 \
        --n_sign_guides_per_kl_range_find=3 \
        --eval_batch_size=256
}

coeff_kl_relationship2 0 0 100 "kl_resnet_imagenet_validation_01" 256


###################################################################################

guided_kl_ablations() {
    local device=$1
    local start_component_index=$2
    local n_components=$3

    local output_dir="${EXPS_DIR}/guided_kl_ablations/attempt01"
    mkdir -p $output_dir

    local og_pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local h_name="spH.nmf_decomp.c512_2500Iters_65536pe_mvpp6_20000ex.${og_pef_name}"
    local nmf_name="fit_w.65536vpe.${h_name}"
    
    local pef_name="resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"

    local fisher_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.mpc3e-3.h5"

    local n_kl_range_finds=6

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_guided_kl_ablations2.py  \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --retaining_fisher_path="${FISHER_DIR}/${fisher_name}" \
        --model="resnet:resnet50_imagenet" \
        --output_dir=$output_dir \
        --output_prefix="test_" \
        --sort_by_coeff_mag=true \
        --start_component_index=$start_component_index \
        --n_components=$n_components \
        --ds_task=imagenet/resnet \
        --ds_split=validation \
        --ds_n_examples=30000 \
        --ds_sequence_length=224 \
        --min_target_kl=0.25 \
        --max_target_kl=0.35 \
        --n_selected_examples=128 \
        --n_kl_range_finds=$n_kl_range_finds \
        --eval_batch_size=256 \
        --ablation_exp_types=component_examples_H,random_examples_H \
        --skip_existing_results
}


# guided_kl_ablations 0 0 50
# guided_kl_ablations 1 50 100


###################################################################################


compute_pef_norm_ratios() {
    local top_ks=$1
    
    local og_pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local h_name="spH.nmf_decomp.c512_2500Iters_65536pe_mvpp6_20000ex.${og_pef_name}"
    local nmf_name="fit_w.65536vpe.${h_name}"
    
    local pef_name="resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"

    CUDA_VISIBLE_DEVICES= python ./scripts1/comp_measure/compute_pef_norm_ratios.py \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --top_ks="${top_ks}"

}

compute_pef_norm_ratios 8,16,32,64,128,256,512

###################################################################################






make_top_example_aux_infos() {
    local n_rows=8
    local n_cols=12

    local k_classes=5


    local output_dir="/fruitbasket/users/m/tmp/imagenet_comp_aux_infos1"
    mkdir -p $output_dir
    
    local og_pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local h_name="spH.nmf_decomp.c512_2500Iters_65536pe_mvpp6_20000ex.${og_pef_name}"
    local nmf_name="fit_w.65536vpe.${h_name}"
    
    local pef_name="resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"

    CUDA_VISIBLE_DEVICES= python ./em/projects/imagenet/mains/top_examples/make_top_example_aux_infos.py \
        --pef_path="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}" \
        --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${nmf_name}" \
        --n_rows="${n_rows}" \
        --n_cols="${n_cols}" \
        --k_classes="${k_classes}" \
        --ds_split=validation \
        --output_dir="${output_dir}"

}

make_top_example_aux_infos



rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/imagenet_comp_aux_infos1/" \
    "$HOME/Desktop/projects_data/extract_merge1/imagenet1/save_bin/imagenet_comp_aux_infos1/"











###################################################################################
###################################################################################

###################################################################################
###################################################################################

###################################################################################
###################################################################################

###################################################################################
###################################################################################

###################################################################################
###################################################################################





cd /fruitbasket/users/m/project_code/cuda_nmf1

dummy_nmf() {
    local devices=$1
    local n_examples=$2
    local n_vals_pe=$3
    local n_comps=$4
    local min_values_per_parameter=$5

    local nmf_max_iter=2500

    local pef_name="resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
    local output_name="nmf_decomp.c${n_comps}_${nmf_max_iter}Iters_${n_vals_pe}pe_mvpp${min_values_per_parameter}_${n_examples}ex.${pef_name}"
    
    local NVCC=/usr/local/cuda/bin/nvcc
    local COMPUTE_CUDA_VERSION_SH=./build_scripts/compute_cuda_version.sh
    local OX=O3

    $NVCC mains/em/run_nmf_on_pefs.cu \
        -gencode $(sh $COMPUTE_CUDA_VERSION_SH)\
        "-${OX}" \
        -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/run_nmf_on_pefs

    CUDA_VISIBLE_DEVICES=$devices ./build/run_nmf_on_pefs \
        --output_path="/dev/null" \
        --nmf_n_components=$n_comps \
        --n_fisher_values=$n_vals_pe \
        --n_examples=$n_examples \
        --nmf_max_iter=$nmf_max_iter \
        --min_values_per_parameter=$min_values_per_parameter \
        --per_example_fishers="${PER_EXAMPLE_FISHERS_DIR}/${pef_name}"

}

# dummy_nmf 0,1,2 8000 16384 512 6
dummy_nmf 1 16000 16384 512 6

#####################################
# These are the A5000:

#####################################
# dummy_nmf _ 8000 16384 512 6
#
# PEFS read in.
# n_cols before reduction: 25583592
# n_cols after reduction: 3392588 [density=0.00410231]

# 1 GPU:
# step 10: 72.0731 [1758.71 ms/step]
# step 20: 60.7151 [1765.46 ms/step]
# step 30: 55.6719 [1770.36 ms/step]
# step 40: 53.6176 [1768.97 ms/step]

# 2 GPU:
# step 10: 72.07 [1280.03 ms/step]
# step 20: 60.738 [1282.34 ms/step]
# step 30: 55.7532 [1283.6 ms/step]

# 3 GPU:
# step 10: 72.073 [1104.06 ms/step]
# step 20: 60.7042 [1105.86 ms/step]
# step 30: 55.7123 [1108.2 ms/step]


#####################################
# dummy_nmf _ 16000 16384 512 6
#
