#include <torch/extension.h>


template <typename scalar_t>
__global__ void deepcopy_kernel(
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> original_matrix,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> row_indexes,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> target_matrix
){
    const int i = blockIdx.x * blockDim.x + threadIdx.x;
    const int j = blockIdx.y * blockDim.y + threadIdx.y;
    const int dimension = original_matrix.size(1);

    if (i>=row_indexes.size(0) || j>=dimension) return;
    int row_index = row_indexes[i];
    target_matrix[i][j] = 0.0+original_matrix[row_index][j];
}


torch::Tensor deepcopy_cu(
    const torch::Tensor original_matrix,
    const torch::Tensor row_indexes,
    torch::Tensor target_matrix
){
    const int n1 = row_indexes.size(0), n2 = original_matrix.size(1);
    
    // torch::Tensor target_matrix = torch::zeros({n1, n2}, original_matrix.options());

    const dim3 threads(16, 16);
    const dim3 blocks((n1+threads.x-1)/threads.x, (n2+threads.y-1)/threads.y);

    AT_DISPATCH_FLOATING_TYPES(original_matrix.type(), "deepcopy_cu", 
    ([&] {
        deepcopy_kernel<scalar_t><<<blocks, threads>>>(
            original_matrix.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
            row_indexes.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
            target_matrix.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>()
        );
    }));

    return target_matrix;
}
