#include <thread>
#include <vector>

#include "./compute_tr_xx.h"

namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization {


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Private/local helpers.

class TrXxComputer {
    // matrix.shape = [n_examples, n_classes, n_parameters]
    const Dense3Tensor<float>& matrix;
    const int64_t n_examples;
    const int64_t n_classes;
    const int64_t n_parameters;

    const int64_t chunk_index;
    const int64_t n_chunks;

    const int64_t start_example_index;
    const int64_t end_example_index;

public:
    double tr_xx_contribution = -1.0;


    TrXxComputer(const Dense3Tensor<float>& matrix, int64_t chunk_index, int64_t n_chunks) :
        matrix(matrix),
        n_examples(matrix.d1),
        n_classes(matrix.d2),
        n_parameters(matrix.d3),
        // 
        chunk_index(chunk_index),
        n_chunks(n_chunks),
        // 
        start_example_index(compute_start_example_index()),
        end_example_index(compute_end_example_index())
    {}

    void operator()() {
        this->tr_xx_contribution = 0.0;
        for(int64_t i=start_example_index; i<end_example_index; i++) {
            this->tr_xx_contribution += compute_contribution_from_example(i);
        }
    }

private:

    double compute_contribution_from_example(int64_t example_index) {
        double ret = 0.0;

        for(int64_t i=0; i<n_classes; i++) {
            ret += compute_squared_dot_product(example_index, i, i);
            for(int64_t j=i+1; j<n_classes; j++) {
                ret += 2.0 * compute_squared_dot_product(example_index, i, j);
            }
        }

        return ret;
    }

    double compute_squared_dot_product(int64_t example_index, int64_t class_index_1, int64_t class_index_2) {
        double ret = 0.0;
        for (int64_t i=0; i < n_parameters; i++) {
            ret += matrix.get_entry(example_index, class_index_1, i) * matrix.get_entry(example_index, class_index_2, i);
        }
        return ret * ret;
    }

    /////////////////////////////////////////////////////////////////
    // Initialization helper functions.

    int64_t compute_start_example_index() const {
        return chunk_index * (n_examples / n_chunks);
    }

    int64_t compute_end_example_index() const {
        if (chunk_index == n_chunks - 1) {
            return n_examples;
        } else {
            return (chunk_index + 1) * (n_examples / n_chunks);
        }
    }
};

///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Public functions.

// matrix.shape = [n_examples, n_classes, n_parameters]
double compute_tr_xx(const Dense3Tensor<float>& matrix, int64_t n_threads) {
    std::vector<TrXxComputer> workers;
    for(int64_t i=0; i<n_threads; i++) {
        workers.emplace_back(matrix, i, n_threads);
    }

    std::vector<std::thread> threads;
    for(auto& worker : workers) {
        threads.emplace_back(std::ref(worker));
    }

    for(auto& thread : threads) {
        thread.join();
    }

    double tr_xx = 0.0;
    for(auto& worker : workers) {
        tr_xx += worker.tr_xx_contribution;
    }

    return tr_xx;
}


}  // dn_lrm_factorization
}  // factorizations
}  // npeff
