#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
using namespace torch::indexing;

#include "utilities.h"


template <typename scalar_t, int tile_size>
__device__ void dkan_kernel_2d_thread_per_four_rows(
    const scalar_t* __restrict__ shared_parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_tiles,
    int n_chunks,
    int n_params,
    int batch_size,
    int t
) {
    // half baked, for benchmarking purposes
    constexpr int half_tile_size = tile_size / 2;
    constexpr int quarter_tile_size = tile_size / 4;
    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    constexpr int job_size = 2 * tile_size;
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int my_in_job_index = thread_idx % job_size;
    int my_output_index = my_in_job_index % 4;
    int my_input_shift_index = (my_in_job_index / 8);
    int my_swap_index = (my_in_job_index / 4) % 2;

    int tile_idx = thread_idx / tile_size;
    int element_idx = thread_idx % tile_size;

    int tile_pos_now;
    int offset = tile_idx * tile_size;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * tile_size;
    scalar_t* output_row = output + t * batch_size * tile_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(n_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(n_chunks);

    float temp_first = 0.0f;
    float temp_second = 0.0f;
    float temp_third = 0.0f;
    float temp_fourth = 0.0f;
    float tmp;

    scalar_t x_element_first, x_element_second, x_element_third, x_element_fourth;

    scalar_t x_delta_now_1, x_delta_now_2;
    int chunk_index_now;
    int first_index, second_index;
    int tile_pos = thread_idx % half_size;
    scalar_t x_element_1, x_element_2;

    const float4* __restrict__ x_row_float4 = reinterpret_cast<const float4*>(x_row);
    float4* __restrict__ output_row_float4 = reinterpret_cast<float4*>(output_row);

    float4 x_element_next_full = x_row_float4[thread_idx];

    float x_element_first_next = x_element_next_full.x;
    float x_element_second_next = x_element_next_full.y;
    float x_element_third_next = x_element_next_full.z;
    float x_element_fourth_next = x_element_next_full.w;

    float4 output_full;

    scalar_t param_1, param_2, param_3, param_4;
    scalar_t b_spline_1, b_spline_2, b_spline_3, b_spline_4;

    int source_index;
    int chunk_index_common;

    for (int i = thread_idx; i < batch_size * quarter_tile_size; i += block_size) {
        x_element_first = x_element_first_next;
        x_element_second = x_element_second_next;
        x_element_third = x_element_third_next;
        x_element_fourth = x_element_fourth_next;

        if (4 * (i + block_size) < batch_size * tile_size) {
            //x_element_next = x_row[i + block_size];
            x_element_next_full = x_row_float4[i + block_size];
            x_element_first_next = x_element_next_full.x;
            x_element_second_next = x_element_next_full.y;
            x_element_third_next = x_element_next_full.z;
            x_element_fourth_next = x_element_next_full.w;
        }

        tile_pos_now = tile_pos;

        temp_first = 0.0f;
        temp_second = 0.0f;
        temp_third = 0.0f;
        temp_fourth = 0.0f;

        #pragma unroll
        for (int pair_index = 0; pair_index < 2; ++pair_index){

            source_index= my_input_shift_index;
            if (my_swap_index == 1) {
                tmp = x_element_first;
                x_element_first = x_element_third;
                x_element_third = tmp;

                tmp = x_element_second;
                x_element_second = x_element_fourth;
                x_element_fourth = tmp;
            }
            #pragma unroll
            for (int j = 0; j < quarter_tile_size; ++j) {
                x_element_1 = __shfl_sync(0xffffffff, x_element_first, source_index, quarter_tile_size);
                x_element_2 = __shfl_sync(0xffffffff, x_element_second, source_index, quarter_tile_size);

                source_index += 1;
                if (source_index >= quarter_tile_size) {
                    source_index -= quarter_tile_size;
                }

                first_index = __float2int_rz((x_element_1 + 1.0f) * inverse_chunk_size);
                second_index = __float2int_rz((x_element_2 + 1.0f) * inverse_chunk_size);

                x_delta_now_1 = fmaf(-first_index, chunk_size, x_element_1);
                x_delta_now_2 = fmaf(-second_index, chunk_size, x_element_2);

                b_spline_1 = (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
                b_spline_2 = (chunk_size - x_delta_now_1) * x_delta_now_2;
                b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2);
                b_spline_4 = x_delta_now_1 * x_delta_now_2;

                chunk_index_common =  first_index * (n_chunks + 1) * half_size + second_index * half_size;
                chunk_index_now = chunk_index_common + tile_pos_now;
                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                temp_first = __fmaf_rn(param_1 , b_spline_1, temp_first);
                temp_first = __fmaf_rn(param_2 , b_spline_2, temp_first);
                temp_first = __fmaf_rn(param_3 , b_spline_3, temp_first);
                temp_first = __fmaf_rn(param_4 , b_spline_4, temp_first);

                tile_pos_now += tile_size;
                if (tile_pos_now >= half_size) {
                    tile_pos_now -= half_size;
                }

                chunk_index_now = chunk_index_common + tile_pos_now;
                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                temp_second = __fmaf_rn(param_1 , b_spline_1, temp_second);
                temp_second = __fmaf_rn(param_2 , b_spline_2, temp_second);
                temp_second = __fmaf_rn(param_3 , b_spline_3, temp_second);
                temp_second = __fmaf_rn(param_4 , b_spline_4, temp_second);

                tile_pos_now += tile_size;
                if (tile_pos_now >= half_size) {
                    tile_pos_now -= half_size;
                }

                chunk_index_now = chunk_index_common + tile_pos_now;
                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                temp_third = __fmaf_rn(param_1 , b_spline_1, temp_third);
                temp_third = __fmaf_rn(param_2 , b_spline_2, temp_third);
                temp_third = __fmaf_rn(param_3 , b_spline_3, temp_third);
                temp_third = __fmaf_rn(param_4 , b_spline_4, temp_third);

                tile_pos_now += tile_size;
                if (tile_pos_now >= half_size) {
                    tile_pos_now -= half_size;
                }

                chunk_index_now = chunk_index_common + tile_pos_now;
                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                temp_fourth = __fmaf_rn(param_1 , b_spline_1, temp_fourth);
                temp_fourth = __fmaf_rn(param_2 , b_spline_2, temp_fourth);
                temp_fourth = __fmaf_rn(param_3 , b_spline_3, temp_fourth);
                temp_fourth = __fmaf_rn(param_4 , b_spline_4, temp_fourth);

                tile_pos_now += tile_size;
                if (tile_pos_now >= half_size) {
                    tile_pos_now -= half_size;
                }
            }
            tmp = x_element_first;
            x_element_first = x_element_third;
            x_element_third = tmp;

            tmp = x_element_second;
            x_element_second = x_element_fourth;
            x_element_fourth = tmp;
        }

        if (my_output_index == 0) {
            output_full.x = temp_first;
            output_full.y = temp_second;
            output_full.z = temp_third;
            output_full.w = temp_fourth;
        } else if (my_output_index == 1) {
            output_full.x = temp_second;
            output_full.y = temp_third;
            output_full.z = temp_fourth;
            output_full.w = temp_first;
        } else if (my_output_index == 2) {
            output_full.x = temp_third;
            output_full.y = temp_fourth;
            output_full.z = temp_first;
            output_full.w = temp_second;
        } else {
            output_full.x = temp_fourth;
            output_full.y = temp_first;
            output_full.z = temp_second;
            output_full.w = temp_third;
        }

        output_row_float4[i] = output_full;
    }
}


template <typename scalar_t, int tile_size>
__device__ void dkan_kernel_2d_thread_per_two_rows(
    const scalar_t* __restrict__ shared_parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_tiles,
    int n_chunks,
    int n_params,
    int batch_size,
    int t
) {
    // half baked, for benchmarking purposes
    constexpr int half_tile_size = tile_size / 2;
    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int tile_idx = thread_idx / tile_size;
    int element_idx = thread_idx % tile_size;

    int tile_pos_now;
    int offset = tile_idx * tile_size;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * tile_size;
    scalar_t* output_row = output + t * batch_size * tile_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(n_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(n_chunks);

    float temp_first = 0.0f;
    float temp_second = 0.0f;

    scalar_t x_element_first, x_element_second;

    scalar_t x_delta_now_1, x_delta_now_2;
    int chunk_index_now;
    int first_index, second_index;
    int tile_pos = thread_idx % half_size;
    scalar_t x_element_1, x_element_2;

    float inverse_batch_size = 1.0f / static_cast<float>(tile_size * batch_size + 4096);
    //scalar_t x_element_next = x_row[thread_idx];

    scalar_t x_element_first_next = static_cast<float>(thread_idx) * inverse_batch_size;
    scalar_t x_element_second_next = 0.5f * static_cast<float>(thread_idx) * inverse_batch_size;

    scalar_t param_1, param_2, param_3, param_4;
    scalar_t b_spline_1, b_spline_2, b_spline_3, b_spline_4;

    for (int i = thread_idx; i < batch_size * half_tile_size; i += block_size) {
        x_element_first = x_element_first_next;
        x_element_second = x_element_second_next;
        if (2 * (i + block_size) < batch_size * tile_size) {
            //x_element_next = x_row[i + block_size];
            x_element_first_next = static_cast<float>(i + block_size) * inverse_batch_size;
            x_element_second_next = 0.5f * static_cast<float>(i + block_size) * inverse_batch_size;
        }

        tile_pos_now = tile_pos;
        #pragma unroll
        for (int j = 0; j < half_tile_size; ++j) {
            x_element_1 = __shfl_sync(0xffffffff, x_element_first, j, half_tile_size);
            x_element_2 = __shfl_sync(0xffffffff, x_element_second, j, half_tile_size);

            first_index = __float2int_rz((x_element_1 + 1.0f) * inverse_chunk_size);
            second_index = __float2int_rz((x_element_2 + 1.0f) * inverse_chunk_size);

            x_delta_now_1 = fmaf(-first_index, chunk_size, x_element_1);
            x_delta_now_2 = fmaf(-second_index, chunk_size, x_element_2);

            b_spline_1 = (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
            b_spline_2 = (chunk_size - x_delta_now_1) * x_delta_now_2;
            b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2);
            b_spline_4 = x_delta_now_1 * x_delta_now_2;

            chunk_index_now = first_index * (n_chunks + 1) * half_size + second_index * half_size + tile_pos_now;
            param_1 = shared_parameters[chunk_index_now];
            chunk_index_now += half_size;
            param_2 = shared_parameters[chunk_index_now];
            chunk_index_now += n_chunks * half_size;
            param_3 = shared_parameters[chunk_index_now];
            chunk_index_now += half_size;
            param_4 = shared_parameters[chunk_index_now];

            temp_first = __fmaf_rn(param_1 , b_spline_1, temp_first);
            temp_first = __fmaf_rn(param_2 , b_spline_2, temp_first);
            temp_first = __fmaf_rn(param_3 , b_spline_3, temp_first);
            temp_first = __fmaf_rn(param_4 , b_spline_4, temp_first);

            tile_pos_now += tile_size;
            if (tile_pos_now >= half_size) {
                tile_pos_now -= half_size;
            }

            chunk_index_now = first_index * (n_chunks + 1) * half_size + second_index * half_size + tile_pos_now;
            param_1 = shared_parameters[chunk_index_now];
            chunk_index_now += half_size;
            param_2 = shared_parameters[chunk_index_now];
            chunk_index_now += n_chunks * half_size;
            param_3 = shared_parameters[chunk_index_now];
            chunk_index_now += half_size;
            param_4 = shared_parameters[chunk_index_now];

            temp_second = __fmaf_rn(param_1 , b_spline_1, temp_second);
            temp_second = __fmaf_rn(param_2 , b_spline_2, temp_second);
            temp_second = __fmaf_rn(param_3 , b_spline_3, temp_second);
            temp_second = __fmaf_rn(param_4 , b_spline_4, temp_second);

            tile_pos_now += tile_size;
            if (tile_pos_now >= half_size) {
                tile_pos_now -= half_size;
            }
        }
    }

    output_row[2 * thread_idx] = temp_first;
    output_row[2 * thread_idx + 1] = temp_second;
}

template <typename scalar_t, int tile_size, bool cdf_grid>
__device__ void dkan_kernel_2d_thread_per_tile(
    const scalar_t* __restrict__ shared_parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_tiles,
    int n_chunks,
    int n_params,
    int batch_size,
    int t
) {
    //static_assert(!cdf_grid, "cdf_grid = true is not implemented."); for now just silently ignore cdf_grid = true

    constexpr int half_tile_size = tile_size / 2;
    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int tile_pos = thread_idx % half_size;

    const scalar_t* x_row = x + t * batch_size * tile_size;
    scalar_t* output_row = output + t * batch_size * tile_size;

    scalar_t inverse_chunk_size;
    scalar_t chunk_size;


    inverse_chunk_size = static_cast<scalar_t>(n_chunks) / static_cast<scalar_t>(2.0f);
    chunk_size = 2.0f / static_cast<scalar_t>(n_chunks);


    int first_index, second_index;
    scalar_t x_delta_now_1, x_delta_now_2;

    scalar_t x_current[tile_size], x_current_next[tile_size];



    scalar_t output_current[tile_size];
    scalar_t param_1, param_2, param_3, param_4;
    scalar_t b_spline_1, b_spline_2, b_spline_3, b_spline_4;
    int chunk_index_now;
    int tile_pos_now;
    int chunk_index_precomputed;

    int full_index = thread_idx % half_size;
    int fast_index = full_index % tile_size;
    int slow_index = full_index / tile_size;

    int index_now;
    scalar_t tmp;

    int fast_index_now, slow_index_now;


    #pragma unroll
    for (int j = 0; j < tile_size; ++j) {
        index_now = j + 2 * slow_index;
        if (index_now >= tile_size) {
            index_now -= tile_size;
        }
        //x_current_next[j] = x_row[thread_idx * tile_size + index_now];
        x_current_next[j] = __ldcg(x_row + thread_idx * tile_size + index_now);
    }

    for (int i = thread_idx * tile_size; i < batch_size * tile_size; i += block_size * tile_size) {

        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            x_current[j] = x_current_next[j];
        }

        /*if (i + block_size * tile_size < batch_size * tile_size) {
            #pragma unroll
            for (int j = 0; j < tile_size; ++j) {
                x_current_next[j] = x_row[i + j + block_size * tile_size];
            }
        } */

        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            index_now = j + 2 * slow_index;
            if (index_now >= tile_size) {
                index_now -= tile_size;
            }
            if (i + index_now + block_size * tile_size < batch_size * tile_size) {
                //x_current_next[j] = x_row[i + index_now + block_size * tile_size];
                x_current_next[j] = __ldcg(x_row + i + index_now + block_size * tile_size);
            }
        }

        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            output_current[j] = 0.0f;
        }

        #pragma unroll
        for (int j = 0; j < half_tile_size; ++j) {
            slow_index_now = j + slow_index;
            if (slow_index_now >= half_tile_size) {
                slow_index_now -= half_tile_size;
            }
            first_index = __float2int_rz((x_current[2 * j] + 1.0f) * inverse_chunk_size);
            second_index = __float2int_rz((x_current[2 * j + 1] + 1.0f) * inverse_chunk_size);

            x_delta_now_1 = __fmaf_rn(-first_index, chunk_size, x_current[2 * j]);
            x_delta_now_2 = __fmaf_rn(-second_index, chunk_size, x_current[2 * j + 1]);

            x_delta_now_1 += 1.0f;
            x_delta_now_2 += 1.0f;

            b_spline_1 = (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
            b_spline_2 = (chunk_size - x_delta_now_1) * x_delta_now_2;
            b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2);
            b_spline_4 = x_delta_now_1 * x_delta_now_2;

            chunk_index_precomputed = first_index * (n_chunks + 1) * half_size + second_index * half_size + slow_index_now * tile_size;


            #pragma unroll
            for (int k = 0; k < tile_size; ++k) {

                fast_index_now = fast_index + k;
                if (fast_index_now >= tile_size) {
                    fast_index_now -= tile_size;
                }
                chunk_index_now = chunk_index_precomputed + fast_index_now;
                //chunk_index_now = first_index * (n_chunks + 1) * half_size + second_index * half_size + slow_index_now * tile_size + fast_index_now;
                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                output_current[k] = __fmaf_rn(param_1 , b_spline_1, output_current[k]);
                output_current[k] = __fmaf_rn(param_2 , b_spline_2, output_current[k]);
                output_current[k] = __fmaf_rn(param_3 , b_spline_3, output_current[k]);
                output_current[k] = __fmaf_rn(param_4 , b_spline_4, output_current[k]);

            }
        }

        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            index_now = fast_index + j;
            if (index_now >= tile_size) {
                index_now -= tile_size;
            }
            output_row[i + index_now] = output_current[j];
            /*asm volatile(
                "atom.global.add.cg.f32 [%0], %1;"
                :
                : "l"(output_row + i + index_now),
                "f"(output_current[j])
            );*/
        }
        /*tmp = 0.0;
        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            tmp += output_current[j];
        }
        output_row[i] = tmp;*/
    }
}

