#include <cstdint>

#include <torch/extension.h>

#include "./cuda/bernoulli/fwd_decls.cc.h"


#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME sfrp_torch
#endif


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

    // Dense bernoulli.

    m.def("project_dn_bernoulli_v2_alg3", &sfrp::bernoulli::dense::rp_v2_alg3);

    m.def("project_dn_bernoulli_v3_xorwow_alg3", &sfrp::bernoulli::dense::rp_v3_xorwow_alg3);


    // Sparse bernoulli.

    m.def("project_sp_bernoulli_v1_xorwow_alg3", &sfrp::bernoulli::sparse::rp_v1_xorwow_alg3);

    m.def("project_sp_bernoulli_v2_xorwow_alg1", &sfrp::bernoulli::sparse::rp_v2_xorwow_alg1);


    ///////////////////////////////////////////////////////////////////////////
    // Transposed projections.


    // Dense bernoulli.

    m.def("transposed_project_dn_bernoulli_v3_xorwow_alg1", &sfrp::bernoulli::dense::trp_rp_v3_xorwow_alg1);


    // Sparse bernoulli.

    m.def("transposed_project_sp_bernoulli_v2_xorwow_alg1", &sfrp::bernoulli::sparse::trp_rp_v2_xorwow_alg1);
    
}
