import numpy as np

try:
    from mmdagg.jax import mmdagg
except ImportError:
    try:
        from mmdagg.np import mmdagg
    except ImportError:
        raise ImportError("mmdagg package not found")


def mmdagg_test(X, Y, r=None):
    try:
        X = np.asarray(X)
        Y = np.asarray(Y)
        if X.shape[0] < 2 or Y.shape[0] < 2:
            return 'na'
        result = mmdagg(X, Y)
        if hasattr(result, 'item'):
            return int(result.item())
        else:
            return int(result)
    except Exception:
        return 'na'
