#include "./fwd_decls.cc.h"

#include <iostream>

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cuda/bernoulli/bernoulli_util.h>
#include <cuda/util/args_validation.h>
#include <cuda/util/misc_util.h>
#include <cuda/util/device/reductions.h>

#include <cuda/util/editor_hack.h>

namespace sfrp {
namespace bernoulli {
namespace sparse {


const uint32_t N_THREADS_PER_BLOCK = 1024;

// This version does the accumulation using atomic adds in the global memory.
// 
// Has a thread for each entry in the input matrix. The thread then goes over each
// non-zero entry in the corresponding column of the projection matrix and atomically
// adds the product of that entry with the input matrix entry to the corresponding
// position in the output matrix.
template <typename scalar_t, typename curandState_t>
__global__ void trp_rp_v2_alg1_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj,
    uint64_t sparse_region_size
) {
    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t mat_row_index = blockIdx.z * blockDim.x + thread_index;

    if (mat_row_index >= d_proj) { return; }

    // Also depends on the specific way that I constructed the blocks.
    const uint64_t sequence_number = mat_row_index;

    const scalar_t scaled_mat_value = mat[vec_index * d_proj + mat_row_index] * rsqrtf((scalar_t) (d_og / sparse_region_size));

    curandState_t state;
    curand_init(seed, sequence_number, /*offset=*/ 0, &state);

    uint64_t base_offset = 0;
    while (base_offset < d_og) {
        uint64_t region_offset = curand(&state) % sparse_region_size;
        uint64_t offset = base_offset + region_offset;

        if (offset < d_og) {
            uint32_t rand_value = curand(&state);
            // The atomicity must be scoped at least at the device level.
            atomicAdd(
                out + vec_index * d_og + offset,
                scaled_mat_value * parity_as_pm1(rand_value));
        }

        base_offset += sparse_region_size;
    }
}


template <typename curandState_t>
void trp_rp_v2_alg1(at::Tensor mat, uint32_t seed, uint64_t sparse_region_size, at::Tensor out) {
    // Make sure these tensors are valid and compatible.
    check_valid_trp_mat_and_out(mat, out);

    // NOTE: The mat and out tensors will have the opposite number of columns as in the
    // the non-transposed version.
    const auto n_vecs = mat.size(0);
    const auto d_proj = mat.size(1);
    const auto d_og = out.size(1);

    const int32_t n_threads = N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    // Fill the output with zeros since we will be doing atomic adds in its global memory.
    out.zero_();

    const dim3 blocks(1, n_vecs, CEIL_DIV(d_proj, n_threads));

    AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "trp_rp_v2_alg1", ([&] {
        CURRY_KERNEL_CALL((trp_rp_v2_alg1_kernel<scalar_t, curandState_t>), blocks, n_threads)(
            mat.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj,
            sparse_region_size
        );
    }));
}


void trp_rp_v2_xorwow_alg1(at::Tensor mat, uint32_t seed, uint64_t sparse_region_size, at::Tensor out) {
    trp_rp_v2_alg1<curandStateXORWOW_t>(mat, seed, sparse_region_size, out);
}


}  // sparse
}  // bernoulli
}  // sfrp
