#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
using namespace torch::indexing;

#include "../utilities.h"

template <int tile_size, bool cdf_grid>
__device__ void dkan_full_kernel_2d_thread_per_four_columns_backward(
    const float* __restrict__ shared_parameters,
    const float* __restrict__ x,
    const float* __restrict__ output_grad,
    float* __restrict__ shared_parameters_grad,
    float* __restrict__ x_grad,
    int my_tile_in,
    int my_tile_out,
    int N_in,
    int N_out,
    int n_chunks,
    int batch_size
) {

    constexpr int half_tile_size = tile_size / 2;
    constexpr int quarter_tile_size = tile_size / 4;
    constexpr int quarter_tile_size_m_one = quarter_tile_size - 1;

    constexpr int quarter_tile_size_log = []() constexpr {
        if constexpr (tile_size == 4) {
            return 0;
        } else if constexpr (tile_size == 8) {
            return 1;
        } else if constexpr (tile_size == 16) {
            return 2;
        } else if constexpr (tile_size == 32) {
            return 3;
        } else {
            static_assert(tile_size == 4 || tile_size == 8 || tile_size == 16 || tile_size == 32,
                        "Unsupported tile size");
            return -1; // This line will never be reached.
        }
    }();

    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    constexpr int job_size = 2 * tile_size;
    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int input_offset = my_tile_in * quarter_tile_size;
    int output_offset = my_tile_out * quarter_tile_size;

    int quarter_N_in = N_in / 4;
    int quarter_N_out = N_out / 4;

    float inverse_chunk_size = static_cast<float>(n_chunks) / static_cast<float>(2.0f);
    float chunk_size = 2.0f / static_cast<float>(n_chunks);

    int job_index = thread_idx & quarter_tile_size_m_one;

    const float4* __restrict__ x_float4 = reinterpret_cast<const float4*>(x);
    const float4* __restrict__ output_grad_float4 = reinterpret_cast<const float4*>(output_grad);
    float4* __restrict__ x_grad_float4 = reinterpret_cast<float4*>(x_grad);

    int sample_index = thread_idx / quarter_tile_size;

    float4 x_element_next_full;
    if (sample_index < batch_size) {
        x_element_next_full = __ldcg(x_float4 + sample_index * quarter_N_in + input_offset + job_index);
    }
    float4 output_grad_element_next_full;
    if (sample_index < batch_size) {
        output_grad_element_next_full = __ldcg(output_grad_float4 + sample_index * quarter_N_out + output_offset + job_index);
    }

    float4 x_grad_element_full;

    float x_elements[4];
    float output_grad_elements[4];
    float x_elements_grad[4];

    float db1_dx1, db1_dx2, db2_dx1, db2_dx2, db3_dx1, db3_dx2, db4_dx1, db4_dx2;
    float param_1, param_2, param_3, param_4;
    float b_spline_1, b_spline_2, b_spline_3, b_spline_4;

    float x_element_1, x_element_2;

    float x_delta_now_1, x_delta_now_2;

    int first_index, second_index;

    float x_element_1_grad, x_element_2_grad;

    float output_grad_now;

    int chunk_index_now, chunk_index_common;

    int full_index = thread_idx;

    int my_source_start_index = full_index & quarter_tile_size_m_one;
    full_index = full_index >> quarter_tile_size_log;
    int my_column_start_index = full_index & 3;
    full_index = full_index >> 2;

    int my_pair_start_index = full_index & 1;

    int source_index, pair_index_full;
    float tmp;

    float x_1_to_left, x_1_to_right, x_2_to_left, x_2_to_right;
    float left_border, right_border;
    float tanh_x1, tanh_x2;
    float right_border_next, left_border_prev;
    float factor;

    for (int i = thread_idx; i < batch_size * quarter_tile_size; i += block_size) {

        x_elements[0] = x_element_next_full.x;
        x_elements[1] = x_element_next_full.y;
        x_elements[2] = x_element_next_full.z;
        x_elements[3] = x_element_next_full.w;

        if (my_column_start_index == 0) {
            output_grad_elements[0] = output_grad_element_next_full.x;
            output_grad_elements[1] = output_grad_element_next_full.y;
            output_grad_elements[2] = output_grad_element_next_full.z;
            output_grad_elements[3] = output_grad_element_next_full.w;
        } else if (my_column_start_index == 1) {
            output_grad_elements[0] = output_grad_element_next_full.y;
            output_grad_elements[1] = output_grad_element_next_full.z;
            output_grad_elements[2] = output_grad_element_next_full.w;
            output_grad_elements[3] = output_grad_element_next_full.x;
        } else if (my_column_start_index == 2) {
            output_grad_elements[0] = output_grad_element_next_full.z;
            output_grad_elements[1] = output_grad_element_next_full.w;
            output_grad_elements[2] = output_grad_element_next_full.x;
            output_grad_elements[3] = output_grad_element_next_full.y;
        } else {
            output_grad_elements[0] = output_grad_element_next_full.w;
            output_grad_elements[1] = output_grad_element_next_full.x;
            output_grad_elements[2] = output_grad_element_next_full.y;
            output_grad_elements[3] = output_grad_element_next_full.z;
        }

        if ((i + block_size) < batch_size * quarter_tile_size) {
            sample_index = (i + block_size) >> quarter_tile_size_log;

            x_element_next_full = __ldcg(x_float4 + sample_index * quarter_N_in + input_offset + job_index);
            output_grad_element_next_full = __ldcg(output_grad_float4 + sample_index * quarter_N_out + output_offset + job_index);
        }

        if (my_pair_start_index == 1) {
            tmp = x_elements[0];
            x_elements[0] = x_elements[2];
            x_elements[2] = tmp;

            tmp = x_elements[1];
            x_elements[1] = x_elements[3];
            x_elements[3] = tmp;
        }
        #pragma unroll
        for (int pair_index = 0; pair_index < 2; ++pair_index) {
            pair_index_full = (pair_index + my_pair_start_index) & 1;

            if constexpr (cdf_grid) {
                x_element_1 = x_elements[2 * pair_index];
                x_element_2 = x_elements[2 * pair_index + 1];

                // Initialize gradient accumulators
                x_element_1_grad = 0.0f;
                x_element_2_grad = 0.0f;

                // Compute tanh transforms used for index selection
                tanh_x1 = tanhf(x_element_1);
                tanh_x2 = tanhf(x_element_2);

                // Compute the grid indices based on the tanh-transformed values
                first_index = __float2int_rz((tanh_x1 + 1.0f) * inverse_chunk_size);
                second_index = __float2int_rz((tanh_x2 + 1.0f) * inverse_chunk_size);

                // Clamp indices to valid range
                first_index = max(0, min(first_index, n_chunks - 1));
                second_index = max(0, min(second_index, n_chunks - 1));

                // For x_1:
                if (first_index == 0) {
                    right_border = atanhf(chunk_size - 1.0f);
                    right_border_next = atanhf(2.0f * chunk_size - 1.0f);
                    left_border = right_border - (right_border_next - right_border);
                } else if (first_index == n_chunks - 1) {
                    left_border = atanhf(first_index * chunk_size - 1.0f);
                    left_border_prev = atanhf((first_index - 1) * chunk_size - 1.0f);
                    right_border = left_border + (left_border - left_border_prev);
                } else {
                    left_border = atanhf(first_index * chunk_size - 1.0f);
                    right_border = atanhf((first_index + 1) * chunk_size - 1.0f);
                }

                x_1_to_left = x_element_1 - left_border;
                x_1_to_right = right_border - x_element_1;

                factor = (right_border - left_border);
                // For x_2:
                if (second_index == 0) {
                    right_border = atanhf(chunk_size - 1.0f);
                    right_border_next = atanhf(2.0f * chunk_size - 1.0f);
                    left_border = right_border - (right_border_next - right_border);
                } else if (second_index == n_chunks - 1) {
                    left_border = atanhf(second_index * chunk_size - 1.0f);
                    left_border_prev = atanhf((second_index - 1) * chunk_size - 1.0f);
                    right_border = left_border + (left_border - left_border_prev);
                } else {
                    left_border = atanhf(second_index * chunk_size - 1.0f);
                    right_border = atanhf((second_index + 1) * chunk_size - 1.0f);
                }

                x_2_to_left = x_element_2 - left_border;
                x_2_to_right = right_border - x_element_2;

                factor = 1.0f / (factor * (right_border - left_border));
                // Compute the b-spline values (similar to forward pass)
                b_spline_1 = x_1_to_right * x_2_to_right * factor;
                b_spline_2 = x_1_to_right * x_2_to_left * factor;
                b_spline_3 = x_1_to_left  * x_2_to_right * factor;
                b_spline_4 = x_1_to_left  * x_2_to_left * factor;

                // Compute the derivatives of each b-spline coefficient w.r.t. x_element_1 and x_element_2.
                // Note: We assume that the borders (computed via atanhf) are constant with respect to x.
                //
                // For b_spline_1 = x_1_to_right * x_2_to_right:
                //   ∂x_1_to_right/∂x_element_1 = -1   and   ∂x_2_to_right/∂x_element_2 = -1.
                db1_dx1 = -x_2_to_right * factor;  // d(b_spline_1)/dx_element_1
                db1_dx2 = -x_1_to_right * factor;  // d(b_spline_1)/dx_element_2

                // For b_spline_2 = x_1_to_right * x_2_to_left:
                //   ∂x_1_to_right/∂x_element_1 = -1   and   ∂x_2_to_left/∂x_element_2 = 1.
                db2_dx1 = -x_2_to_left * factor;
                db2_dx2 =  x_1_to_right * factor;

                // For b_spline_3 = x_1_to_left * x_2_to_right:
                //   ∂x_1_to_left/∂x_element_1 = 1   and   ∂x_2_to_right/∂x_element_2 = -1.
                db3_dx1 =  x_2_to_right * factor;
                db3_dx2 = -x_1_to_left * factor;

                // For b_spline_4 = x_1_to_left * x_2_to_left:
                //   ∂x_1_to_left/∂x_element_1 = 1   and   ∂x_2_to_left/∂x_element_2 = 1.
                db4_dx1 =  x_2_to_left * factor;
                db4_dx2 =  x_1_to_left * factor;
            } else {
                x_element_1 = x_elements[2 * pair_index] + 1.0f;
                x_element_2 = x_elements[2 * pair_index + 1] + 1.0f;

                x_element_1_grad = 0.0f;
                x_element_2_grad = 0.0f;

                first_index = __float2int_rz((x_element_1) * inverse_chunk_size);
                second_index = __float2int_rz((x_element_2) * inverse_chunk_size);

                first_index = max(0, min(first_index, n_chunks - 1));
                second_index = max(0, min(second_index, n_chunks - 1));

                x_delta_now_1 = fmaf(-first_index, chunk_size, x_element_1);
                x_delta_now_2 = fmaf(-second_index, chunk_size, x_element_2);

                b_spline_1 = (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
                b_spline_2 = (chunk_size - x_delta_now_1) * x_delta_now_2;
                b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2);
                b_spline_4 = x_delta_now_1 * x_delta_now_2;

                // Compute derivatives of b_spline_1
                db1_dx1 = -1.0f * (chunk_size - x_delta_now_2);  // ∂/∂x_delta_now_1 of b_spline_1
                db1_dx2 = -1.0f * (chunk_size - x_delta_now_1);  // ∂/∂x_delta_now_2 of b_spline_1

                // Compute derivatives of b_spline_2
                db2_dx1 = -1.0f * x_delta_now_2;                 // ∂/∂x_delta_now_1 of b_spline_2
                db2_dx2 = (chunk_size - x_delta_now_1);            // ∂/∂x_delta_now_2 of b_spline_2

                // Compute derivatives of b_spline_3
                // b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2)
                db3_dx1 = (chunk_size - x_delta_now_2);            // ∂/∂x_delta_now_1 of b_spline_3
                db3_dx2 = -1.0f * x_delta_now_1;                   // ∂/∂x_delta_now_2 of b_spline_3

                // Compute derivatives of b_spline_4
                // b_spline_4 = x_delta_now_1 * x_delta_now_2
                db4_dx1 = x_delta_now_2;                           // ∂/∂x_delta_now_1 of b_spline_4
                db4_dx2 = x_delta_now_1;
            }

            chunk_index_common = first_index * (n_chunks + 1) * half_size + second_index * half_size;

            #pragma unroll
            for (int column_index = 0; column_index < 4; ++column_index) {
                #pragma unroll
                for (int j = 0; j < quarter_tile_size; ++j) {
                    source_index = (j + my_source_start_index) & quarter_tile_size_m_one;
                    output_grad_now = __shfl_sync(0xffffffff, output_grad_elements[column_index], source_index, quarter_tile_size);
                    chunk_index_now = chunk_index_common + (source_index * 4 + ((column_index + my_column_start_index) & 3)) * half_tile_size + job_index * 2 + pair_index_full;

                    param_1 = shared_parameters[chunk_index_now];
                    atomicAdd(shared_parameters_grad + chunk_index_now, output_grad_now * b_spline_1);
                    chunk_index_now += half_size;
                    param_2 = shared_parameters[chunk_index_now];
                    atomicAdd(shared_parameters_grad + chunk_index_now, output_grad_now * b_spline_2);
                    chunk_index_now += n_chunks * half_size;
                    param_3 = shared_parameters[chunk_index_now];
                    atomicAdd(shared_parameters_grad + chunk_index_now, output_grad_now * b_spline_3);
                    chunk_index_now += half_size;
                    param_4 = shared_parameters[chunk_index_now];
                    atomicAdd(shared_parameters_grad + chunk_index_now, output_grad_now * b_spline_4);

                    x_element_1_grad += output_grad_now * (db1_dx1 * param_1 + db2_dx1 * param_2 + db3_dx1 * param_3 + db4_dx1 * param_4);
                    x_element_2_grad += output_grad_now * (db1_dx2 * param_1 + db2_dx2 * param_2 + db3_dx2 * param_3 + db4_dx2 * param_4);
                }
            }
            x_elements_grad[2 * pair_index] = x_element_1_grad;
            x_elements_grad[2 * pair_index + 1] = x_element_2_grad;
        }

        if (my_pair_start_index == 1) {
            tmp = x_elements_grad[0];
            x_elements_grad[0] = x_elements_grad[2];
            x_elements_grad[2] = tmp;

            tmp = x_elements_grad[1];
            x_elements_grad[1] = x_elements_grad[3];
            x_elements_grad[3] = tmp;
        }
        sample_index = i >> quarter_tile_size_log;

        x_grad_element_full.x = x_elements_grad[0];
        x_grad_element_full.y = x_elements_grad[1];
        x_grad_element_full.z = x_elements_grad[2];
        x_grad_element_full.w = x_elements_grad[3];

        if (sample_index < batch_size) {
            atomicAdd(x_grad_float4 + sample_index * quarter_N_in + input_offset + job_index, x_grad_element_full);
        }
    }
    __syncthreads();
}