#include "include/utils.h"


torch::Tensor jensen_shannon(
    const torch::Tensor sparse_feats_1,
    const torch::Tensor sparse_feats_2
){
    CHECK_INPUT(sparse_feats_1);
    CHECK_INPUT(sparse_feats_2);

    return jensen_shannon_cu(sparse_feats_1, sparse_feats_2);
}

torch::Tensor jaccard(
    const torch::Tensor sparse_feats_1,
    const torch::Tensor sparse_feats_2
){
    CHECK_INPUT(sparse_feats_1);
    CHECK_INPUT(sparse_feats_2);

    return jaccard_cu(sparse_feats_1, sparse_feats_2);
}

torch::Tensor intersection(
    const torch::Tensor sparse_feats_1,
    const torch::Tensor sparse_feats_2
){
    CHECK_INPUT(sparse_feats_1);
    CHECK_INPUT(sparse_feats_2);

    return intersection_cu(sparse_feats_1, sparse_feats_2);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("jensen_shannon", &jensen_shannon);
    m.def("jaccard", &jaccard);
    m.def("intersection", &intersection);
}