template <typename scalar_t, int tile_size>
__device__ void dkan_kernel_2d(
    const scalar_t* __restrict__ shared_parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_tiles,
    int n_chunks,
    int n_params,
    int batch_size,
    int t
) {
    constexpr int half_tile_size = tile_size / 2;
    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int tile_idx = thread_idx / tile_size;
    int element_idx = thread_idx % tile_size;

    int tile_pos_now;
    int offset = tile_idx * tile_size;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * tile_size;
    scalar_t* output_row = output + t * batch_size * tile_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(n_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(n_chunks);
    scalar_t temp, x_element;
    scalar_t x_delta_now_1, x_delta_now_2;
    int chunk_index_now;
    int first_index, second_index;
    int tile_pos = thread_idx % half_size;
    scalar_t x_element_1, x_element_2;

    float inverse_batch_size = 1.0f / static_cast<float>(tile_size * batch_size + 4096);
    //scalar_t x_element_next = x_row[thread_idx];
    scalar_t x_element_next = static_cast<float>(thread_idx) * inverse_batch_size;

    scalar_t param_1, param_2, param_3, param_4;

    int my_index;
    float b_spline_1, b_spline_2, b_spline_3, b_spline_4;

    temp = 0.0f;
    for (int i = thread_idx; i < batch_size * tile_size; i += block_size) {
        x_element = x_element_next;
        if (i + block_size < batch_size * tile_size) {
            //x_element_next = x_row[i + block_size];
            x_element_next = static_cast<float>(i + block_size) * inverse_batch_size;
        }
        my_index = __float2int_rz((x_element + 1.0f) * inverse_chunk_size);
        tile_pos_now = tile_pos;
        #pragma unroll
        for (int j = 0; j < half_tile_size; ++j) {
            //x_element_1 = __shfl_sync(0xffffffff, x_element, 2 * j, tile_size);
            //x_element_2 = __shfl_sync(0xffffffff, x_element, 2 * j + 1, tile_size);

            //x_element_1 = x_element + static_cast<float>(2 * j) * inverse_batch_size * 0.3f;
            //x_element_2 = 0.5f * x_element + static_cast<float>(2 * j + 1) * inverse_batch_size * 0.3f;

            //x_element_1 = __shfl_up_sync(0xffffffff, x_element, 1);
            //x_element_2 = __shfl_up_sync(0xffffffff, x_element, 2);

            //first_index = __float2int_rz((x_element_1 + 1.0f) * inverse_chunk_size);
            //second_index = __float2int_rz((x_element_2 + 1.0f) * inverse_chunk_size);

            //x_delta_now_1 = __fmaf_rn(-first_index, chunk_size, x_element_1);
            //x_delta_now_2 = __fmaf_rn(-second_index, chunk_size, x_element_2);

            first_index = __shfl_sync(0xffffffff, my_index, 2 * j, tile_size);
            //second_index = __shfl_sync(0xffffffff, my_index, 2 * j + 1, tile_size);
            second_index = first_index;

            chunk_index_now = first_index * (n_chunks + 1) * half_size + second_index * half_size + tile_pos_now;
            param_1 = shared_parameters[chunk_index_now];
            chunk_index_now += half_size;
            param_2 = shared_parameters[chunk_index_now];
            chunk_index_now += n_chunks * half_size;
            param_3 = shared_parameters[chunk_index_now];
            chunk_index_now += half_size;
            param_4 = shared_parameters[chunk_index_now];

            /*temp = __fmaf_rn(param_1, (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2), temp);
            temp = __fmaf_rn(param_2, (chunk_size - x_delta_now_1) * x_delta_now_2, temp);
            temp = __fmaf_rn(param_3, x_delta_now_1 * (chunk_size - x_delta_now_2), temp);
            temp = __fmaf_rn(param_4, x_delta_now_1 * x_delta_now_2, temp);*/
            b_spline_1 = __shfl_sync(0xffffffff, x_element, 2 * j, tile_size);
            //b_spline_2 = __shfl_sync(0xffffffff, x_element, 2 * j + 1, tile_size);
            //b_spline_3 = __shfl_sync(0xffffffff, temp, 2 * j, tile_size);
            //b_spline_4 = __shfl_sync(0xffffffff, temp, 2 * j + 1, tile_size);

            temp = __fmaf_rn(param_1, b_spline_1, temp);
            temp = __fmaf_rn(param_2, b_spline_1, temp);
            temp = __fmaf_rn(param_3, b_spline_1, temp);
            temp = __fmaf_rn(param_4, b_spline_1, temp);

            /*temp += param_1 * (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
            temp += param_2 * (chunk_size - x_delta_now_1) * x_delta_now_2;
            temp += param_3 * x_delta_now_1 * (chunk_size - x_delta_now_2);
            temp += param_4 * x_delta_now_1 * x_delta_now_2;*/

            tile_pos_now += tile_size;
            if (tile_pos_now >= half_size) {
                tile_pos_now -= half_size;
            }
        }
    }
    output_row[thread_idx] = temp;
}


template <typename scalar_t>
__device__ void dkan_kernel_1d(
    const scalar_t* __restrict__ shared_parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    scalar_t* __restrict__ x_delta,
    int* __restrict__ chunk_index,
    int n_tiles,
    int tile_size,
    int n_chunks,
    int n_params,
    int batch_size,
    int t
) {
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int tile_idx = thread_idx / tile_size;
    int element_idx = thread_idx % tile_size;

    int offset = tile_idx * tile_size;

    // Calculate the base global index for the row to be processed
    const scalar_t* x_row = x + t * batch_size * tile_size;
    scalar_t* output_row = output + t * batch_size * tile_size;

    scalar_t inverse_chunk_size = static_cast<scalar_t>(n_chunks) / static_cast<scalar_t>(2.0f);
    scalar_t chunk_size = 2.0f / static_cast<scalar_t>(n_chunks);
    scalar_t temp, x_element;
    scalar_t x_delta_now;
    int chunk_index_now;

    for (int i = thread_idx; i < batch_size * tile_size; i += block_size) {
        x_element = x_row[i];
        temp = (x_element + 1.0f) * inverse_chunk_size;
        chunk_index_now = __float2int_rz(temp);
        x_delta[i % block_size] = fmaf(-chunk_index_now, chunk_size, x_element);
        chunk_index[i % block_size] = chunk_index_now;
        __syncthreads();

        // Perform the DKAN operation
        temp = 0.0f;
        for (int j = 0; j < tile_size; ++j) {
            //temp = fmaf(shared_parameters[j + chunk_index_now * tile_size], x_delta[j], temp);
            x_delta_now = x_delta[j + offset];
            //chunk_index_now = element_idx * tile_size * (n_chunks + 1) + j * (n_chunks + 1) + chunk_index[j + offset];
            chunk_index_now = chunk_index[j + offset] * tile_size * tile_size + j * tile_size + element_idx;
            //chunk_index_now = chunk_index[j + offset] * tile_size * tile_size + element_idx * tile_size + j;
            //chunk_index_now = j * tile_size * (n_chunks + 1) + chunk_index[j + offset] * tile_size + element_idx;
            temp += (chunk_size - x_delta_now) * shared_parameters[chunk_index_now] + x_delta_now * shared_parameters[chunk_index_now + 1];
            //temp += x_delta_now + chunk_index_now;
        }

        /*temp = 0.0f;
        #pragma unroll
        for (int j = 0; j < 8; ++j) {
            x_delta_now = x_delta[j + offset];
            chunk_index_now = chunk_index[j + offset] * tile_size * tile_size + j * tile_size + element_idx;
            scalar_t param0 = shared_parameters[chunk_index_now];
            scalar_t param1 = shared_parameters[chunk_index_now + 1];
            scalar_t delta_param = param1 - param0;
            scalar_t param0_chunk_size = param0 * chunk_size;
            temp += __fmaf_rn(delta_param, x_delta_now, param0_chunk_size);
        }*/

        output_row[i] = temp;
        __syncthreads();
    }
}

template <typename scalar_t, int dim, int tile_size, bool cdf_grid>
__global__ void linear_cuda_dkan_forward_kernel(
    const scalar_t* __restrict__ parameters,
    const scalar_t* __restrict__ x,
    scalar_t* __restrict__ output,
    int n_tiles,
    int n_chunks,
    int n_params,
    int batch_size
) {
    extern __shared__ char buffer[];

    char* shared_mem = buffer;

    // Reinterpret the char buffer as scalar_t* for shared parameters
    scalar_t* shared_parameters = reinterpret_cast<scalar_t*>(shared_mem);
    int t = blockIdx.x; // Index of the row to copy

    // Invoke the device function to copy the row to shared memory
    //copy_parameters_to_shared(parameters, shared_parameters, n_params, t);

    if constexpr (dim == 1) {
        /*dkan_kernel_1d(
            shared_parameters,
            x,
            output,
            x_delta,
            chunk_index,
            n_tiles,
            tile_size,
            n_chunks,
            n_params,
            batch_size,
            t
        );*/
    } else if constexpr (dim == 2) {
        dkan_kernel_2d_thread_per_four_rows<scalar_t, tile_size>(
            shared_parameters,
            x,
            output,
            n_tiles,
            n_chunks,
            n_params,
            batch_size,
            t
        );
    } else {
        static_assert(dim == 1 || dim == 2, "dim must be 1 or 2");
    }
}

template <int dim, int block_size, bool cdf_grid>
torch::Tensor linear_cuda_dkan_forward(
    torch::Tensor parameters,
    torch::Tensor x
) {
    // Ensure the input tensors are contiguous and on the GPU
    CHECK_INPUT(parameters);
    CHECK_INPUT(x);

    // x shape: [n_tiles, batch_size, tile_size]
    // parameters shape in 1d case: [n_tiles, tile_size, tile_size, n_chunks + 1]
    // parameters shape in 2d case: [n_tiles, n_chunks + 1, n_chunks + 1, tile_size / 2, tile_size]

    int n_tiles, tile_size, n_chunks;
    if constexpr (dim == 2) {
        n_tiles = parameters.size(0);
        n_chunks = parameters.size(1) - 1;
        tile_size = parameters.size(4);
    } else if constexpr (dim == 1) {
        n_tiles = parameters.size(0);
        n_chunks = parameters.size(3) - 1;
        tile_size = parameters.size(1);
    } else {
        static_assert(dim == 1 || dim == 2, "dim must be 1 or 2");
    }

    auto batch_size = x.size(1);

    int n_params;
    if constexpr (dim == 1) {
        n_params = tile_size * tile_size * (n_chunks + 1);
    } else if constexpr (dim == 2) {
        n_params = tile_size * tile_size * (n_chunks + 1) * (n_chunks + 1) / 2;
    } else {
        static_assert(dim == 1 || dim == 2, "dim must be 1 or 2");
    }

    auto output = torch::empty({x.size(0), x.size(1), x.size(2)}, x.options()); // Output tensor

    dim3 block_dim(block_size); // Number of threads per block (adjust as needed)
    dim3 grid_dim(n_tiles);     // Number of blocks in grid

    // Perform runtime check for allowed tile_size values
    // Now includes 2, 4, 8, 12, 16, 24, 32
    if (tile_size != 2 && tile_size != 4 && tile_size != 8 &&
        tile_size != 12 && tile_size != 16 && tile_size != 24 && tile_size != 32) {
        throw std::invalid_argument("tile_size must be one of {2, 4, 8, 12, 16, 24, 32}");
    }

    AT_DISPATCH_FLOATING_TYPES(parameters.scalar_type(), "linear_cuda_dkan_forward", ([&] {
        size_t shared_mem_size = n_params * sizeof(scalar_t);

        if (tile_size == 2) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 2, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 2, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );

        } else if (tile_size == 4) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 4, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 4, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );

        } else if (tile_size == 8) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 8, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 8, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );

        } else if (tile_size == 12) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 12, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 12, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );

        } else if (tile_size == 16) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 16, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 16, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );

        } else if (tile_size == 24) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 24, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 24, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );

        } else if (tile_size == 32) {
            cudaFuncSetAttribute(linear_cuda_dkan_forward_kernel<scalar_t, dim, 32, cdf_grid>,
                                 cudaFuncAttributeMaxDynamicSharedMemorySize,
                                 shared_mem_size);

            linear_cuda_dkan_forward_kernel<scalar_t, dim, 32, cdf_grid><<<grid_dim, block_dim, shared_mem_size>>>(
                parameters.data_ptr<scalar_t>(),
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                n_tiles,
                n_chunks,
                n_params,
                batch_size
            );
        }
    }));

    return output;
}

