#include <torch/extension.h>
#include <ATen/cuda/CUDAUtils.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

template <typename scalar_t>
__global__ void dimmedian_idx_cuda_forward_kernel(
    const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> csrRowPtr,
    const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> rowIdx,
    const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> colIdx,
    const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> value,
    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> X_argsort,
    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> X_rev_argsort,
    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> outIdx,
    const int64_t m,
    const int64_t d)
{
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
    const int row = idx / d;
    const int dim = idx % d;

    if ((row >= m) || (dim >= d))
    {
        return;
    }

    // Check if row has any neighbors
    if (csrRowPtr[row] >= csrRowPtr[row + 1])
    {
        outIdx[row][dim] = row; // Self-reference if no neighbors
        return;
    }

    // Calculate total weight for this row
    float weightSum = 0;
    for (int i = csrRowPtr[row]; i < csrRowPtr[row + 1]; i++)
    {
        weightSum += value[i];
    }

    // Handle edge case: if total weight is 0
    if (weightSum <= 0)
    {
        outIdx[row][dim] = colIdx[csrRowPtr[row]]; // Use first neighbor
        return;
    }

    // Find weighted median - use original algorithm logic
    int lastOrderIdx = -1;
    float cumSum = 0;
    int selectedElement = -1;
    
    // Use the original algorithm: iterate until cumSum >= weightSum/2
    while (cumSum < weightSum / 2.0f)
    {
        int currOrderIdx = INT_MAX;
        int currentElement = -1;
        
        // Find the next smallest element (by order index) that hasn't been processed
        for (int j = csrRowPtr[row]; j < csrRowPtr[row + 1]; j++)
        {
            int tempOrderIdx = X_rev_argsort[colIdx[j]][dim];
            if ((tempOrderIdx < currOrderIdx) && (tempOrderIdx > lastOrderIdx))
            {
                currOrderIdx = tempOrderIdx;
                currentElement = j;
            }
        }
        
        // If we found a valid element, update state
        if (currentElement >= 0)
        {
            lastOrderIdx = currOrderIdx;
            cumSum += value[currentElement];
            selectedElement = currentElement;
        }
        else
        {
            // No more elements to process
            break;
        }
    }
    
    // Set output - always guarantee a valid result
    if (selectedElement >= 0 && selectedElement < csrRowPtr[row + 1])
    {
        outIdx[row][dim] = colIdx[selectedElement];
    }
    else
    {
        // Fallback: use the first neighbor
        outIdx[row][dim] = colIdx[csrRowPtr[row]];
    }
}

__global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
                               int64_t N, int64_t numel) {

  int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  if (thread_idx == 0) {
    for (int64_t i = 0; i <= ind_data[0]; i++)
      out_data[i] = 0;
  } else if (thread_idx < numel) {
    for (int64_t i = ind_data[thread_idx - 1]; i < ind_data[thread_idx]; i++)
      out_data[i + 1] = thread_idx;
  } else if (thread_idx == numel) {
    for (int64_t i = ind_data[numel - 1] + 1; i < N + 1; i++)
      out_data[i] = numel;
  }
}

torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t N, const int THREADS = 256) {

  auto out = torch::empty(N + 1, ind.options());

  if (ind.numel() == 0)
    return out.zero_();

  auto ind_data = ind.data_ptr<int64_t>();
  auto out_data = out.data_ptr<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  ind2ptr_kernel<<<(ind.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
                   stream>>>(ind_data, out_data, N, ind.numel());
  return out;
}

at::Tensor dimmedian_idx_forward_cuda(
    torch::Tensor X,
    torch::Tensor rowIndices,
    torch::Tensor colIndices,
    torch::Tensor edge_weights,
    const int64_t N,
    const int THREADS = 1024)
{
       // Fix argsort ambiguity by explicitly casting to int64_t and specifying descending parameter
    torch::Tensor X_argsort = X.argsort(static_cast<int64_t>(0), false).to(torch::kInt32);
    torch::Tensor X_rev_argsort = X_argsort.argsort(static_cast<int64_t>(0), false).to(torch::kInt32);

    int64_t d = X.size(1);
    torch::Tensor values = edge_weights.to(torch::kFloat32);
    
    torch::Tensor csrPtr = ind2ptr_cuda(rowIndices, N).to(torch::kInt32);
    rowIndices = rowIndices.to(torch::kInt32);
    colIndices = colIndices.to(torch::kInt32);

    const dim3 n_blocks((N * d + THREADS - 1) / THREADS);
    torch::Tensor outIdx = torch::full({N, d}, -1, X.options().dtype(torch::kInt32));
    
    AT_DISPATCH_INTEGRAL_TYPES(csrPtr.scalar_type(), "dimmedian_idx_cuda_forward", ([&] {
        dimmedian_idx_cuda_forward_kernel<scalar_t><<<n_blocks, THREADS>>>(
            csrPtr.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
            rowIndices.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
            colIndices.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
            values.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
            X_argsort.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
            X_rev_argsort.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
            outIdx.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
            N,
            d);
    }));
    
    return outIdx.to(torch::kInt64);
} 