#include "./fwd_decls.cc.h"

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cuda/bernoulli/bernoulli_util.h>
#include <cuda/util/args_validation.h>
#include <cuda/util/misc_util.h>
#include <cuda/util/device/reductions.h>

#include <cuda/util/editor_hack.h>

namespace sfrp {
namespace bernoulli {
namespace sparse {


const uint32_t WARP_SIZE = 32;
const uint32_t N_THREADS_PER_BLOCK = 1024;


// NOTE: n_outputs_per_thread cannot be larger than 32 since curand generates 32 bits per call.
// NOTE: Values that are too large can cause the kernel to fail (potentially silently).
template <typename scalar_t, typename curandState_t>
__global__ void rp_v1_alg3_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj,
    uint64_t sparse_region_size
) {
    const uint32_t n_threads_per_block = N_THREADS_PER_BLOCK;
    const uint32_t n_blocks_per_sparse_region = sparse_region_size / N_THREADS_PER_BLOCK;

    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t out_col_index = blockIdx.z;

    // Also depends on the specific way that I constructed the blocks.
    const uint64_t sequence_number = blockIdx.z * n_threads_per_block + thread_index;

    const uint32_t warp_index = thread_index / WARP_SIZE;
    const uint32_t warp_lane = thread_index % WARP_SIZE;


    curandState_t state;
    curand_init(seed, sequence_number, /*offset=*/ 0, &state);


    scalar_t thread_acc = 0.0f;
    uint64_t warp_base_offset = warp_index * WARP_SIZE;

    while (warp_base_offset < d_og) {

        uint32_t warp_sparse_region_offset_in_blocks;
        if(warp_lane == 0) {
            warp_sparse_region_offset_in_blocks = curand(&state) % n_blocks_per_sparse_region;
        }
        warp_sparse_region_offset_in_blocks = __shfl_sync(0xffffffff, warp_sparse_region_offset_in_blocks, 0);

        // NOTE: I'm too lazy to see if the casts to uint64_t are needed to prevent possible overflows.
        uint64_t warp_sparse_region_offset = ((uint64_t) warp_sparse_region_offset_in_blocks) * ((uint64_t) n_threads_per_block);
        // uint64_t thread_offset = warp_base_offset + warp_sparse_region_offset + thread_index;
        uint64_t thread_offset = warp_base_offset + warp_sparse_region_offset + warp_lane;

        if (thread_offset < d_og) {
            scalar_t mat_value = mat[vec_index * d_og + thread_offset];
            uint32_t rand_value = curand(&state);
            thread_acc += mat_value * parity_as_pm1(rand_value);
        }

        warp_base_offset += sparse_region_size;
    }

    // float output_value = reduce_across_block(thread_acc);
    float output_value = device::reduce_scalar_across_block<scalar_t, N_THREADS_PER_BLOCK>(thread_acc);
    if (thread_index == 0) {
        uint64_t out_flat_index = vec_index * d_proj + out_col_index;
        // TODO: See if there is some cheap way to correct the norm of the random projection column
        // to correct for when a chunk straddles the end of the vector.
        out[out_flat_index] = output_value * rsqrtf((scalar_t) (d_og / n_blocks_per_sparse_region));
    }
}


template <typename curandState_t>
void rp_v1_alg3(at::Tensor mat, uint32_t seed, uint64_t sparse_region_size, at::Tensor out) {
    // Make sure these tensors are valid and compatible.
    check_valid_mat_and_out(mat, out);

    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    TORCH_CHECK((sparse_region_size % n_threads) == 0, "The sparse region size must be a multipe of the number of threads.");

    const dim3 blocks(1, n_vecs, d_proj);

    AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "rp_v2_alg3", ([&] {
        CURRY_KERNEL_CALL((rp_v1_alg3_kernel<scalar_t, curandState_t>), blocks, n_threads)(
            mat.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj,
            sparse_region_size
        );
    }));
}



void rp_v1_xorwow_alg3(at::Tensor mat, uint32_t seed, uint64_t sparse_region_size, at::Tensor out) {
    rp_v1_alg3<curandStateXORWOW_t>(mat, seed, sparse_region_size, out);
}



}  // sparse
}  // bernoulli
}  // sfrp












