#include <cmath>

#include "./pef_normalization.h"


namespace npeff {
namespace preprocessing {


int64_t normalize_dn_pefs_in_place(inputs::DnLrmPefs& pefs, bool non_finite_norms_to_zeros, float eps) {
    auto& pefs_tensor = *pefs.pefs;
    float* norms = pefs.pef_frobenius_norms->data.get();

    int64_t n_non_finite = 0;
    for (int64_t example_index = 0; example_index < pefs.n_examples(); example_index++) {
        if (!std::isfinite(norms[example_index])) { n_non_finite++; }
    }

    for (int64_t example_index = 0; example_index < pefs.n_examples(); example_index++) {
        // We divide by the square root of the Frobenious norm of the PEF matrix.
        // This is because we are storing effectively a matrix A while the PEF matrix
        // is given by AA^T.
        float norm = norms[example_index];
        float sqrt_norm = std::sqrt(norm);
        sqrt_norm = std::max(sqrt_norm, eps);

        if (!std::isfinite(norm) && non_finite_norms_to_zeros) {
            for(int64_t class_index = 0; class_index < pefs.n_classes(); class_index++) {
                for(int64_t j = 0; j < pefs.n_parameters(); j++) {
                    pefs_tensor.get_entry(example_index, class_index, j) = 0.0f;
                }
            }
            continue;
        }

        for(int64_t class_index = 0; class_index < pefs.n_classes(); class_index++) {
            for(int64_t j = 0; j < pefs.n_parameters(); j++) {
                pefs_tensor.get_entry(example_index, class_index, j) /= sqrt_norm;
            }
        }
    }

    return n_non_finite;
}



}  // preprocessing
}  // npeff
