#include "./fwd_decls.cc.h"

#include <cuda.h>
#include <cuda_runtime.h>

#include <cuda/bernoulli/bernoulli_util.h>
#include <cuda/util/args_validation.h>
#include <cuda/util/hashing.h>
#include <cuda/util/misc_util.h>

#include <cuda/util/editor_hack.h>


namespace sfrp {
namespace bernoulli {
namespace dense {


const uint32_t N_THREADS_PER_BLOCK = 1024;
// NOTE: Cannot be larger than 32 since we use a 32-bit hash.
const uint32_t N_OUTPUTS_PER_THREAD = 32;



template <typename scalar_t>
__global__ void rp_v2_alg3_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    const uint32_t n_threads_per_block = N_THREADS_PER_BLOCK;
    const uint32_t n_outputs_per_thread = N_OUTPUTS_PER_THREAD;

    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t base_out_col_index = n_outputs_per_thread * blockIdx.z;

    // TODO: Can I use vector ops below for extra speed??

    scalar_t t_cache[n_outputs_per_thread] = {0.0f};
    uint64_t offset = thread_index;
    while (offset < d_og) {
        uint64_t og_flat_index = vec_index * d_og + offset;
        float mat_value = mat[og_flat_index];

        uint64_t base_proj_mat_flat_index = base_out_col_index * d_og + offset;
        const uint32_t hash = sfrp::hashing::fast_seeded_hash(seed, base_proj_mat_flat_index);
        for(int i = 0; i < n_outputs_per_thread; i++) {
            t_cache[i] += mat_value * bit_parity_as_pm1(hash, i);
        }

        offset += n_threads_per_block;
    }

    // TODO: Maybe do multiple of these at once.

    __shared__ scalar_t s_cache[n_threads_per_block];

    for(uint64_t i=0; i<n_outputs_per_thread; i++) {
        uint64_t out_col_index = base_out_col_index + i;
        if(out_col_index >= d_proj) { break; }

        s_cache[thread_index] = t_cache[i];

        __syncthreads();

        // We must have n_threads_per_block be a power of 2 due to the following code.
        int j = n_threads_per_block / 2;
        while (j != 0) {
            if (thread_index < j) {
                s_cache[thread_index] += s_cache[thread_index + j];
            }
            __syncthreads();
            j /= 2;
        }

        if (thread_index == 0) {
            uint64_t out_flat_index = vec_index * d_proj + out_col_index;
            out[out_flat_index] = s_cache[0] * rsqrtf((float) d_og);
        }
    }
}


void rp_v2_alg3(at::Tensor mat, uint32_t seed, at::Tensor out) {
    // Make sure these tensors are valid and compatible.
    check_valid_mat_and_out(mat, out);

    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    const int32_t n_outputs_per_thread = N_OUTPUTS_PER_THREAD;

    const dim3 blocks(1, n_vecs, CEIL_DIV(d_proj, n_outputs_per_thread));

    AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "rp_v2_alg3", ([&] {
        CURRY_KERNEL_CALL(rp_v2_alg3_kernel, blocks, n_threads)(
            mat.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}



}  // dense
}  // bernoulli
}  // sfrp
