// #include <torch/extension.h>
// #include <ATen/ATen.h>
// #include <cuda_runtime.h>
// #include <vector>

// // 设备端内联函数，用于执行原子最大值更新
// __device__ inline void atomicMaxFloat(float* address, float val) {
//     int* address_as_int = (int*)address;
//     int old = *address_as_int, assumed;
//     do {
//         assumed = old;
//         old = atomicCAS(address_as_int, assumed,
//                         __float_as_int(fmaxf(val, __int_as_float(assumed))));
//     } while (assumed != old);
// }

// // CUDA 核函数用于并行更新最大值
// __global__ void atomicMaxKernel3D(float* data, long* x_indices, long* y_indices, long* z_indices, float* values, size_t N, int data_dim_x, int data_dim_y, int data_dim_z) {
//     int idx = blockIdx.x * blockDim.x + threadIdx.x;
//     if (idx < N) {
//         long index = x_indices[idx] * data_dim_y * data_dim_z + y_indices[idx] * data_dim_z + z_indices[idx];
//         if (index >= 0 && index < data_dim_x * data_dim_y * data_dim_z) {
//             atomicMaxFloat(&data[index], values[idx]);
//         } else {
//             printf("Out of bounds access at index %ld\n", index);
//         }
//     }
// }

// // C++ 接口函数，用于从PyTorch调用CUDA核函数
// void atomic_max_cuda(at::Tensor data, at::Tensor x_indices, at::Tensor y_indices, at::Tensor z_indices, at::Tensor values) {
//     auto data_ptr = data.data_ptr<float>();
//     auto x_indices_ptr = x_indices.data_ptr<long>();
//     auto y_indices_ptr = y_indices.data_ptr<long>();
//     auto z_indices_ptr = z_indices.data_ptr<long>();
//     auto values_ptr = values.data_ptr<float>();
//     auto N = values.numel();

//     // 获取步长参数
//     int data_dim_x = data.size(0);
//     int data_dim_y = data.size(1);
//     int data_dim_z = data.size(2);

//     // 启动核函数，传入步长参数
//     atomicMaxKernel3D<<<(N + 255) / 256, 256>>>(data_ptr, x_indices_ptr, y_indices_ptr, z_indices_ptr, values_ptr, N, data_dim_x, data_dim_y, data_dim_z);
//     cudaDeviceSynchronize();
// }

// // 注册PyTorch模块
// PYBIND11_MODULE(atomic_max_custom, m) {
//     m.def("atomic_max", &atomic_max_cuda, "Execute atomic max operation on CUDA with 3D indexing using floats.");
// }


#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>
#include <vector>

// 设备端内联函数，用于执行原子最大值更新
__device__ inline void atomicMaxFloat(double* address, double val) {
    unsigned long long* address_as_ull = (unsigned long long*) address;
    unsigned long long old = *address_as_ull, assumed;
    do {
        assumed = old;
        old = atomicCAS(address_as_ull, assumed,
                        __double_as_longlong(fmax(val, __longlong_as_double(assumed))));
    } while (assumed != old);
}

// CUDA 核函数用于并行更新最大值
__global__ void atomicMaxKernel3D(double* data, long* x_indices, long* y_indices, long* z_indices, double* values, size_t N, int data_dim_x, int data_dim_y, int data_dim_z) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < N) {
        long index = x_indices[idx] * data_dim_y * data_dim_z + y_indices[idx] * data_dim_z + z_indices[idx];
        if (index >= 0 && index < data_dim_x * data_dim_y * data_dim_z) {
            atomicMaxFloat(&data[index], values[idx]);
        } else {
            printf("Out of bounds access at index %ld\n", index);
        }
    }
}


// C++ 接口函数，用于从PyTorch调用CUDA核函数
void atomic_max_cuda(at::Tensor data, at::Tensor x_indices, at::Tensor y_indices, at::Tensor z_indices, at::Tensor values) {
    auto data_ptr = data.data_ptr<double>();
    auto x_indices_ptr = x_indices.data_ptr<long>();
    auto y_indices_ptr = y_indices.data_ptr<long>();
    auto z_indices_ptr = z_indices.data_ptr<long>();
    auto values_ptr = values.data_ptr<double>();
    auto N = values.numel();

    // 获取步长参数
    int data_dim_x = data.size(0);
    int data_dim_y = data.size(1);
    int data_dim_z = data.size(2);

    // 启动核函数，传入步长参数
    atomicMaxKernel3D<<<(N + 255) / 256, 256>>>(data_ptr, x_indices_ptr, y_indices_ptr, z_indices_ptr, values_ptr, N, data_dim_x, data_dim_y, data_dim_z);
    cudaDeviceSynchronize();
}

// 注册PyTorch模块
PYBIND11_MODULE(atomic_max_custom, m) {
    m.def("atomic_max", &atomic_max_cuda, "Execute atomic max operation on CUDA with 3D indexing.");
}
