#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
using namespace torch::indexing;

#include "../utilities.h"

template <int tile_size, bool cdf_grid>
__device__ void dkan_full_kernel_2d_thread_per_tile_only_computations(
    const float* __restrict__ shared_parameters,
    const float* __restrict__ x,
    float* __restrict__ output,
    int my_tile_in,
    int my_tile_out,
    int N_in,
    int N_out,
    int n_chunks,
    int batch_size
) {
    //static_assert(!cdf_grid, "cdf_grid = true is not implemented."); for now just silently ignore cdf_grid = true

    constexpr int half_tile_size = tile_size / 2;
    constexpr int half_size = half_tile_size * tile_size;
    static_assert(tile_size % 2 == 0, "tile_size must be even");

    int thread_idx = threadIdx.x;
    int block_size = blockDim.x;

    int tile_pos = thread_idx % half_size;

    int input_shift = my_tile_in * tile_size;
    int output_shift = my_tile_out * tile_size;

    int first_index, second_index;
    float x_delta_now_1, x_delta_now_2;

    float x_current[tile_size], x_current_next[tile_size];

    float output_current[tile_size];
    float param_1, param_2, param_3, param_4;
    float b_spline_1, b_spline_2, b_spline_3, b_spline_4;
    int chunk_index_now;
    int tile_pos_now;
    int chunk_index_precomputed;

    int full_index = thread_idx % half_size;
    int fast_index = full_index % tile_size;
    int slow_index = full_index / tile_size;

    int index_now;
    float tmp;

    int fast_index_now, slow_index_now;

    float inverse_chunk_size;
    float chunk_size;

    inverse_chunk_size = static_cast<float>(n_chunks) / static_cast<float>(2.0f);
    chunk_size = 2.0f / static_cast<float>(n_chunks);

    float inverse_batch_size = 1.0f / static_cast<float>(batch_size + 4096);
    #pragma unroll
    for (int j = 0; j < tile_size; ++j) {
        index_now = j + 2 * slow_index;
        if (index_now >= tile_size) {
            index_now -= tile_size;
        }
        /*if (thread_idx * N_in + input_shift + index_now < batch_size * N_in) {
            x_current_next[j] = x[thread_idx * N_in + input_shift + index_now];
            //x_current_next[j] = __ldcg(x + thread_idx * N_in + input_shift + index_now);
        }*/

        if ((input_shift + j) * batch_size + thread_idx < N_in * batch_size) {
            //x_current_next[j] = x[(input_shift + j) * batch_size + thread_idx];
            x_current_next[j] = static_cast<float>(thread_idx - index_now) * inverse_batch_size;
        }
    }

    #pragma unroll
    for (int j = 0; j < tile_size; ++j) {
        output_current[j] = 0.0f;
    }

    for (int i = thread_idx; i < batch_size; i += block_size) {
        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            x_current[j] = x_current_next[j];
        }

        #pragma unroll
        for (int j = 0; j < tile_size; ++j) {
            index_now = j + 2 * slow_index;
            if (index_now >= tile_size) {
                index_now -= tile_size;
            }
            /*if ((i + block_size) * N_in + input_shift + index_now < batch_size * N_in) {
                x_current_next[j] = x[(i + block_size) * N_in + input_shift + index_now];
                //x_current_next[j] = __ldcg(x + (i + block_size) * N_in + input_shift + index_now);
            }*/

            if ((input_shift + j) * batch_size + i + block_size < N_in * batch_size) {
                //x_current_next[j] = x[(input_shift + j) * batch_size + i + block_size];
                //x_current_next[j] = __ldcg(x + (input_shift + j) * batch_size + i + block_size);
                x_current_next[j] = static_cast<float>(i - index_now) * inverse_batch_size;
            }
        }

        #pragma unroll
        for (int j = 0; j < half_tile_size; ++j) {
            slow_index_now = j + slow_index;
            if (slow_index_now >= half_tile_size) {
                slow_index_now -= half_tile_size;
            }
            first_index = __float2int_rz((x_current[2 * j] + 1.0f) * inverse_chunk_size);
            second_index = __float2int_rz((x_current[2 * j + 1] + 1.0f) * inverse_chunk_size);

            x_delta_now_1 = __fmaf_rn(-first_index, chunk_size, x_current[2 * j]);
            x_delta_now_2 = __fmaf_rn(-second_index, chunk_size, x_current[2 * j + 1]);

            x_delta_now_1 += 1.0f;
            x_delta_now_2 += 1.0f;

            b_spline_1 = (chunk_size - x_delta_now_1) * (chunk_size - x_delta_now_2);
            b_spline_2 = (chunk_size - x_delta_now_1) * x_delta_now_2;
            b_spline_3 = x_delta_now_1 * (chunk_size - x_delta_now_2);
            b_spline_4 = x_delta_now_1 * x_delta_now_2;

            chunk_index_precomputed = first_index * (n_chunks + 1) * half_size + second_index * half_size + slow_index_now * tile_size;


            #pragma unroll
            for (int k = 0; k < tile_size; ++k) {

                fast_index_now = fast_index + k;
                if (fast_index_now >= tile_size) {
                    fast_index_now -= tile_size;
                }
                chunk_index_now = chunk_index_precomputed + fast_index_now;
                //chunk_index_now = first_index * (n_chunks + 1) * half_size + second_index * half_size + slow_index_now * tile_size + fast_index_now;
                param_1 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_2 = shared_parameters[chunk_index_now];
                chunk_index_now += n_chunks * half_size;
                param_3 = shared_parameters[chunk_index_now];
                chunk_index_now += half_size;
                param_4 = shared_parameters[chunk_index_now];

                output_current[k] = __fmaf_rn(param_1 , b_spline_1, output_current[k]);
                output_current[k] = __fmaf_rn(param_2 , b_spline_2, output_current[k]);
                output_current[k] = __fmaf_rn(param_3 , b_spline_3, output_current[k]);
                output_current[k] = __fmaf_rn(param_4 , b_spline_4, output_current[k]);

            }
        }
    }

    #pragma unroll
    for (int j = 0; j < tile_size; ++j) {
        index_now = fast_index + j;
        if (index_now >= tile_size) {
            index_now -= tile_size;
        }
        //output[i * N_out + output_shift + index_now] = output_current[j];
        //atomicAdd(&output[i * N_out + output_shift + index_now], output_current[j]);

        atomicAdd(&output[(output_shift + j) * batch_size + thread_idx], output_current[j]);
        /*asm volatile(
            "atom.global.add.f32 [%0], %1;"
            :
            : "l"(output + i * N_out + output_shift + index_now),
            "f"(output_current[j])
        );*/
    }
}