#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
using namespace torch::indexing;

#include "../utilities.h"

template <int tile_size, bool cdf_grid>
__device__ void dkan_full_kernel_2d_thread_per_four_rows(
    const float* __restrict__ shared_parameters,
    const float* __restrict__ x,
    float* __restrict__ output,
    int my_tile_in,
    int my_tile_out,
    int N_in,
    int N_out,
    int n_chunks,
    int batch_size
) {
    // half baked, for benchmarking purposes
    constexpr int half_tile_size = tile_size / 2;
    constexpr int quarter_tile_size = tile_size / 4;
    constexpr int quarter_tile_size_m_one = quarter_tile_size - 1;

    constexpr int quarter_tile_size_log = []() constexpr {
        if constexpr (tile_size == 4) {
            return 0;
        } else if constexpr (tile_size == 8) {
            return 1;
        } else if constexpr (tile_size == 16) {
            return 2;
        } else if constexpr (tile_size == 32) {
            return 3;
        } else {
            static_assert(tile_size == 4 || tile_size == 8 || tile_size == 16 || tile_size == 32,
                        "Unsupported tile size");
            return -1; // This line will never be reached.
        }
    }();

    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    constexpr int job_size = 2 * tile_size;
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int input_offset = my_tile_in * quarter_tile_size;
    int output_offset = my_tile_out * quarter_tile_size;

    int quarter_N_in = N_in / 4;
    int quarter_N_out = N_out / 4;

    int sample_index;

    int tile_idx = thread_idx / tile_size;
    int element_idx = thread_idx % tile_size;

    int offset = tile_idx * tile_size;

    float inverse_chunk_size = static_cast<float>(n_chunks) / static_cast<float>(2.0f);
    float chunk_size = 2.0f / static_cast<float>(n_chunks);

    float temp_first = 0.0f;
    float temp_second = 0.0f;
    float temp_third = 0.0f;
    float temp_fourth = 0.0f;
    float tmp;

    float x_element_first, x_element_second, x_element_third, x_element_fourth;

    float x_delta_now_1, x_delta_now_2;
    int chunk_index_now;
    int first_index, second_index;
    int tile_pos = thread_idx % half_size;
    float x_element_1, x_element_2;

    const float4* __restrict__ x_float4 = reinterpret_cast<const float4*>(x);
    float4* __restrict__ output_float4 = reinterpret_cast<float4*>(output);

    int job_index = thread_idx & quarter_tile_size_m_one;

    sample_index = thread_idx / quarter_tile_size;

    float4 x_element_next_full;
    if (sample_index < batch_size) {
        x_element_next_full = __ldcg(x_float4 + sample_index * quarter_N_in + input_offset + job_index);
    }

    float x_element_first_next = x_element_next_full.x;
    float x_element_second_next = x_element_next_full.y;
    float x_element_third_next = x_element_next_full.z;
    float x_element_fourth_next = x_element_next_full.w;

    float4 output_full;

    float param_1, param_2, param_3, param_4;
    float b_spline_1, b_spline_2, b_spline_3, b_spline_4;

    int source_index;
    int chunk_index_common;


    int full_index = thread_idx;

    if constexpr (tile_size == 4) {
        full_index = full_index >> 2;
    }
    if constexpr (tile_size == 8) {
        full_index = full_index >> 1;
    }

    int my_swap_index = full_index & 1;
    full_index = full_index >> 1;

    int my_input_shift_index = full_index & quarter_tile_size_m_one;
    full_index = full_index >> quarter_tile_size_log;

    int my_output_index = (full_index & 3) * half_tile_size;
    full_index = full_index >> 2;

    constexpr int output_index_range = 2 * tile_size - 1;
    float x_1_to_left, x_1_to_right, x_2_to_left, x_2_to_right;
    float left_border, right_border;
    float right_border_next, left_border_prev;
    float factor;

    for (int i = thread_idx; i < batch_size * quarter_tile_size; i += block_size) {

        x_element_first = x_element_next_full.x;
        x_element_second = x_element_next_full.y;
        x_element_third = x_element_next_full.z;
        x_element_fourth = x_element_next_full.w;

        if ((i + block_size) < batch_size * quarter_tile_size) {
            //x_element_next = x_row[i + block_size];
            sample_index = (i + block_size) >> quarter_tile_size_log;

            x_element_next_full = __ldcg(x_float4 + sample_index * quarter_N_in + input_offset + job_index);
            //x_element_next_full = x_float4[sample_index * quarter_N_in + input_offset + job_index];
        }

        temp_first = 0.0f;
        temp_second = 0.0f;
        temp_third = 0.0f;
        temp_fourth = 0.0f;

        source_index= my_input_shift_index;
        if (my_swap_index == 1) {
            tmp = x_element_first;
            x_element_first = x_element_third;
            x_element_third = tmp;

            tmp = x_element_second;
            x_element_second = x_element_fourth;
            x_element_fourth = tmp;
        }

        #pragma unroll
        for (int pair_index = 0; pair_index < 2; ++pair_index){

            #pragma unroll
            for (int j = 0; j < quarter_tile_size; ++j) {
                if constexpr (cdf_grid) {
                    x_element_1 = __shfl_sync(0xffffffff, x_element_first, source_index, quarter_tile_size);
                    x_element_2 = __shfl_sync(0xffffffff, x_element_second, source_index, quarter_tile_size);

                    first_index = __float2int_rz((tanhf(x_element_1) + 1.0f) * inverse_chunk_size);
                    second_index = __float2int_rz((tanhf(x_element_2) + 1.0f) * inverse_chunk_size);

                    first_index = max(0, min(first_index, n_chunks - 1));
                    second_index = max(0, min(second_index, n_chunks - 1));

                    // For x_1:
                    if (first_index == 0) {
                        right_border = atanhf(chunk_size - 1.0f);
                        right_border_next = atanhf(2.0f * chunk_size - 1.0f);
                        left_border = right_border - (right_border_next - right_border);
                    } else if (first_index == n_chunks - 1) {
                        left_border = atanhf(first_index * chunk_size - 1.0f);
                        left_border_prev = atanhf((first_index - 1) * chunk_size - 1.0f);
                        right_border = left_border + (left_border - left_border_prev);
                    } else {
                        left_border = atanhf(first_index * chunk_size - 1.0f);
                        right_border = atanhf((first_index + 1) * chunk_size - 1.0f);
                    }

                    x_1_to_left = x_element_1 - left_border;
                    x_1_to_right = right_border - x_element_1;
                    factor = right_border - left_border;

                    // For x_2:
                    if (second_index == 0) {
                        right_border = atanhf(chunk_size - 1.0f);
                        right_border_next = atanhf(2.0f * chunk_size - 1.0f);
                        left_border = right_border - (right_border_next - right_border);
                    } else if (second_index == n_chunks - 1) {
                        left_border = atanhf(second_index * chunk_size - 1.0f);
                        left_border_prev = atanhf((second_index - 1) * chunk_size - 1.0f);
                        right_border = left_border + (left_border - left_border_prev);
                    } else {
                        left_border = atanhf(second_index * chunk_size - 1.0f);
                        right_border = atanhf((second_index + 1) * chunk_size - 1.0f);
                    }

                    x_2_to_left = x_element_2 - left_border;
                    x_2_to_right = right_border - x_element_2;

                    factor = 1.0f / (factor * (right_border - left_border));

                    b_spline_1 = x_1_to_right * x_2_to_right * factor;
                    b_spline_2 = x_1_to_right * x_2_to_left * factor;
                    b_spline_3 = x_1_to_left * x_2_to_right * factor;
                    b_spline_4 = x_1_to_left * x_2_to_left * factor;
                } else {
                    x_element_1 = __shfl_sync(0xffffffff, x_element_first, source_index, quarter_tile_size) + 1.0f;
                    x_element_2 = __shfl_sync(0xffffffff, x_element_second, source_index, quarter_tile_size) + 1.0f;

                    first_index = __float2int_rz((x_element_1) * inverse_chunk_size);
                    second_index = __float2int_rz((x_element_2) * inverse_chunk_size);

                    first_index = max(0, min(first_index, n_chunks - 1));
                    second_index = max(0, min(second_index, n_chunks - 1));

                    x_delta_now_1 = fmaf(-first_index, chunk_size, x_element_1);
                    x_delta_now_2 = fmaf(-second_index, chunk_size, x_element_2);

                    b_spline_1 = (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
                    b_spline_2 = (chunk_size - x_delta_now_1) * x_delta_now_2;
                    b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2);
                    b_spline_4 = x_delta_now_1 * x_delta_now_2;
                }

                chunk_index_common =  first_index * (n_chunks + 1) * half_size + second_index * half_size;
                //chunk_index_common += ((source_index << 1) + (pair_index + my_swap_index) & 1) * tile_size + (job_index << 2);
                //chunk_index_common += tile_pos - my_output_index;
                //chunk_index_common += (4 * job_index + my_output_index) * half_tile_size + 2 * source_index + ((pair_index + my_swap_index) % 2);
                chunk_index_common += (job_index << 2) * half_tile_size + (source_index << 1) + ((pair_index + my_swap_index) & 1);

                chunk_index_now = chunk_index_common + my_output_index;

                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                my_output_index = (my_output_index + half_tile_size) & output_index_range;

                temp_first = __fmaf_rn(param_1 , b_spline_1, temp_first);
                temp_first = __fmaf_rn(param_2 , b_spline_2, temp_first);
                temp_first = __fmaf_rn(param_3 , b_spline_3, temp_first);
                temp_first = __fmaf_rn(param_4 , b_spline_4, temp_first);

                chunk_index_now = chunk_index_common + my_output_index;

                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                my_output_index = (my_output_index + half_tile_size) & output_index_range;

                temp_second = __fmaf_rn(param_1 , b_spline_1, temp_second);
                temp_second = __fmaf_rn(param_2 , b_spline_2, temp_second);
                temp_second = __fmaf_rn(param_3 , b_spline_3, temp_second);
                temp_second = __fmaf_rn(param_4 , b_spline_4, temp_second);

                chunk_index_now = chunk_index_common + my_output_index;

                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                my_output_index = (my_output_index + half_tile_size) & output_index_range;

                temp_third = __fmaf_rn(param_1 , b_spline_1, temp_third);
                temp_third = __fmaf_rn(param_2 , b_spline_2, temp_third);
                temp_third = __fmaf_rn(param_3 , b_spline_3, temp_third);
                temp_third = __fmaf_rn(param_4 , b_spline_4, temp_third);

                chunk_index_now = chunk_index_common + my_output_index;

                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                my_output_index = (my_output_index + half_tile_size) & output_index_range;

                temp_fourth = __fmaf_rn(param_1 , b_spline_1, temp_fourth);
                temp_fourth = __fmaf_rn(param_2 , b_spline_2, temp_fourth);
                temp_fourth = __fmaf_rn(param_3 , b_spline_3, temp_fourth);
                temp_fourth = __fmaf_rn(param_4 , b_spline_4, temp_fourth);

                source_index = (source_index + 1) & quarter_tile_size_m_one;
            }
            x_element_first = x_element_third;
            x_element_second = x_element_fourth;
        }

        if (my_output_index == 0) {
            output_full.x = temp_first;
            output_full.y = temp_second;
            output_full.z = temp_third;
            output_full.w = temp_fourth;
        } else if (my_output_index == half_tile_size) {
            output_full.y = temp_first;
            output_full.z = temp_second;
            output_full.w = temp_third;
            output_full.x = temp_fourth;
        } else if (my_output_index == 2 * half_tile_size) {
            output_full.z = temp_first;
            output_full.w = temp_second;
            output_full.x = temp_third;
            output_full.y = temp_fourth;
        } else {
            output_full.w = temp_first;
            output_full.x = temp_second;
            output_full.y = temp_third;
            output_full.z = temp_fourth;
        }

        sample_index = i >> quarter_tile_size_log;
        //output_float4[sample_index * quarter_N_out + output_offset + job_index] = output_full;
        //atomicAdd(&output[(output_shift + j) * batch_size + thread_idx], output_current[j]);
        /*output_full.x = 1.0f;
        output_full.y = 1.0f;
        output_full.z = 1.0f;
        output_full.w = 1.0f;*/

        atomicAdd(&output_float4[sample_index * quarter_N_out + output_offset + job_index], output_full);
        //atomicAdd(&output_float4[job_index], output_full);
    }
}