import torch

from .extension import feat_to_v_attr, feat_to_e_attr


__all__ = [
    "cpp_feat_to_v_attr",
    "cpp_feat_to_e_attr"
]


def cpp_feat_to_v_attr(
    ingredients: torch.LongTensor,
    attn_cls: torch.Tensor,
    n_vertices: int,
    mean: bool = False,
    ingredients_only: bool = False
) -> torch.Tensor:
    return feat_to_v_attr(ingredients, attn_cls, n_vertices, mean, ingredients_only)


def cpp_feat_to_e_attr(
    ingredients: torch.LongTensor,
    attn: torch.Tensor,
    geo_sim: torch.Tensor,
    n_vertices: int,
    mean: bool = False
) -> torch.Tensor:
    return feat_to_e_attr(
        ingredients,
        attn,
        geo_sim,
        n_vertices,
        mean
    )
