
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void packbits_kernel(const uint8_t* __restrict__ input, uint8_t* __restrict__ output, int rows, int cols) {
    int inputIdx = blockIdx.x * blockDim.x * 8 + threadIdx.x * 8;
    int outputIdx = blockIdx.x * blockDim.x * 3 + threadIdx.x * 3;

    // Calculate the total number of input bytes and output bytes
    int totalInputBytes = rows * cols;
    int totalOutputBytes = rows * ((cols * 3) / 8);

    if (inputIdx + 7 < totalInputBytes && outputIdx + 2 < totalOutputBytes) {
        uint32_t packed_byte = 0;

        for (int i = 0; i < 8; ++i) {
            packed_byte |= uint32_t(input[inputIdx + i] & 7) << (21 - i * 3);
        }

        output[outputIdx] = uint8_t((packed_byte & 0x00FF0000) >> 16);
        output[outputIdx + 1] = uint8_t((packed_byte & 0x0000FF00) >> 8);
        output[outputIdx + 2] = uint8_t(packed_byte & 0x000000FF);
    }
}

torch::Tensor packbits_cuda(torch::Tensor input) {
    int rows = input.size(0);
    int cols = input.size(1);
    auto output = torch::zeros({rows, (cols * 3) / 8}, torch::dtype(torch::kUInt8).device(input.device()));

    int threads = 1024;
    int blocks = (rows * ((cols * 3) / 8) + threads - 1) / threads;

    packbits_kernel<<<blocks, threads>>>(input.data_ptr<uint8_t>(), output.data_ptr<uint8_t>(), rows, cols);

    return output;
}

__global__ void unpackbits_kernel(const uint8_t* __restrict__ input, uint8_t* __restrict__ output, int rows, int cols) {
    int inputIdx = blockIdx.x * blockDim.x * 3 + threadIdx.x * 3;
    int outputIdx = blockIdx.x * blockDim.x * 8 + threadIdx.x * 8;

    // Calculate the total number of input and output bytes
    int totalInputBytes = rows * ((cols * 3) / 8);
    int totalOutputBytes = rows * cols;

    if (inputIdx + 2 < totalInputBytes && outputIdx + 7 < totalOutputBytes) {
        uint32_t packed_bytes = 0;
        for (int i = 0; i < 3; ++i) {
            packed_bytes |= uint32_t(input[inputIdx + i]) << (8 * (2 - i));
        }

        for (int i = 0; i < 8; ++i) {
            output[outputIdx + i] = uint8_t(
                (packed_bytes >> (21 - (i * 3))) & 7
            );
        }
    }
}

torch::Tensor unpackbits_cuda(torch::Tensor input) {
    int rows = input.size(0);
    int cols = (input.size(1) * 8) / 3; // This assumes input tensor is in packed format

    auto output = torch::zeros({rows, cols}, torch::dtype(torch::kUInt8).device(input.device()));

    int threads = 256; // Use 256 threads per block for better efficiency
    int blocks = (rows * cols + threads - 1) / threads;

    unpackbits_kernel<<<blocks, threads>>>(input.data_ptr<uint8_t>(), output.data_ptr<uint8_t>(), rows, cols);
    cudaDeviceSynchronize(); // Ensure kernel execution is completed before returning the result

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("packbits_cuda", &packbits_cuda, "Pack bits (CUDA)");
    m.def("unpackbits_cuda", &unpackbits_cuda, "Unpack bits (CUDA)");
}
