#include <cstdint>

#include <torch/extension.h>

#include "./sfrp.fwd_decl.h"



#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

#define CHECK_2D(x) TORCH_CHECK(x.sizes().size() == 2, #x " must be 2-D")




void make_rand_proj_hypercubic_mat_v01(torch::Tensor out, uint32_t seed) {
    CHECK_INPUT(out);
    CHECK_2D(out);
    cuda_make_rand_proj_hypercubic_mat_v01(out, seed);
}


void rand_proj_hypercubic_v01_alg01(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    CHECK_INPUT(mat); CHECK_2D(mat);
    CHECK_INPUT(out); CHECK_2D(out);

    TORCH_CHECK(mat.type() == out.type(), "The dtypes of mat and out must match.");
    TORCH_CHECK(mat.size(0) == out.size(0), "The first dimension of mat and out must match.");

    cuda_rand_proj_hypercubic_v01_alg01(mat, seed, out);
}

void rand_proj_hypercubic_v01_alg02(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    CHECK_INPUT(mat); CHECK_2D(mat);
    CHECK_INPUT(out); CHECK_2D(out);

    TORCH_CHECK(mat.type() == out.type(), "The dtypes of mat and out must match.");
    TORCH_CHECK(mat.size(0) == out.size(0), "The first dimension of mat and out must match.");

    cuda_rand_proj_hypercubic_v01_alg02(mat, seed, out);
}

void rand_proj_hypercubic_v01_alg03(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    CHECK_INPUT(mat); CHECK_2D(mat);
    CHECK_INPUT(out); CHECK_2D(out);

    TORCH_CHECK(mat.type() == out.type(), "The dtypes of mat and out must match.");
    TORCH_CHECK(mat.size(0) == out.size(0), "The first dimension of mat and out must match.");

    cuda_rand_proj_hypercubic_v01_alg03(mat, seed, out);
}

void rand_proj_hypercubic_v02_alg03(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    CHECK_INPUT(mat); CHECK_2D(mat);
    CHECK_INPUT(out); CHECK_2D(out);

    TORCH_CHECK(mat.type() == out.type(), "The dtypes of mat and out must match.");
    TORCH_CHECK(mat.size(0) == out.size(0), "The first dimension of mat and out must match.");

    cuda_rand_proj_hypercubic_v02_alg03(mat, seed, out);
}




PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("make_rand_proj_hypercubic_mat_v01", &make_rand_proj_hypercubic_mat_v01);

    m.def("rand_proj_hypercubic_v1_alg1", &rand_proj_hypercubic_v01_alg01);
    m.def("rand_proj_hypercubic_v1_alg2", &rand_proj_hypercubic_v01_alg02);
    m.def("rand_proj_hypercubic_v1_alg3", &rand_proj_hypercubic_v01_alg03);
    
    m.def("rand_proj_hypercubic_v2_alg3", &rand_proj_hypercubic_v02_alg03);
}




// Maybe dumb idea: do something like cos(seed1 * flat_index + seed2), then use parity of least significant bit in the mantissa
//  - If so, maybe need to use long/double to not repeat.
// First do things based on curand, since there will likely be a way to make them more robust.
