#include "./args_validation.h"


#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

#define CHECK_2D(x) TORCH_CHECK(x.sizes().size() == 2, #x " must be 2-D")


namespace sfrp {


void check_valid_mat_and_out(at::Tensor mat, at::Tensor out) {
    CHECK_INPUT(mat);
    CHECK_2D(mat);

    CHECK_INPUT(out);
    CHECK_2D(out);

    TORCH_CHECK(mat.scalar_type() == out.scalar_type(), "The dtypes of mat and out must match.");
    TORCH_CHECK(mat.size(0) == out.size(0), "The first dimension of mat and out must match.");
}


void check_valid_trp_mat_and_out(at::Tensor mat, at::Tensor out) {
    // The non-transposed check works fine here.
    check_valid_mat_and_out(mat, out);
}


}  // sfrp
