import numpy as np

from src.cvt.cvt_archive_2 import CVTArchive2


def test_sklearn_nn():
    archive = CVTArchive2(
        solution_dim=10,
        cells=100,
        ranges=[(-1, 1)] * 2,
        use_kd_tree=False,
        use_sklearn_nn=True,
        sklearn_nn_kwargs={"metric": "cosine"},
    )

    measures = np.array([[-0.5, -0.5], [1, 1]])

    indices = archive.index_of(measures)
    print(archive.centroids[indices])
    assert False


def test_empty_array():
    archive = CVTArchive2(
        solution_dim=10,
        cells=100,
        ranges=[(-1, 1)] * 2,
        use_kd_tree=False,
        use_sklearn_nn=True,
        sklearn_nn_kwargs={"metric": "cosine"},
    )

    indices = archive.index_of(np.empty((0, 2)))
    assert indices.shape == (0,)
