#include "./fwd_decls.cc.h"

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cuda/bernoulli/bernoulli_util.h>
#include <cuda/util/args_validation.h>
#include <cuda/util/misc_util.h>

#include <cuda/util/editor_hack.h>

namespace sfrp {
namespace bernoulli {
namespace sparse {


// const uint32_t WARP_SIZE = 32;
const uint32_t N_THREADS_PER_BLOCK = 1024;



template <typename scalar_t, typename curandState_t>
__global__ void rp_v2_alg1_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj,
    uint64_t sparse_region_size
) {
    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t out_col_index = blockIdx.z * blockDim.x + thread_index;

    // NOTE: This was originally [which was wrong]:
    //      if (out_col_index > d_proj) { return; }
    // If doing a single vector, then this would not impact the output since it would write beyond it. If
    // doing multiple vectors, this would result in an incorrect value getting written to the first value
    // of all but the first vector.
    if (out_col_index >= d_proj) { return; }

    // Also depends on the specific way that I constructed the blocks.
    const uint64_t sequence_number = out_col_index;

    curandState_t state;
    curand_init(seed, sequence_number, /*offset=*/ 0, &state);

    scalar_t acc = 0.0f;
    uint64_t base_offset = 0;

    while (base_offset < d_og) {
        uint64_t region_offset = curand(&state) % sparse_region_size;
        uint64_t offset = base_offset + region_offset;

        if (offset < d_og) {
            scalar_t mat_value = mat[vec_index * d_og + offset];
            uint32_t rand_value = curand(&state);
            acc += mat_value * parity_as_pm1(rand_value);
        }

        base_offset += sparse_region_size;
    }

    uint64_t out_flat_index = vec_index * d_proj + out_col_index;
    // TODO: See if there is some cheap way to correct the norm of the random projection column
    // to correct for when a region straddles the end of the vector.
    out[out_flat_index] = acc * rsqrtf((scalar_t) (d_og / sparse_region_size));
}



template <typename curandState_t>
void rp_v2_alg1(at::Tensor mat, uint32_t seed, uint64_t sparse_region_size, at::Tensor out) {
    // Make sure these tensors are valid and compatible.
    check_valid_mat_and_out(mat, out);

    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    const dim3 blocks(1, n_vecs, CEIL_DIV(d_proj, n_threads));

    AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "rp_v2_alg1", ([&] {
        CURRY_KERNEL_CALL((rp_v2_alg1_kernel<scalar_t, curandState_t>), blocks, n_threads)(
            mat.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj,
            sparse_region_size
        );
    }));
}



void rp_v2_xorwow_alg1(at::Tensor mat, uint32_t seed, uint64_t sparse_region_size, at::Tensor out) {
    rp_v2_alg1<curandStateXORWOW_t>(mat, seed, sparse_region_size, out);
}


}  // sparse
}  // bernoulli
}  // sfrp


