// Helper for type check
#define CHECK_CUDA_TENSOR_DIM_TYPE(name, n_dim, type)                             \
  TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!");          \
  TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!");                \
  TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \
  TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!");     \

// Helper for type check
#define CHECK_CUDA_TENSOR_TYPE(name, type)                                        \
  TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!");          \
  TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!");                \
  TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!");     \

// Helper for type check
#define CHECK_CUDA_TENSOR_FLOAT(name)                                             \
  TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!");          \
  TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!");                \
  TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16, \
              "The type of " #name " is not kFloat32 or kFloat16!");                           \

// Helper for type check
#define CHECK_CUDA_TENSOR_DIM_FLOAT(name, n_dim)                                  \
  TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!");          \
  TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!");                \
  TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \
  TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16, \
              "The type of " #name " is not kFloat32 or kFloat16!");                           \

#define CHECK_CUDA_TENSOR_DIM_BFLOAT(name, n_dim)                                  \
  TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!");          \
  TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!");                \
  TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \
  TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16 || name.dtype() == torch::kBFloat16, \
              "The type of " #name " is not kFloat32, kFloat16 or kBFloat16!");                \

#define AT_DISPATCH_FLOATING_TYPES_AND_BFLOAT16_AND_HALF(TYPE, NAME, ...) \
    [&] { \
        const auto& the_type = TYPE; \
        if (the_type == torch::kFloat) { \
            using scalar_t = float; \
            return __VA_ARGS__(); \
        } else if (the_type == torch::kDouble) { \
            using scalar_t = double; \
            return __VA_ARGS__(); \
        } else if (the_type == torch::kHalf) { \
            using scalar_t = at::Half; \
            return __VA_ARGS__(); \
        } else if (the_type == torch::kBFloat16) { \
            using scalar_t = at::BFloat16; \
            return __VA_ARGS__(); \
        } else { \
            TORCH_CHECK(false, "Unsupported scalar type for dispatch: ", the_type); \
        } \
    }()
