#include <torch/extension.h>

#include <vector>
#include <iostream>

#include <ATen/ATen.h>

// #include "thread_pool.hpp"

// thread_pool pool(std::thread::hardware_concurrency());

torch::Tensor forward(
    torch::Tensor input_tensor, 
    torch::Tensor memory_tensor,
    torch::Tensor mappings_tensor,
    torch::Tensor keys_tensor) {

    auto input = input_tensor.accessor<bool, 2>();
    auto memory = memory_tensor.accessor<float, 3>();
    auto mappings = mappings_tensor.accessor<int, 2>();
    auto keys = keys_tensor.accessor<int64_t, 2>();

    auto num_samples = input_tensor.size(0);
    auto num_rams = memory_tensor.size(0);
    auto tuple_lenght = mappings_tensor.size(1);
    auto num_keys = keys_tensor.size(0);
    auto num_output = memory_tensor.size(2);

    auto output_tensor = torch::zeros({num_samples, num_rams, num_output}, torch::kFloat32);
    auto output = output_tensor.accessor<float, 3>();

    // pool.parallelize_loop(0, num_samples, [&](int64_t start, int64_t end) {

    //     uint64_t addr;

    //     int a = 0;

    //     for(int i = start; i < end; ++i) {
    //         for(int j = 0; j < num_rams; ++j) { 
    //             for(int o = 0; o < num_output; ++o) {
    //                 for(int k = 0; k < num_keys; ++k) {
                        
    //                     addr = input[i][mappings[j][0]] * keys[k][0];
    //                     for(int l = 1; l < tuple_lenght; ++l) {
    //                         addr ^= input[i][mappings[j][l]] * keys[k][l];
    //                     };

    //                     output[i][j][o] += memory[j][addr][o] > 0;
                        
    //                 };
    //             };
    //         };
    //     };
     
    //  });

    for(int i = 0; i < num_samples; ++i) {
        for(int j = 0; j < num_rams; ++j) { 
            for(int k = 0; k < num_keys; ++k) {
                    
                uint64_t addr = input[i][mappings[j][0]] * keys[k][0];
                for(int l = 1; l < tuple_lenght; ++l) {
                    addr ^= input[i][mappings[j][l]] * keys[k][l];
                };

                for(int o = 0; o < num_output; ++o) {
                    output[i][j][o] += memory[j][addr][o] > 0;
                };

            };
        };
    };

    return output_tensor;


};

torch::Tensor backward(
    torch::Tensor input_tensor, 
    torch::Tensor memory_tensor,
    torch::Tensor mappings_tensor,
    torch::Tensor keys_tensor,
    torch::Tensor output_grad_tensor) {

    auto input = input_tensor.accessor<bool, 2>();
    auto memory = memory_tensor.accessor<float, 3>();
    auto mappings = mappings_tensor.accessor<int, 2>();
    auto keys = keys_tensor.accessor<int64_t, 2>();
    auto output_grad = output_grad_tensor.accessor<float, 3>();

    auto num_samples = input_tensor.size(0);
    auto num_rams = memory_tensor.size(0);
    auto tuple_lenght = mappings_tensor.size(1);
    auto num_keys = keys_tensor.size(0);
    auto num_output = memory_tensor.size(2);

    auto memory_grad_tensor = torch::zeros_like(memory_tensor);
    auto memory_grad = memory_grad_tensor.accessor<float, 3>();

    for(int i = 0; i < num_samples; ++i) {
        for(int j = 0; j < num_rams; ++j) { 
            for(int k = 0; k < num_keys; ++k) {
                    
                uint64_t addr = input[i][mappings[j][0]] * keys[k][0];
                for(int l = 1; l < tuple_lenght; ++l) {
                    addr ^= input[i][mappings[j][l]] * keys[k][l];
                };

                for(int o = 0; o < num_output; ++o) {
                    memory_grad[j][addr][o] += output_grad[i][j][o];
                };
            };
        };
    };

    return memory_grad_tensor;


};

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &forward, "DoMaFilter CPU forward");
  m.def("backward", &backward, "DoMaFilter CPU backward");
}