#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <iostream>
using namespace torch::indexing;

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

//https://chatgpt.com/c/1571275f-e6cb-4805-bc58-b3c16dd92a4a

// [todo] double check the code below
template <typename scalar_t>
__device__ void apply_elementwise_function(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ shared_parameters,
    scalar_t* __restrict__ output,
    int n_params,
    int num_x,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * num_x;
    scalar_t* output_row = output + t * num_x;

    // Calculate the number of chunks
    int num_chunks = n_params / 2;
    scalar_t chunk_size = static_cast<scalar_t>(2.0) / static_cast<scalar_t>(num_chunks);

    // Iterate over the elements in the row
    for (int i = thread_idx; i < num_x; i += block_size) {
        scalar_t x_element = x_row[i];

        // Calculate the chunk index for the current element
        int chunk_index = static_cast<int>((x_element + 1.0) / chunk_size);

        // Ensure the chunk index is within bounds
        chunk_index = min(max(chunk_index, 0), num_chunks - 1);

        // Calculate the linear function based on shared parameters
        scalar_t result = shared_parameters[2 * chunk_index] * x_element + shared_parameters[2 * chunk_index + 1];

        // Store the result in the output tensor
        output_row[i] = result;
    }

    // Ensure all threads have finished processing before proceeding
    __syncthreads();
}


template <typename scalar_t>
__device__ void copy_parameters_to_shared(
    const scalar_t* __restrict__ parameters_direct,
    scalar_t* shared_parameters,
    int n_params,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be copied
    const scalar_t* parameters_direct_source_now = parameters_direct + t * n_params;

    // Copy elements from global memory to shared memory
    for (int i = thread_idx; i < n_params; i += block_size) {
        shared_parameters[i] = parameters_direct_source_now[i];
    }

    // Ensure all threads have finished copying before proceeding
    __syncthreads();
}

template <typename scalar_t>
__global__ void one_d_linear_cuda_forward_kernel(
    const scalar_t* __restrict__ parameters_direct,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_params,
    int num_x,
    int n_func
) {

    // Declare dynamic shared memory as a char buffer
    extern __shared__ char buffer[];

    // Reinterpret the char buffer as scalar_t* for shared parameters
    scalar_t* shared_parameters = reinterpret_cast<scalar_t*>(buffer);

    int t = blockIdx.x; // Index of the row to copy

    // Invoke the device function to copy the row to shared memory
    copy_parameters_to_shared(parameters_direct, shared_parameters, n_params, t);

    // Invoke the device function to apply the element-wise operation and store the result
    apply_elementwise_function(x, shared_parameters, output, n_params, num_x, t);

    // Rest of the kernel code...
}


torch::Tensor one_d_linear_cuda_forward(
    torch::Tensor parameters_direct,
    torch::Tensor x
) {
    // Ensure the input tensors are contiguous and on the GPU
    CHECK_INPUT(parameters_direct);
    CHECK_INPUT(x);

    // Extract dimensions
    auto n_func = parameters_direct.size(0);   // First dimension of parameters_direct and x
    auto n_params = parameters_direct.size(1); // Second dimension of parameters_direct (number of parameters)
    auto num_x = x.size(1);                    // Second dimension of x (number of elements in x)

    // Output tensor, same shape as x
    auto output = torch::empty_like(x);

    // Define grid and block sizes
    dim3 block_dim(1024); // Number of threads per block (adjust as needed)
    dim3 grid_dim((n_func)); // Number of blocks in grid

    AT_DISPATCH_FLOATING_TYPES(parameters_direct.scalar_type(), "one_d_linear_cuda_forward", ([&] {
        // Calculate the required shared memory size:
        // 1. Space for `parameters_direct[t, :]` (n_params elements)

        size_t shared_mem_size = (n_params) * sizeof(scalar_t);

        one_d_linear_cuda_forward_kernel<scalar_t><<<grid_dim, block_dim, shared_mem_size>>>(
            parameters_direct.data_ptr<scalar_t>(),
            x.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            n_params,
            num_x,
            n_func
        );
    }));

    return output;
}

torch::Tensor one_d_linear_gpu_forward(
    torch::Tensor parameters_direct,
    torch::Tensor x
) {
    CHECK_INPUT(parameters_direct);
    CHECK_INPUT(x);

    return one_d_linear_cuda_forward(parameters_direct, x);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &one_d_linear_gpu_forward, "1D Linear forward (CUDA)");
}

int main() {

}