#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

__global__ void domafilter_cuda_forward_kernel(
    const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> input,
    const torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> memory,
    const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> mappings,
    const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> keys,
    torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> output,
    const int tuple_lenght,
    const int num_keys,
    const int num_output) {
  
    const int i = blockIdx.y;

    for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < memory.size(0); j += blockDim.x * gridDim.x) {
        for(int k = 0; k < num_keys; ++k) {

            int addr = input[i][mappings[j][0]] * keys[k][0];
            for(int l = 1; l < tuple_lenght; ++l) {
                addr ^= input[i][mappings[j][l]] * keys[k][l];
            };

            for(int o = 0; o < num_output; ++o) {
                output[i][j][o] += memory[j][addr][o] > 0;
            };

        };
    };
};

torch::Tensor domafilter_cuda_forward(
    torch::Tensor input_tensor,
    torch::Tensor memory_tensor,
    torch::Tensor mappings_tensor,
    torch::Tensor keys_tensor) {
  
    auto batch_size = input_tensor.size(0);
    auto num_rams = memory_tensor.size(0);
    auto tuple_lenght = mappings_tensor.size(1);
    auto num_keys = keys_tensor.size(0);
    auto num_output = memory_tensor.size(2);

    auto output_tensor = torch::zeros({batch_size, num_rams, num_output}, 
    torch::dtype(torch::kFloat32).device(torch::kCUDA, input_tensor.device().index()));

    // const int threads_per_block = 1024;
    // const dim3 blocks_per_grid((num_rams + threads_per_block - 1) / threads_per_block, batch_size);

    // const dim3 threads_per_block(32, 32);
    // const dim3 blocks_per_grid(
    //     min(static_cast<int64_t>(65535), ceil_div(batch_size, static_cast<int64_t>(threads_per_block.x))),
    //     min(static_cast<int64_t>(65535), ceil_div(num_rams, static_cast<int64_t>(threads_per_block.y)))
    // );

    const int threads = 1024;
    const dim3 blocks((num_rams + threads - 1) / threads, batch_size);

    domafilter_cuda_forward_kernel<<<blocks, threads>>>(
        input_tensor.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
        memory_tensor.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        mappings_tensor.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
        keys_tensor.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
        output_tensor.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        tuple_lenght,
        num_keys,
        num_output
    );

    cudaDeviceSynchronize();

    return output_tensor;
};

__global__ void domafilter_cuda_backward_kernel(
    const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> input,
    const torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> memory,
    const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> mappings,
    const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> keys,
    const torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> output_grad,
    torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> memory_grad,
    const int batch_size,
    const int tuple_lenght,
    const int num_keys,
    const int num_output) {
  
    const int j = blockIdx.x * blockDim.x + threadIdx.x;

    if (j < memory.size(0)) {

        for(int i = 0; i < batch_size; ++i) { 
            for(int k = 0; k < num_keys; ++k) {
                
                int addr = input[i][mappings[j][0]] * keys[k][0];
                for(int l = 1; l < tuple_lenght; ++l) {
                    addr ^= input[i][mappings[j][l]] * keys[k][l];
                };

                for(int o = 0; o < num_output; ++o) {
                    memory_grad[j][addr][o] += output_grad[i][j][o];
                };

            };

        };
    
    };

};

torch::Tensor domafilter_cuda_backward(
    torch::Tensor input_tensor,
    torch::Tensor memory_tensor,
    torch::Tensor mappings_tensor,
    torch::Tensor keys_tensor,
    torch::Tensor output_grad_tensor) {
  
    auto batch_size = input_tensor.size(0);
    auto num_rams = memory_tensor.size(0);
    auto tuple_lenght = mappings_tensor.size(1);
    auto num_keys = keys_tensor.size(0);
    auto num_output = memory_tensor.size(2);

    auto memory_grad_tensor = torch::zeros_like(memory_tensor);

    const int threads = 1024;
    const int  blocks = (num_rams + threads - 1) / threads;

    domafilter_cuda_backward_kernel<<<blocks, threads>>>(
        input_tensor.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
        memory_tensor.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        mappings_tensor.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
        keys_tensor.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
        output_grad_tensor.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        memory_grad_tensor.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        batch_size,
        tuple_lenght,
        num_keys,
        num_output
    );

    cudaDeviceSynchronize();

    return memory_grad_tensor;
};
