#include "./fwd_decls.cc.h"
#include "./dn_v3_constants.h"

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cuda/bernoulli/bernoulli_util.h>
#include <cuda/util/args_validation.h>
#include <cuda/util/misc_util.h>
#include <cuda/util/device/reductions.h>

#include <cuda/util/editor_hack.h>


namespace sfrp {
namespace bernoulli {
namespace dense {

const uint32_t WARP_SIZE = 32;


template <typename scalar_t, uint32_t warp_size = WARP_SIZE>
__device__ __forceinline__ scalar_t reduce_scalar_within_warp(
    scalar_t thread_value
) {
    // Use XOR mode to perform butterfly reduction
    for (int i=warp_size / 2; i>=1; i/=2) {
        thread_value += __shfl_xor_sync(0xffffffff, thread_value, i, warp_size);
    }

    return thread_value;
}


template <typename scalar_t, uint32_t warp_size = WARP_SIZE>
__device__ __forceinline__ scalar_t reduce_scalar_across_threads_using_warps(
    scalar_t thread_acc,
    uint32_t warp_lane,
    uint32_t warp_index,
    // The shared buffer has a slot for each warp participating the reduction.
    scalar_t* shared_buffer,
    // NOTE: n_warps must be a power of 2. This is NOT checked.
    uint32_t n_warps
) {
    scalar_t warp_acc = reduce_scalar_within_warp<scalar_t, warp_size>(thread_acc);

    if (warp_lane == 0) {
        shared_buffer[warp_index] = warp_acc;
    }
    __syncthreads();

    // We must have n_warps be a power of 2 due to the following code.
    int j = n_warps / 2;
    while (j != 0) {
        if (warp_index < j) {
            shared_buffer[warp_index] += shared_buffer[warp_index + j];
        }
        __syncthreads();
        j /= 2;
    }

    return shared_buffer[0];
}


template <typename scalar_t, typename curandState_t, uint32_t n_outputs_per_og_thread>
__global__ void trp_rp_v3_alg1_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    const uint32_t n_threads_per_block = v3::N_THREADS_PER_BLOCK;
    const uint32_t n_output_rows_per_block = (n_threads_per_block * n_outputs_per_og_thread) / d_proj;


    const uint32_t n_threads_per_output_row = d_proj / n_outputs_per_og_thread;


    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;


    const scalar_t* mat_row_start = mat + vec_index * d_proj;
    scalar_t* output_row_start = out + vec_index * d_og;


    const uint32_t output_row_index_offset = thread_index / n_threads_per_output_row;
    const uint32_t og_output_row_thread_index = thread_index % n_threads_per_output_row;


    const uint32_t og_thread_index = blockIdx.z * n_output_rows_per_block + output_row_index_offset;

    const uint64_t sequence_number = og_output_row_thread_index * n_threads_per_block + og_thread_index;


    const uint32_t n_warps_per_output_row = n_threads_per_output_row / WARP_SIZE;
    // const uint32_t warp_index = thread_index / WARP_SIZE;
    const uint32_t og_output_row_warp_index = og_output_row_thread_index / WARP_SIZE;
    const uint32_t warp_lane = thread_index % WARP_SIZE;

    __shared__ scalar_t s_cache[n_threads_per_block / WARP_SIZE];
    scalar_t* shared_output_row_buffer = s_cache + n_warps_per_output_row * output_row_index_offset;


    curandState_t state;
    curand_init(seed, sequence_number, /*offset=*/ 0, &state);

