#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <iostream>
using namespace torch::indexing;

#include "utilities.h"


//https://chatgpt.com/c/1571275f-e6cb-4805-bc58-b3c16dd92a4a

// [todo] double check the code below

template <typename scalar_t>
__device__ void apply_elementwise_function_3d(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ shared_parameters,
    scalar_t* __restrict__ output,
    int num_chunks,
    int batch_size,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * 3;
    scalar_t* output_row = output + t * batch_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(num_chunks) / static_cast<scalar_t>(2.0f);

    scalar_t x_1, x_2, x_3;

    int first_index, second_index, third_index;
    int total_index;
    scalar_t output_now;
    // Iterate over the elements in the row
    for (int i = thread_idx; i < batch_size; i += block_size) {
        x_1 = x_row[3 * i];
        x_2 = x_row[3 * i + 1];
        x_3 = x_row[3 * i + 2];

        first_index = __float2int_rz((x_1 + 1.0f) * inverse_chunk_size);
        second_index = __float2int_rz((x_2 + 1.0f) * inverse_chunk_size);
        third_index = __float2int_rz((x_3 + 1.0f) * inverse_chunk_size);

        //first_index = int((x_1 + 1.0f) * inverse_chunk_size);
        //second_index = int((x_2 + 1.0f) * inverse_chunk_size);
        //third_index = int((x_3 + 1.0f) * inverse_chunk_size);

        total_index = 8 * (first_index * num_chunks * num_chunks + second_index * num_chunks + third_index);

        /*output_now =
            shared_parameters[total_index] * x_1 +
            shared_parameters[total_index + 1] * x_2 +
            shared_parameters[total_index + 2] * x_3 +
            shared_parameters[total_index + 3];

        output_now +=
            shared_parameters[total_index + 4] * x_1 * x_2 +
            shared_parameters[total_index + 5] * x_1 * x_3 +
            shared_parameters[total_index + 6] * x_2 * x_3 +
            shared_parameters[total_index + 7] * x_1 * x_2 * x_3;*/

        output_now = shared_parameters[total_index++] * x_1;
        output_now += shared_parameters[total_index++] * x_2;
        output_now += shared_parameters[total_index++] * x_3;
        output_now += shared_parameters[total_index++];

        output_now += shared_parameters[total_index++] * x_1 * x_2;
        output_now += shared_parameters[total_index++] * x_1 * x_3;
        output_now += shared_parameters[total_index++] * x_2 * x_3;
        output_now += shared_parameters[total_index++] * x_1 * x_2 * x_3;


        output_row[i] = output_now;
    }

    // Ensure all threads have finished processing before proceeding
    __syncthreads();
}

