#include "./fwd_decls.cc.h"
#include "./dn_v3_constants.h"

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cuda/bernoulli/bernoulli_util.h>
#include <cuda/util/args_validation.h>
#include <cuda/util/misc_util.h>

#include <cuda/util/editor_hack.h>


namespace sfrp {
namespace bernoulli {
namespace dense {


// NOTE: n_outputs_per_thread cannot be larger than 32 since curand generates 32 bits per call.
// NOTE: Values that are too large can cause the kernel to fail (potentially silently).
template <typename scalar_t, typename curandState_t, uint32_t n_outputs_per_thread>
__global__ void rp_v3_alg3_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    const uint32_t n_threads_per_block = v3::N_THREADS_PER_BLOCK;

    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t base_out_col_index = n_outputs_per_thread * blockIdx.z;

    // Also depends on the specific way that I constructed the blocks.
    const uint64_t sequence_number = blockIdx.z * n_threads_per_block + thread_index;

    curandState_t state;
    curand_init(seed, sequence_number, /*offset=*/ 0, &state);

    scalar_t t_cache[n_outputs_per_thread] = {0.0f};
    uint64_t offset = thread_index;
    while (offset < d_og) {
        uint64_t og_flat_index = vec_index * d_og + offset;
        scalar_t mat_value = mat[og_flat_index];

        const uint32_t rand_value = curand(&state);
        for(int i = 0; i < n_outputs_per_thread; i++) {
            t_cache[i] += mat_value * bit_parity_as_pm1(rand_value, i);
        }

        offset += n_threads_per_block;
    }

    // TODO: Maybe do multiple of these at once.

    __shared__ scalar_t s_cache[n_threads_per_block];

    for(uint64_t i=0; i<n_outputs_per_thread; i++) {
        uint64_t out_col_index = base_out_col_index + i;
        if(out_col_index >= d_proj) { break; }

        s_cache[thread_index] = t_cache[i];

        __syncthreads();

        // We must have n_threads_per_block be a power of 2 due to the following code.
        int j = n_threads_per_block / 2;
        while (j != 0) {
            if (thread_index < j) {
                s_cache[thread_index] += s_cache[thread_index + j];
            }
            __syncthreads();
            j /= 2;
        }

        if (thread_index == 0) {
            uint64_t out_flat_index = vec_index * d_proj + out_col_index;
            out[out_flat_index] = s_cache[0] * rsqrtf((scalar_t) d_og);
        }
    }
}




template <typename curandState_t, uint32_t n_outputs_per_thread>
void rp_v3_alg3(at::Tensor mat, uint32_t seed, at::Tensor out) {
    // Make sure these tensors are valid and compatible.
    check_valid_mat_and_out(mat, out);

    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = v3::N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    const dim3 blocks(1, n_vecs, CEIL_DIV(d_proj, n_outputs_per_thread));

    AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "rp_v2_alg3", ([&] {
        CURRY_KERNEL_CALL((rp_v3_alg3_kernel<scalar_t, curandState_t, n_outputs_per_thread>), blocks, n_threads)(
            mat.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}


void rp_v3_xorwow_alg3(at::Tensor mat, uint32_t seed, at::Tensor out) {
    rp_v3_alg3<curandStateXORWOW_t, v3::XORWOW_N_OUTPUTS_PER_THREAD>(mat, seed, out);
}

// TODO: rp_v3_mrg32k3a_alg3 was always failing. See if this can be fixed.
// void rp_v3_mrg32k3a_alg3(at::Tensor mat, uint32_t seed, at::Tensor out) {
//     rp_v3_alg3<curandStateMRG32k3a_t, 1>(mat, seed, out);
// }




}  // dense
}  // bernoulli
}  // sfrp