    uint64_t block_base_output_row_offset = blockIdx.z * n_output_rows_per_block;
    while (block_base_output_row_offset < d_og) {
        scalar_t acc = 0.0;
        uint64_t output_row_offset = block_base_output_row_offset + output_row_index_offset;

        if (output_row_offset < d_og) {
            const uint32_t rand_value = curand(&state);
            for(int i = 0; i < n_outputs_per_og_thread; i++) {
                scalar_t mat_value = mat_row_start[og_output_row_thread_index * n_outputs_per_og_thread + i];
                acc += mat_value * bit_parity_as_pm1(rand_value, i);
            }

        }

        scalar_t output_value = reduce_scalar_across_threads_using_warps<scalar_t, WARP_SIZE>(
            acc,
            warp_lane, og_output_row_warp_index,
            shared_output_row_buffer, n_warps_per_output_row
        );


        if (output_row_offset < d_og && og_output_row_thread_index == 0) {
            output_row_start[output_row_offset] = output_value * rsqrtf((scalar_t) d_og);
        }

        block_base_output_row_offset += n_threads_per_block;
    }
}



// template <typename scalar_t, typename curandState_t, uint32_t n_outputs_per_og_thread>
// __global__ void trp_rp_v3_alg1_kernel(
//     const scalar_t* mat,
//     scalar_t* out,
//     uint32_t seed,
//     size_t n_vecs,
//     size_t d_og,
//     size_t d_proj
// ) {
//     const uint32_t n_threads_per_block = v3::N_THREADS_PER_BLOCK;
//     const uint32_t n_output_rows_per_block = (n_threads_per_block * n_outputs_per_og_thread) / d_proj;


//     const uint32_t n_threads_per_output_row = d_proj / n_outputs_per_og_thread;


//     const uint64_t thread_index = threadIdx.x;
//     const uint64_t vec_index = blockIdx.y;


//     const scalar_t* mat_row_start = mat + vec_index * d_proj;
//     scalar_t* output_row_start = out + vec_index * d_og;


//     const uint32_t output_row_index_offset = thread_index / n_threads_per_output_row;
//     const uint32_t og_output_row_thread_index = thread_index % n_threads_per_output_row;


//     const uint32_t og_thread_index = blockIdx.z * n_output_rows_per_block + output_row_index_offset;

//     const uint64_t sequence_number = og_output_row_thread_index * n_threads_per_block + og_thread_index;


//     __shared__ scalar_t s_cache[n_threads_per_block];
//     scalar_t* shared_output_row_buffer = s_cache + n_threads_per_output_row * output_row_index_offset;

//     curandState_t state;
//     curand_init(seed, sequence_number, /*offset=*/ 0, &state);

//     uint64_t block_base_output_row_offset = blockIdx.z * n_output_rows_per_block;
//     while (block_base_output_row_offset < d_og) {
//         scalar_t acc = 0.0;
//         uint64_t output_row_offset = block_base_output_row_offset + output_row_index_offset;

//         if (output_row_offset < d_og) {
//             const uint32_t rand_value = curand(&state);
//             for(int i = 0; i < n_outputs_per_og_thread; i++) {
//                 scalar_t mat_value = mat_row_start[og_output_row_thread_index * n_outputs_per_og_thread + i];
//                 acc += mat_value * bit_parity_as_pm1(rand_value, i);
//             }

//         }

//         scalar_t output_value = device::reduce_scalar_across_threads(
//             acc, og_output_row_thread_index,
//             shared_output_row_buffer, n_threads_per_output_row
//         );

//         if (output_row_offset < d_og && og_output_row_thread_index == 0) {
//             output_row_start[output_row_offset] = output_value * rsqrtf((scalar_t) d_og);
//         }

//         block_base_output_row_offset += n_threads_per_block;
//     }
// }



template <typename curandState_t, uint32_t n_outputs_per_og_thread>
void trp_rp_v3_alg1(at::Tensor mat, uint32_t seed, at::Tensor out) {
    // Make sure these tensors are valid and compatible.
    check_valid_trp_mat_and_out(mat, out);

    // NOTE: The mat and out tensors will have the opposite number of columns as in the
    // the non-transposed version.
    const auto n_vecs = mat.size(0);
    const auto d_proj = mat.size(1);
    const auto d_og = out.size(1);

    const int32_t n_threads = v3::N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    // TODO: Support cases where d_proj > max_d_proj. Either use some temporary storage or atomic operations.
    const int32_t max_d_proj = n_threads * n_outputs_per_og_thread;
    TORCH_CHECK(d_proj <= max_d_proj, "TODO: The projected dimension is larger than what currently can be handled.");


    // TODO: Only supporting power of 2 d_proj in the first pass.
    TORCH_CHECK(is_power_of_2(d_proj), "TODO: Support non-power of 2 d_proj.");


    const auto n_output_rows_per_block = (n_threads * n_outputs_per_og_thread) / d_proj;
    const dim3 blocks(1, n_vecs, CEIL_DIV(n_threads, n_output_rows_per_block));


    AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "trp_rp_v3_alg1", ([&] {
        CURRY_KERNEL_CALL((trp_rp_v3_alg1_kernel<scalar_t, curandState_t, n_outputs_per_og_thread>), blocks, n_threads)(
            mat.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}


void trp_rp_v3_xorwow_alg1(at::Tensor mat, uint32_t seed, at::Tensor out) {
    trp_rp_v3_alg1<curandStateXORWOW_t, v3::XORWOW_N_OUTPUTS_PER_THREAD>(mat, seed, out);
}



}  // dense
}  // bernoulli
}  // sfrp
