#include <torch/extension.h>
#include <vector>

__global__ void roll_cuda_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int* __restrict__ shifts,
    int B, int C, int D, int H, int W) {

    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int totalElements = B * C * D * H * W;

    if (idx >= totalElements) return;

    int w = idx % W;
    int h = (idx / W) % H;
    int d = (idx / (W * H)) % D;
    int c = (idx / (W * H * D)) % C;
    int b = idx / (W * H * D * C);

    int shiftH = shifts[b * 2];
    int shiftW = shifts[b * 2 + 1];

    int rolled_w = (w + shiftW + W) % W;
    int rolled_h = (h + shiftH + H) % H;

    int output_idx = ((b * C + c) * D + d) * H * W + rolled_h * W + rolled_w;

    if (output_idx < totalElements) {
        output[output_idx] = input[idx];
    }
}

__global__ void unroll_cuda_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int* __restrict__ shifts,
    int B, int C, int D, int H, int W) {

    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int totalElements = B * C * D * H * W;

    if (idx >= totalElements) return;

    int w = idx % W;
    int h = (idx / W) % H;
    int d = (idx / (W * H)) % D;
    int c = (idx / (W * H * D)) % C;
    int b = idx / (W * H * D * C);

    int shiftH = shifts[b * 2];
    int shiftW = shifts[b * 2 + 1];

    int rolled_w = (w - shiftW + W) % W;
    int rolled_h = (h - shiftH + H) % H;

    int output_idx = ((b * C + c) * D + d) * H * W + rolled_h * W + rolled_w;

    if (output_idx < totalElements) {
        output[output_idx] = input[idx];
    }
}

torch::Tensor roll_cuda(
    torch::Tensor input,
    torch::Tensor shifts) {

    const int B = input.size(0);
    const int C = input.size(1);
    const int D = input.size(2);
    const int H = input.size(3);
    const int W = input.size(4);

    auto output = torch::zeros_like(input);

    const int threads = 1024;
    const int blocks = (B * C * D * H * W + threads - 1) / threads;

    roll_cuda_kernel<<<blocks, threads>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        shifts.data_ptr<int>(),
        B, C, D, H, W);

    return output;
}

torch::Tensor roll_backward_cuda(torch::Tensor grad_output, torch::Tensor shifts) {
    int B = grad_output.size(0);
    int C = grad_output.size(1);
    int D = grad_output.size(2);
    int H = grad_output.size(3);
    int W = grad_output.size(4);

    auto grad_input = torch::zeros_like(grad_output);

    int threads = 1024;
    int blocks = (B * C * D * H * W + threads - 1) / threads;

    unroll_cuda_kernel<<<blocks, threads>>>(
    grad_output.data_ptr<float>(), 
    grad_input.data_ptr<float>(), 
    shifts.data_ptr<int>(), 
    B, C, D, H, W);

    return grad_input;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("roll_cuda", &roll_cuda, "Roll operation with CUDA");
    m.def("roll_backward_cuda", &roll_backward_cuda, "Roll backward operation with CUDA");
}