###################################################################################
###################################################################################


VIRTUALENV_NAME=extract_merge1

CODE_DIR=/fruitbasket/users/m/project_code/extract_merge1
DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

cd $CODE_DIR
workon $VIRTUALENV_NAME
export PYTHONPATH="$PYTHONPATH:$CODE_DIR",
export TFDS_DATA_DIR=/fruitbasket/datasets/tensorflow_datasets

###################################################################################
###################################################################################

EXPS_DIR="${DATA_DIR}/imagenet2_lrm_npeff"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"
PERTURBATIONS_DIR="${EXPS_DIR}/perturbations"

###################################################################################
###################################################################################


compute_pefs() {
    local device=$1
    local n_vals_pe=$2
    local n_ex=$3
    local split=$4
    local min_prob_class=$5
    local max_classes=$6

    local model="resnet:resnet50_imagenet"
    local pef="resnet50_imagenet.${split}.${n_ex}ex.${n_vals_pe}.mpc${min_prob_class}.${max_classes}mc.h5"

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/data_gen/compute_lvrm_pefs.py  \
        --output_path="${PER_EXAMPLE_FISHERS_DIR}/${pef}" \
        --model="${model}"   \
        --from_pt=false \
        --task=imagenet/resnet \
        --split=$split \
        --n_examples=$n_ex \
        --n_fisher_values_per_example=$n_vals_pe \
        --sequence_length=224 \
        --min_prob_class=$min_prob_class \
        --max_classes=$max_classes
}

compute_pefs 3 65536 20000 train 3e-3 35
# compute_pefs 3 65536 30000 validation 3e-3 35


###################################################################################


rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/project_data/extract_merge1/imagenet2_lrm_npeff/per_example_fishers/resnet50_imagenet.train.100ex.65536.mpc3e-3.h5" \
    "$HOME/Desktop/projects_data/extract_merge1"


cd ~/Desktop/projects/cuda_m_npeff

cmake --build build; ./build/mains/run_lvrm_npeff \
    --pef_filepath="${HOME}/Desktop/projects_data/extract_merge1/resnet50_imagenet.train.100ex.65536.mpc3e-3.h5" \
    --min_nonzero_per_col=12 \
    --n_components=16 \
    --learning_rate_G_G_only=3e-5 \
    --learning_rate_G=3e-4 \
    --output_filepath="/tmp/adfasdfasdf.h5" \
    --n_preprocess_cpu_threads=4 \
    --n_iters_G_only=30 \
    --n_iters_joint=70

###################################################################################



cd /fruitbasket/users/m/project_code/cuda_m_npeff
export PATH=/home/m/.local/bin:$PATH


DATA_DIR=/fruitbasket/users/m/project_data/extract_merge1
EXPS_DIR="${DATA_DIR}/imagenet2_lrm_npeff"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

PLAYPLAN_PEFS_DIR="/playpen/users/m/project_data/imagenet2_lrm_npeff/per_example_fishers/"



cp "${PER_EXAMPLE_FISHERS_DIR}/resnet50_imagenet.train.20000ex.65536.mpc3e-3.35mc.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=0,3 ./build/mains/run_lvrm_npeff \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/resnet50_imagenet.train.20000ex.65536.mpc3e-3.35mc.h5" \
    --min_nonzero_per_col=8 \
    --n_components=512 \
    --learning_rate_G_G_only=7e-4 \
    --learning_rate_G=3e-4 \
    --output_filepath="${PER_EXAMPLE_FISHERS_DIR}/resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.h5" \
    --n_preprocess_cpu_threads=64 \
    --n_iters_G_only=100 \
    --n_iters_joint=1500



cp "${PER_EXAMPLE_FISHERS_DIR}/resnet50_imagenet.validation.30000ex.65536.mpc3e-3.35mc.h5" $PLAYPLAN_PEFS_DIR
cp "${PER_EXAMPLE_FISHERS_DIR}/resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.h5" $PLAYPLAN_PEFS_DIR


CUDA_VISIBLE_DEVICES=3 ./build/mains/fit_lvrm_coeffs \
    --pef_filepath="${PLAYPLAN_PEFS_DIR}/resnet50_imagenet.validation.30000ex.65536.mpc3e-3.35mc.h5" \
    --decomposition_filepath=$PLAYPLAN_PEFS_DIR/resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.h5 \
    --output_filepath=$PER_EXAMPLE_FISHERS_DIR/resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.fit_to_validation.001.h5 \
    --n_preprocess_cpu_threads=64 \
    --n_iters=500 \
    --mu_eps=1e-7 \
    --n_columns_per_chunk=2500000




run_perturbations_kl() {
    local device=$1
    local split=$2
    local perturbation_magnitude=$3
    local max_abs_cos_sim=$4
    local nmf_filename=$5
    local pef_filename=$6
    local output_name=$7
    local model="resnet:resnet50_imagenet"


    local n_top_examples=128
    local n_total_examples=10000

    CUDA_VISIBLE_DEVICES=$device python ./scripts1/comp_measure/run_m_npeff_perturbations_kl.py \
        --nmf_filepath="${PER_EXAMPLE_FISHERS_DIR}/${nmf_filename}" \
        --pef_filepath="${PER_EXAMPLE_FISHERS_DIR}/${pef_filename}" \
        --model="${model}"   \
        --from_pt=false \
        --task=imagenet/resnet \
        --sequence_length=224 \
        --split=$split \
        --output_filepath="${PERTURBATIONS_DIR}/${output_name}" \
        --n_top_examples=$n_top_examples \
        --n_total_examples=$n_total_examples \
        --perturbation_magnitude=$perturbation_magnitude \
        --max_abs_cos_sim=$max_abs_cos_sim
}

# run_perturbations_kl 1 validation 1e-1 0.35 \
#     resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.fit_to_validation.001.h5 \
#     resnet50_imagenet.validation.30000ex.65536.mpc3e-3.35mc.h5 \
#     kl_perturbations.resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.fit_to_validation.001.json


run_perturbations_kl 0 validation 1e-1 "-1" \
    resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.fit_to_validation.001.h5 \
    resnet50_imagenet.validation.30000ex.65536.mpc3e-3.35mc.h5 \
    kl_perturbations.no_semiorth.resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.fit_to_validation.001.json


rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/project_data/extract_merge1/imagenet2_lrm_npeff/perturbations/*" \
    "$HOME/Desktop/projects_data/extract_merge1/neurips2023/perturbation_kl_jsons/imagenet2_lrm_npeff/"




# CUDA_VISIBLE_DEVICES=0,3 ./build/mains/run_lvrm_npeff \
#     --pef_filepath="${PLAYPLAN_PEFS_DIR}/resnet50_imagenet.train.20000ex.65536.mpc3e-3.35mc.h5" \
#     --min_nonzero_per_col=8 \
#     --n_components=32 \
#     --learning_rate_G_G_only=1e-3 \
#     --learning_rate_G=3e-4 \
#     --output_filepath="${PER_EXAMPLE_FISHERS_DIR}/TEST_LVRM_NPEFF_001.h5" \
#     --n_preprocess_cpu_threads=64 \
#     --n_examples=10000 \
#     --n_iters_G_only=100 \
#     --n_iters_joint=500 \
#     --log_loss_frequency=10


