#include <torch/extension.h>

#include <algorithm>
#include <cuda.h>
#include <cuda_runtime.h>

#include "./sfrp.fwd_decl.h"

#define CEIL_DIV(a, b) ((a) + (b) - 1) / (b)


// https://stackoverflow.com/a/12996028
__device__ __forceinline__ uint64_t fast_hash64(uint64_t x) {
    x = (x ^ (x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
    x = (x ^ (x >> 27)) * UINT64_C(0x94d049bb133111eb);
    x = x ^ (x >> 31);
    return x;
}

// https://stackoverflow.com/a/12996028
__device__ __forceinline__ uint32_t fast_hash32(uint32_t x) {
    x = ((x >> 16) ^ x) * 0x45d9f3b;
    x = ((x >> 16) ^ x) * 0x45d9f3b;
    x = (x >> 16) ^ x;
    return x;
}


// I just made this up, IDK if it works OK.
__device__ __forceinline__ uint32_t fast_seeded_hash(uint32_t seed, uint64_t x) {
    return fast_hash32(((uint32_t) fast_hash64(x)) ^ seed);
}

///////////////////////////////////////////////////////////////////////////////

__device__ __forceinline__ float pm_1_from_hash(uint32_t hash) {
    return (float) (2 * (int) (hash % 2) - 1);
}

__device__ __forceinline__ float compute_rank_proj_hypercubic_mat_entry_v01(uint32_t seed, uint64_t flat_index) {
    const auto value = fast_seeded_hash(seed, flat_index);
    return pm_1_from_hash(value);
}


///////////////////////////////////////////////////////////////////////////////


template <typename scalar_t>
__global__ void cuda_make_rand_proj_hypercubic_mat_v01_kernel(scalar_t* out, uint32_t seed, size_t d_og) {
    const uint64_t column_index = blockIdx.x * blockDim.x + threadIdx.x;
    const uint64_t flat_index = blockIdx.y * d_og + column_index;

    if (column_index < d_og) {
        out[flat_index] = rsqrtf((float) d_og) * compute_rank_proj_hypercubic_mat_entry_v01(seed, flat_index);
    }
}


void cuda_make_rand_proj_hypercubic_mat_v01(torch::Tensor out, uint32_t seed) {
    const auto d_proj = out.size(0);
    const auto d_og = out.size(1);

    const int32_t n_threads = 1024;
    const dim3 blocks((d_og + n_threads - 1) / n_threads, d_proj);

    AT_DISPATCH_FLOATING_TYPES(out.type(), "cuda_make_rand_proj_hypercubic_mat_v01", ([&] {
        cuda_make_rand_proj_hypercubic_mat_v01_kernel<scalar_t><<<blocks, n_threads>>>(
            out.data<scalar_t>(), seed, d_og);
    }));
}


///////////////////////////////////////////////////////////////////////////////
// https://siboehm.com/articles/22/CUDA-MMM



template <typename scalar_t>
__global__ void cuda_rand_proj_hypercubic_v01_alg01_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    // Get position in out this thread is responsible for.
    const uint64_t out_col_index = blockIdx.x * blockDim.x + threadIdx.x;
    // vec_index is the row_index
    const uint64_t vec_index = blockIdx.y * blockDim.y + threadIdx.y;

    if (out_col_index < d_proj && vec_index < n_vecs) {
        scalar_t tmp = 0.0;
        for(uint64_t i = 0; i < d_og; i++) {
            uint64_t og_flat_index = vec_index * d_og + i;
            uint64_t proj_mat_flat_index = out_col_index * d_og + i;
            tmp += mat[og_flat_index] * compute_rank_proj_hypercubic_mat_entry_v01(seed, proj_mat_flat_index);
        }
        uint64_t out_flat_index = vec_index * d_proj + out_col_index;
        out[out_flat_index] = tmp * rsqrtf((float) d_og);
    }

}


// mat.shape = [n_vecs, d_og]
// out.shape = [n_vecs, d_proj]
// Here, assumes that everything has been already validated.
void cuda_rand_proj_hypercubic_v01_alg01(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    const int32_t n_threads = std::min(d_proj, (int64_t) 1024);
    const dim3 blocks((d_proj + n_threads - 1) / n_threads, n_vecs);

    AT_DISPATCH_FLOATING_TYPES(out.type(), "cuda_rand_proj_hypercubic_v01_alg01", ([&] {
        cuda_rand_proj_hypercubic_v01_alg01_kernel<scalar_t><<<blocks, n_threads>>>(
            mat.data<scalar_t>(),
            out.data<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}



///////////////////////////////////////////////////////////////////////////////
// https://github.com/jiekebo/CUDA-By-Example/blob/master/5-dotproduct.cu


const uint32_t HYPERCUBIC_V01_ALG02_N_THREADS_PER_BLOCK = 1024;


template <typename scalar_t>
__global__ void cuda_rand_proj_hypercubic_v01_alg02_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    const uint32_t n_threads_per_block = HYPERCUBIC_V01_ALG02_N_THREADS_PER_BLOCK;

    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t out_col_index = blockIdx.z;


    __shared__ scalar_t cache[n_threads_per_block];

    float tmp = 0.0;
    uint64_t offset = thread_index;
    while (offset < d_og) {
        uint64_t og_flat_index = vec_index * d_og + offset;
        uint64_t proj_mat_flat_index = out_col_index * d_og + offset;
        tmp += mat[og_flat_index] * compute_rank_proj_hypercubic_mat_entry_v01(seed, proj_mat_flat_index);
        offset += n_threads_per_block;
    }

    cache[thread_index] = tmp;

    __syncthreads();

    // We must have n_threads_per_block be a power of 2 due to the following code.
    int i = n_threads_per_block / 2;
    while (i != 0) {
        if (thread_index < i) {
            cache[thread_index] += cache[thread_index + i];
        }
        __syncthreads();
        i /= 2;
    }

    if (thread_index == 0) {
        uint64_t out_flat_index = vec_index * d_proj + out_col_index;
        out[out_flat_index] = cache[0] * rsqrtf((float) d_og);
    }
}


// mat.shape = [n_vecs, d_og]
// out.shape = [n_vecs, d_proj]
// Here, assumes that everything has been already validated.
void cuda_rand_proj_hypercubic_v01_alg02(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = HYPERCUBIC_V01_ALG02_N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    const dim3 blocks(1, n_vecs, d_proj);

    AT_DISPATCH_FLOATING_TYPES(out.type(), "cuda_rand_proj_hypercubic_v01_alg02", ([&] {
        cuda_rand_proj_hypercubic_v01_alg02_kernel<scalar_t><<<blocks, n_threads>>>(
            mat.data<scalar_t>(),
            out.data<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}


///////////////////////////////////////////////////////////////////////////////

const uint32_t HYPERCUBIC_V01_ALG03_N_THREADS_PER_BLOCK = 1024;
// Of the powers of 2, 8 seems to be best for this.
const uint32_t HYPERCUBIC_V01_ALG03_N_OUTPUTS_PER_THREAD = 8;

template <typename scalar_t>
__global__ void cuda_rand_proj_hypercubic_v01_alg03_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    const uint32_t n_threads_per_block = HYPERCUBIC_V01_ALG03_N_THREADS_PER_BLOCK;
    const uint32_t n_outputs_per_thread = HYPERCUBIC_V01_ALG03_N_OUTPUTS_PER_THREAD;

    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t base_out_col_index = n_outputs_per_thread * blockIdx.z;

    // TODO: Can I use vector ops below for extra speed??
    // TODO: Also maybe use multiple bits from the compute_rank_proj_hypercubic_mat_entry_v01 for different outputs
    //          - Would result in different projection matrix for the same seed.

    scalar_t t_cache[n_outputs_per_thread] = {0.0};
    uint64_t offset = thread_index;
    while (offset < d_og) {
        uint64_t og_flat_index = vec_index * d_og + offset;
        float mat_value = mat[og_flat_index];

        for(uint64_t i = 0; i < n_outputs_per_thread; i++) {
            uint64_t proj_mat_flat_index = (base_out_col_index + i) * d_og + offset;
            t_cache[i] += mat_value * compute_rank_proj_hypercubic_mat_entry_v01(seed, proj_mat_flat_index);
        }

        offset += n_threads_per_block;
    }

    // TODO: Maybe do multiple of these at once.

    __shared__ scalar_t s_cache[n_threads_per_block];

    for(uint64_t i=0; i<n_outputs_per_thread; i++) {
        uint64_t out_col_index = base_out_col_index + i;
        if(out_col_index >= d_proj) { break; }

        s_cache[thread_index] = t_cache[i];

        __syncthreads();

        // We must have n_threads_per_block be a power of 2 due to the following code.
        int j = n_threads_per_block / 2;
        while (j != 0) {
            if (thread_index < j) {
                s_cache[thread_index] += s_cache[thread_index + j];
            }
            __syncthreads();
            j /= 2;
        }

        if (thread_index == 0) {
            uint64_t out_flat_index = vec_index * d_proj + out_col_index;
            out[out_flat_index] = s_cache[0] * rsqrtf((float) d_og);
        }
    }
}

// mat.shape = [n_vecs, d_og]
// out.shape = [n_vecs, d_proj]
// Here, assumes that everything has been already validated.
void cuda_rand_proj_hypercubic_v01_alg03(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = HYPERCUBIC_V01_ALG03_N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    const int32_t n_outputs_per_thread = HYPERCUBIC_V01_ALG03_N_OUTPUTS_PER_THREAD;

    const dim3 blocks(1, n_vecs, CEIL_DIV(d_proj, n_outputs_per_thread));

    AT_DISPATCH_FLOATING_TYPES(out.type(), "cuda_rand_proj_hypercubic_v01_alg03", ([&] {
        cuda_rand_proj_hypercubic_v01_alg03_kernel<scalar_t><<<blocks, n_threads>>>(
            mat.data<scalar_t>(),
            out.data<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////


const uint32_t HYPERCUBIC_V02_ALG03_N_THREADS_PER_BLOCK = 1024;
// NOTE: Cannot be larger than 32 since we use a 32-bit hash.
const uint32_t HYPERCUBIC_V02_ALG03_N_OUTPUTS_PER_THREAD = 32;

template <typename scalar_t>
__global__ void cuda_rand_proj_hypercubic_v02_alg03_kernel(
    const scalar_t* mat,
    scalar_t* out,
    uint32_t seed,
    size_t n_vecs,
    size_t d_og,
    size_t d_proj
) {
    const uint32_t n_threads_per_block = HYPERCUBIC_V02_ALG03_N_THREADS_PER_BLOCK;
    const uint32_t n_outputs_per_thread = HYPERCUBIC_V02_ALG03_N_OUTPUTS_PER_THREAD;

    // These depend on the specific way that I constructed the blocks.
    const uint64_t thread_index = threadIdx.x;
    const uint64_t vec_index = blockIdx.y;
    const uint64_t base_out_col_index = n_outputs_per_thread * blockIdx.z;

    // TODO: Can I use vector ops below for extra speed??

    scalar_t t_cache[n_outputs_per_thread] = {0.0f};
    uint64_t offset = thread_index;
    while (offset < d_og) {
        uint64_t og_flat_index = vec_index * d_og + offset;
        float mat_value = mat[og_flat_index];

        uint64_t base_proj_mat_flat_index = base_out_col_index * d_og + offset;
        const uint32_t hash = fast_seeded_hash(seed, base_proj_mat_flat_index);
        for(int i = 0; i < n_outputs_per_thread; i++) {
            t_cache[i] += mat_value * pm_1_from_hash(hash >> i);
        }

        offset += n_threads_per_block;
    }

    // TODO: Maybe do multiple of these at once.

    __shared__ scalar_t s_cache[n_threads_per_block];

    for(uint64_t i=0; i<n_outputs_per_thread; i++) {
        uint64_t out_col_index = base_out_col_index + i;
        if(out_col_index >= d_proj) { break; }

        s_cache[thread_index] = t_cache[i];

        __syncthreads();

        // We must have n_threads_per_block be a power of 2 due to the following code.
        int j = n_threads_per_block / 2;
        while (j != 0) {
            if (thread_index < j) {
                s_cache[thread_index] += s_cache[thread_index + j];
            }
            __syncthreads();
            j /= 2;
        }

        if (thread_index == 0) {
            uint64_t out_flat_index = vec_index * d_proj + out_col_index;
            out[out_flat_index] = s_cache[0] * rsqrtf((float) d_og);
        }
    }
}

// mat.shape = [n_vecs, d_og]
// out.shape = [n_vecs, d_proj]
// Here, assumes that everything has been already validated.
void cuda_rand_proj_hypercubic_v02_alg03(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
    const auto n_vecs = mat.size(0);
    const auto d_og = mat.size(1);
    const auto d_proj = out.size(1);

    // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
    // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
    // not supporting this now.
    const int32_t n_threads = HYPERCUBIC_V02_ALG03_N_THREADS_PER_BLOCK;
    TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

    const int32_t n_outputs_per_thread = HYPERCUBIC_V02_ALG03_N_OUTPUTS_PER_THREAD;

    const dim3 blocks(1, n_vecs, CEIL_DIV(d_proj, n_outputs_per_thread));

    AT_DISPATCH_FLOATING_TYPES(out.type(), "cuda_rand_proj_hypercubic_v02_alg03", ([&] {
        cuda_rand_proj_hypercubic_v02_alg03_kernel<scalar_t><<<blocks, n_threads>>>(
            mat.data<scalar_t>(),
            out.data<scalar_t>(),
            seed,
            n_vecs,
            d_og,
            d_proj
        );
    }));
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////



///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////


// const uint32_t HYPERCUBIC_V01_ALG03_N_THREADS_PER_BLOCK = 1024;
// // const uint32_t HYPERCUBIC_V01_ALG03_MAX_SUMMATION_SIZE = 4 * 1024;
// const uint32_t HYPERCUBIC_V01_ALG03_MAX_SUMMATION_SIZE = 16 * 1024;


// template <typename scalar_t>
// __global__ void cuda_rand_proj_hypercubic_v01_alg03_kernel(
//     const scalar_t* mat,
//     scalar_t* out,
//     uint32_t seed,
//     size_t n_vecs,
//     size_t d_og,
//     size_t d_proj
// ) {
//     const uint32_t n_threads_per_block = HYPERCUBIC_V01_ALG03_N_THREADS_PER_BLOCK;

//     const int32_t max_summation_size = HYPERCUBIC_V01_ALG03_MAX_SUMMATION_SIZE;
//     const uint64_t block_summation_size = n_threads_per_block * max_summation_size;

//     const uint64_t summation_offset = blockIdx.x * block_summation_size;
//     const uint64_t summation_end = min(d_og, summation_offset + block_summation_size);

//     // These depend on the specific way that I constructed the blocks.
//     const uint64_t thread_index = threadIdx.x;
//     const uint64_t vec_index = blockIdx.y;
//     const uint64_t out_col_index = blockIdx.z;


//     __shared__ scalar_t cache[n_threads_per_block];

//     float tmp = 0.0;
//     uint64_t offset = summation_offset + thread_index;
//     while (offset < summation_end) {
//         uint64_t og_flat_index = vec_index * d_og + offset;
//         uint64_t proj_mat_flat_index = out_col_index * d_og + offset;
//         tmp += mat[og_flat_index] * compute_rank_proj_hypercubic_mat_entry_v01(seed, proj_mat_flat_index);
//         offset += n_threads_per_block;
//     }

//     cache[thread_index] = tmp;

//     __syncthreads();

//     // We must have n_threads_per_block be a power of 2 due to the following code.
//     int i = n_threads_per_block / 2;
//     while (i != 0) {
//         if (thread_index < i) {
//             cache[thread_index] += cache[thread_index + i];
//         }
//         __syncthreads();
//         i /= 2;
//     }

//     if (thread_index == 0) {
//         uint64_t out_flat_index = vec_index * d_proj + out_col_index;
//         // out[out_flat_index] = cache[0] * rsqrtf((float) d_og);
//         atomicAdd(out + out_flat_index, cache[0] * rsqrtf((float) d_og));
//     }
// }


// // mat.shape = [n_vecs, d_og]
// // out.shape = [n_vecs, d_proj]
// // Here, assumes that everything has been already validated.
// void cuda_rand_proj_hypercubic_v01_alg03(torch::Tensor mat, uint32_t seed, torch::Tensor out) {
//     const auto n_vecs = mat.size(0);
//     const auto d_og = mat.size(1);
//     const auto d_proj = out.size(1);

//     // TODO: It's probably possible to support smaller dimensions by taking the largest power of two that is
//     // smaller or equal to than d_og. My current use-case involves dimensions much greater than this, so I'm
//     // not supporting this now.
//     const int32_t n_threads = HYPERCUBIC_V01_ALG03_N_THREADS_PER_BLOCK;
//     TORCH_CHECK(d_og >= n_threads, "The original dimension must be greater than the number of threads.");

//     const int32_t max_summation_size = HYPERCUBIC_V01_ALG03_MAX_SUMMATION_SIZE;
//     const uint64_t block_summation_size = n_threads * max_summation_size;

//     const int32_t n_summation_blocks = CEIL_DIV(d_og, block_summation_size);

//     const dim3 blocks(n_summation_blocks, n_vecs, d_proj);

//     // Need to zero out the output tensor since we'll be doing atomicAdds to it.
//     out.zero_();

//     AT_DISPATCH_FLOATING_TYPES(out.type(), "cuda_rand_proj_hypercubic_v01_alg03", ([&] {
//         cuda_rand_proj_hypercubic_v01_alg03_kernel<scalar_t><<<blocks, n_threads>>>(
//             mat.data<scalar_t>(),
//             out.data<scalar_t>(),
//             seed,
//             n_vecs,
//             d_og,
//             d_proj
//         );
//     }));
// }



///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////

// For random projection of vec, create 1 thread per output
// For random projection of matrix, create 1 thread per output, split each row of input matrix into its own block

// Can also try the dot-product based approach (will require some additional storage)

// Also think about implementing this for per-variable gradients, some stuff might be different.



// Basically, have 1 block per output position, then do a strided sum to temp buffer

// Later maybe try ...