// // NOTE: n_outputs_per_thread cannot be larger than 32 since curand generates 32 bits per call.
// // NOTE: Values that are too large can cause the kernel to fail (potentially silently).
// template <typename scalar_t, typename curandState_t>
// __global__ void rp_v1_alg3_kernel(
//     const scalar_t* mat,
//     scalar_t* out,
//     uint32_t seed,
//     size_t n_vecs,
//     size_t d_og,
//     size_t d_proj,
//     uint64_t sparse_region_size
// ) {
//     const uint32_t n_threads_per_block = N_THREADS_PER_BLOCK;
//     const uint32_t n_blocks_per_sparse_region = sparse_region_size / N_THREADS_PER_BLOCK;

//     // These depend on the specific way that I constructed the blocks.
//     const uint64_t thread_index = threadIdx.x;
//     const uint64_t vec_index = blockIdx.y;
//     const uint64_t out_col_index = blockIdx.z;

//     // Also depends on the specific way that I constructed the blocks.
//     const uint64_t sequence_number = blockIdx.z * n_threads_per_block + thread_index;

//     const uint32_t warp_index = thread_index / WARP_SIZE;
//     const uint32_t warp_lane = thread_index % WARP_SIZE;


//     curandState_t state;
//     curand_init(seed, sequence_number, /*offset=*/ 0, &state);


//     scalar_t thread_acc = 0.0f;
//     uint64_t warp_base_offset = warp_index * WARP_SIZE;

//     while (warp_base_offset < d_og) {

//         uint32_t warp_sparse_region_offset_in_blocks;
//         if(warp_lane == 0) {
//             warp_sparse_region_offset_in_blocks = curand(&state) % n_blocks_per_sparse_region;
//         }
//         warp_sparse_region_offset_in_blocks = __shfl_sync(0xffffffff, warp_sparse_region_offset_in_blocks, 0);

//         // NOTE: I'm too lazy to see if the casts to uint64_t are needed to prevent possible overflows.
//         uint64_t warp_sparse_region_offset = ((uint64_t) warp_sparse_region_offset_in_blocks) * ((uint64_t) n_threads_per_block);
//         // uint64_t thread_offset = warp_base_offset + warp_sparse_region_offset + thread_index;
//         uint64_t thread_offset = warp_base_offset + warp_sparse_region_offset + warp_lane;

//         if (thread_offset < d_og) {
//             scalar_t mat_value = mat[vec_index * d_og + thread_offset];
//             uint32_t rand_value = curand(&state);
//             thread_acc += mat_value * parity_as_pm1(rand_value);
//         }

//         warp_base_offset += sparse_region_size;
//     }

//     // uint32_t rand_value;
//     // uint8_t bit_counter = 0;

//     // while (warp_base_offset < d_og) {

//     //     uint32_t warp_sparse_region_offset_in_blocks;
//     //     if(warp_lane == 0) {
//     //         warp_sparse_region_offset_in_blocks = curand(&state) % n_blocks_per_sparse_region;
//     //     }
//     //     warp_sparse_region_offset_in_blocks = __shfl_sync(0xffffffff, warp_sparse_region_offset_in_blocks, 0);

//     //     // NOTE: I'm too lazy to see if the casts to uint64_t are needed to prevent possible overflows.
//     //     uint64_t warp_sparse_region_offset = ((uint64_t) warp_sparse_region_offset_in_blocks) * ((uint64_t) n_threads_per_block);
//     //     uint64_t thread_offset = warp_base_offset + warp_sparse_region_offset + thread_index;

//     //     if (thread_offset < d_og) {
//     //         scalar_t mat_value = mat[vec_index * d_og + thread_offset];
//     //         if (bit_counter == 0) {
//     //             rand_value = curand(&state);
//     //         }
//     //         thread_acc += mat_value * bit_parity_as_pm1(rand_value, bit_counter);
//     //         bit_counter = (bit_counter + 1) % 32;;
//     //     }

//     //     warp_base_offset += sparse_region_size;
//     // }


//     // float output_value = reduce_across_block(thread_acc);
//     float output_value = device::reduce_scalar_across_block<scalar_t, N_THREADS_PER_BLOCK>(thread_acc);
//     if (thread_index == 0) {
//         uint64_t out_flat_index = vec_index * d_proj + out_col_index;
//         // TODO: See if there is some cheap way to correct the norm of the random projection column
//         // to correct for when a chunk straddles the end of the vector.
//         out[out_flat_index] = output_value * rsqrtf((scalar_t) (d_og / n_blocks_per_sparse_region));
//     }
// }