template <int dim, bool cdf_grid>
torch::Tensor linear_gpu_dkan_forward_inner(
    torch::Tensor parameters,
    torch::Tensor x,
    int block_size
) {
    CHECK_INPUT(parameters);
    CHECK_INPUT(x);

    switch (block_size) {
        case 128:
            return linear_cuda_dkan_forward<dim, 128, cdf_grid>(parameters, x);
        case 256:
            return linear_cuda_dkan_forward<dim, 256, cdf_grid>(parameters, x);
        case 384:
            return linear_cuda_dkan_forward<dim, 384, cdf_grid>(parameters, x);
        case 512:
            return linear_cuda_dkan_forward<dim, 512, cdf_grid>(parameters, x);
        case 640:
            return linear_cuda_dkan_forward<dim, 640, cdf_grid>(parameters, x);
        case 768:
            return linear_cuda_dkan_forward<dim, 768, cdf_grid>(parameters, x);
        case 896:
            return linear_cuda_dkan_forward<dim, 896, cdf_grid>(parameters, x);
        case 1024:
            return linear_cuda_dkan_forward<dim, 1024, cdf_grid>(parameters, x);
        default:
            throw std::runtime_error(
                "Unsupported block size. Supported block sizes are {128, 256, 384, 512, 640, 768, 896, 1024}."
            );
    }
}


template <int dim>
torch::Tensor linear_gpu_dkan_forward(
    torch::Tensor parameters,
    torch::Tensor x,
    int block_size,
    bool cdf_grid
) {
    if (cdf_grid) {
        return linear_gpu_dkan_forward_inner<dim, true>(parameters, x, block_size);
    } else {
        return linear_gpu_dkan_forward_inner<dim, false>(parameters, x, block_size);
    }
}