template <typename scalar_t>
__device__ void apply_elementwise_function_2d_bspline(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ shared_parameters,
    scalar_t* __restrict__ output,
    int num_chunks,
    int batch_size,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * 2;  // Each data point has 2 components
    scalar_t* output_row = output + t * batch_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(num_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(num_chunks);

    scalar_t x_1, x_2;
    int first_index, second_index;
    scalar_t x_1_delta, x_2_delta;
    scalar_t temp;
    int param_size = num_chunks + 1;
    scalar_t output_now;
    int total_index;

    // Iterate over the elements in the row
    for (int i = thread_idx; i < batch_size; i += block_size) {
        x_1 = x_row[2 * i] + 1.0f;
        x_2 = x_row[2 * i + 1] + 1.0f;

        temp = x_1 * inverse_chunk_size;
        first_index = __float2int_rz(temp);
        x_1_delta = fmaf(-first_index, chunk_size, x_1);

        temp = x_2 * inverse_chunk_size;
        second_index = __float2int_rz(temp);
        x_2_delta = fmaf(-second_index, chunk_size, x_2);

        // Compute the output value using bilinear interpolation
        total_index = first_index * param_size + second_index;
        output_now = shared_parameters[total_index] * (chunk_size - x_1_delta) * (chunk_size - x_2_delta);

        total_index = first_index * param_size + second_index + 1;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * x_2_delta;

        total_index = (first_index + 1) * param_size + second_index;
        output_now += shared_parameters[total_index] * x_1_delta * (chunk_size - x_2_delta);

        total_index = (first_index + 1) * param_size + second_index + 1;
        output_now += shared_parameters[total_index] * x_1_delta * x_2_delta;

        output_row[i] = output_now;
    }

    // Ensure all threads have finished processing before proceeding
    __syncthreads();
}


template <typename scalar_t>
__device__ void apply_elementwise_function_3d_bspline(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ shared_parameters,
    scalar_t* __restrict__ output,
    int num_chunks,
    int batch_size,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * 3;
    scalar_t* output_row = output + t * batch_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(num_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(num_chunks);

    scalar_t x_1, x_2, x_3;

    int first_index, second_index, third_index;
    int total_index;
    scalar_t output_now;
    scalar_t temp;
    scalar_t x_1_delta, x_2_delta, x_3_delta;
    int param_size = num_chunks + 1;
    // Iterate over the elements in the row
    for (int i = thread_idx; i < batch_size; i += block_size) {
        x_1 = x_row[3 * i] + 1.0f;
        x_2 = x_row[3 * i + 1] + 1.0f;
        x_3 = x_row[3 * i + 2] + 1.0f;

        temp = (x_1) * inverse_chunk_size;
        first_index = __float2int_rz(temp);
        x_1_delta = fmaf(-first_index, chunk_size, x_1);

        temp = (x_2) * inverse_chunk_size;
        second_index = __float2int_rz(temp);
        x_2_delta = fmaf(-second_index, chunk_size, x_2);

        temp = (x_3) * inverse_chunk_size;
        third_index = __float2int_rz(temp);
        x_3_delta = fmaf(-third_index, chunk_size, x_3);

        total_index = first_index * (num_chunks + 1) * (num_chunks + 1) + second_index * (num_chunks + 1) + third_index;
        output_now = shared_parameters[total_index] * (chunk_size - x_1_delta) * (chunk_size - x_2_delta) * (chunk_size - x_3_delta);

        total_index = first_index * (num_chunks + 1) * (num_chunks + 1) + second_index * (num_chunks + 1) + third_index + 1;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * (chunk_size - x_2_delta) * x_3_delta;

        total_index = first_index * (num_chunks + 1) * (num_chunks + 1) + (second_index + 1) * (num_chunks + 1) + third_index;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * x_2_delta * (chunk_size - x_3_delta);

        total_index = first_index * (num_chunks + 1) * (num_chunks + 1) + (second_index + 1) * (num_chunks + 1) + third_index + 1;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * x_2_delta * x_3_delta;

        total_index = (first_index + 1) * (num_chunks + 1) * (num_chunks + 1) + second_index * (num_chunks + 1) + third_index;
        output_now += shared_parameters[total_index] * x_1_delta * (chunk_size - x_2_delta) * (chunk_size - x_3_delta);

        total_index = (first_index + 1) * (num_chunks + 1) * (num_chunks + 1) + second_index * (num_chunks + 1) + third_index + 1;
        output_now += shared_parameters[total_index] * x_1_delta * (chunk_size - x_2_delta) * x_3_delta;

        total_index = (first_index + 1) * (num_chunks + 1) * (num_chunks + 1) + (second_index + 1) * (num_chunks + 1) + third_index;
        output_now += shared_parameters[total_index] * x_1_delta * x_2_delta * (chunk_size - x_3_delta);

        total_index = (first_index + 1) * (num_chunks + 1) * (num_chunks + 1) + (second_index + 1) * (num_chunks + 1) + third_index + 1;
        output_now += shared_parameters[total_index] * x_1_delta * x_2_delta * x_3_delta;

        /*total_index = first_index * param_size * param_size + second_index * param_size + third_index;
        output_now = shared_parameters[total_index] * (chunk_size - x_1_delta) * (chunk_size - x_2_delta) * (chunk_size - x_3_delta);

        total_index++;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * (chunk_size - x_2_delta) * x_3_delta;

        total_index += num_chunks;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * x_2_delta * (chunk_size - x_3_delta);

        total_index++;
        output_now += shared_parameters[total_index] * (chunk_size - x_1_delta) * x_2_delta * x_3_delta;

        total_index = (first_index + 1) * param_size * param_size + second_index * param_size + third_index;
        output_now += shared_parameters[total_index] * x_1_delta * (chunk_size - x_2_delta) * (chunk_size - x_3_delta);

        total_index++;
        output_now += shared_parameters[total_index] * x_1_delta * (chunk_size - x_2_delta) * x_3_delta;

        total_index += num_chunks;
        output_now += shared_parameters[total_index] * x_1_delta * x_2_delta * (chunk_size - x_3_delta);

        total_index++;
        output_now += shared_parameters[total_index] * x_1_delta * x_2_delta * x_3_delta;*/

        output_row[i] = output_now;
    }

    // Ensure all threads have finished processing before proceeding
    __syncthreads();
}


template <typename scalar_t>
__device__ void apply_elementwise_function_1d_bspline(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ shared_parameters,
    scalar_t* __restrict__ output,
    int num_chunks,
    int batch_size,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size;
    scalar_t* output_row = output + t * batch_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(num_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(num_chunks);

    scalar_t x_element;
    int chunk_index;

    scalar_t x_delta;
    scalar_t temp;
    // Iterate over the elements in the row
    for (int i = thread_idx; i < batch_size; i += block_size) {
        x_element = x_row[i] + 1.0f;

        // Calculate the chunk index for the current element

        temp = (x_element) * inverse_chunk_size;
        chunk_index = __float2int_rz(temp);
        x_delta = fmaf(-chunk_index, chunk_size, x_element);


        //chunk_index = int((x_element + 1.0) * inverse_chunk_size);
        //x_delta = x_element - (chunk_index * chunk_size - 1.0);

        output_row[i] = (chunk_size - x_delta) * shared_parameters[chunk_index] + x_delta * shared_parameters[chunk_index + 1];
        //output_row[i] = (chunk_size - x_delta) * shared_parameters[chunk_index];
    }

    // Ensure all threads have finished processing before proceeding
    __syncthreads();
}



template <typename scalar_t>
__device__ void apply_elementwise_function_1d(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ shared_parameters,
    scalar_t* __restrict__ output,
    int num_chunks,
    int batch_size,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size;
    scalar_t* output_row = output + t * batch_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(num_chunks) / static_cast<scalar_t>(2.0f);

    scalar_t x_element;
    int chunk_index;
    // Iterate over the elements in the row
    for (int i = thread_idx; i < batch_size; i += block_size) {
        x_element = x_row[i];

        // Calculate the chunk index for the current element

        chunk_index = 2 * __float2int_rz((x_element + 1.0f) * inverse_chunk_size);

        // Ensure the chunk index is within bounds
        //chunk_index = 2 * min(max(chunk_index, 0), num_chunks - 1); should be ensured given x is within [-1, 1]

        // Calculate the linear function based on shared parameters
        output_row[i] = shared_parameters[chunk_index] * x_element + shared_parameters[chunk_index + 1];
    }

    // Ensure all threads have finished processing before proceeding
    __syncthreads();
}

template <typename scalar_t, int dim, bool direct_parametrization>
__global__ void linear_cuda_forward_kernel(
    const scalar_t* __restrict__ parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_chunks,
    int n_params,
    int batch_size,
    int n_func
) {

    // Declare dynamic shared memory as a char buffer
    extern __shared__ char buffer[];

    // Reinterpret the char buffer as scalar_t* for shared parameters
    scalar_t* shared_parameters = reinterpret_cast<scalar_t*>(buffer);

    int t = blockIdx.x; // Index of the row to copy

    // Invoke the device function to copy the row to shared memory
    copy_parameters_to_shared(parameters, shared_parameters, n_params, t);

    // Invoke the device function to apply the element-wise operation and store the result

    if constexpr (direct_parametrization) {
        if constexpr (dim == 1) {
            apply_elementwise_function_1d(x, shared_parameters, output, n_chunks, batch_size, t);
        } else if constexpr (dim == 2) {
            apply_elementwise_function_2d(x, shared_parameters, output, n_chunks, batch_size, t);
        } else if constexpr (dim == 3) {
            apply_elementwise_function_3d(x, shared_parameters, output, n_chunks, batch_size, t);
        } else if constexpr (dim == 4) {
            apply_elementwise_function_4d(x, shared_parameters, output, n_chunks, batch_size, t);
        } else {
            static_assert(dim == 1 || dim == 2 || dim == 3 || dim == 4, "dim must be 1, 2, 3, or 4");
        }
    } else { // direct_parametrization == false
        if constexpr (dim == 1) {
            apply_elementwise_function_1d_bspline(x, shared_parameters, output, n_chunks, batch_size, t);
        } else if constexpr (dim == 2) {
            apply_elementwise_function_2d_bspline(x, shared_parameters, output, n_chunks, batch_size, t);
        } else if constexpr (dim == 3) {
            apply_elementwise_function_3d_bspline(x, shared_parameters, output, n_chunks, batch_size, t);
        } else if constexpr (dim == 4) {
            apply_elementwise_function_4d_bspline(x, shared_parameters, output, n_chunks, batch_size, t);
        } else {
            static_assert(dim == 1 || dim == 2 || dim == 3 || dim == 4, "dim must be 1, 2, 3, or 4");
        }
    }

    // Rest of the kernel code...
}

template <int dim, bool direct_parametrization>
torch::Tensor linear_cuda_forward(
    torch::Tensor parameters,
    torch::Tensor x
) {
    // Ensure the input tensors are contiguous and on the GPU
    CHECK_INPUT(parameters);
    CHECK_INPUT(x);

    // Extract dimensions
    auto n_func = parameters.size(0);   // First dimension of parameters and x
    int n_chunks;
    if constexpr (direct_parametrization) {
        n_chunks = parameters.size(1); // Second dimension of parameters (number of parameters)
    } else {
        n_chunks = parameters.size(1) - 1; // Second dimension of parameters (number of parameters)
    }

    int n_params;
    if constexpr (direct_parametrization) {
        if constexpr (dim == 1) {
            n_params = n_chunks * 2;
        } else if constexpr (dim == 2) {
            n_params = n_chunks * n_chunks * 4;
        } else if constexpr (dim == 3) {
            n_params = n_chunks * n_chunks * n_chunks * 8;
        } else if constexpr (dim == 4) {
            n_params = n_chunks * n_chunks * n_chunks * n_chunks * 16;
        } else {
            static_assert(dim == 1 || dim == 2 || dim == 3 || dim == 4, "dim must be 1, 2, 3, or 4");
        }
    } else { // direct_parametrization == false
        if constexpr (dim == 1) {
            n_params = (n_chunks + 1);
        } else if constexpr (dim == 2) {
            n_params = (n_chunks + 1) * (n_chunks + 1);
        } else if constexpr (dim == 3) {
            n_params = (n_chunks + 1) * (n_chunks + 1) * (n_chunks + 1);
        } else if constexpr (dim == 4) {
            n_params = (n_chunks + 1) * (n_chunks + 1) * (n_chunks + 1) * (n_chunks + 1);
        } else {
            static_assert(dim == 1 || dim == 2 || dim == 3 || dim == 4, "dim must be 1, 2, 3, or 4");
        }
    }

    auto batch_size = x.size(1);                    // Second dimension of x (number of elements in x)

    auto output = torch::empty({x.size(0), x.size(1)}, x.options()); // Output tensor

    // Define grid and block sizes
    dim3 block_dim(1024); // Number of threads per block (adjust as needed)
    dim3 grid_dim((n_func)); // Number of blocks in grid

    AT_DISPATCH_FLOATING_TYPES(parameters.scalar_type(), "one_d_linear_cuda_forward", ([&] {
        // Calculate the required shared memory size:
        // 1. Space for `parameters[t, :]` (n_params elements)

        size_t shared_mem_size = (n_params) * sizeof(scalar_t);

        cudaFuncSetAttribute(linear_cuda_forward_kernel<scalar_t, dim, direct_parametrization>,
          cudaFuncAttributeMaxDynamicSharedMemorySize,
          shared_mem_size);

        linear_cuda_forward_kernel<scalar_t, dim, direct_parametrization><<<grid_dim, block_dim, shared_mem_size>>>(
            parameters.data_ptr<scalar_t>(),
            x.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            n_chunks,
            n_params,
            batch_size,
            n_func
        );
    }));

    return output;
}

template <int dim, bool direct_parametrization>
torch::Tensor linear_gpu_forward(
    torch::Tensor parameters,
    torch::Tensor x
) {
    CHECK_INPUT(parameters);
    CHECK_INPUT(x);

    return linear_cuda_forward<dim, direct_parametrization>(parameters, x);
}

