import pytest
import gemclus
from gemclus.gemini import MMDOvA, MMDOvO, WassersteinOvA, WassersteinOvO
from sklearn.utils.estimator_checks import check_estimator

from ._tree import Douglas
from sklearn import metrics


@pytest.mark.parametrize(
    "gemini", [MMDOvA(), MMDOvO(), WassersteinOvA(), WassersteinOvO()]
)
@pytest.mark.timeout(5)
def test_something_sagittarius(gemini):
    X, y = gemclus.data.celeux_one(n=300, p=1, random_state=0)

    clf = Douglas(verbose=True, random_state=0, gemini=gemini)
    y_pred = clf.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

def test_something_batches():
    X, y = gemclus.data.celeux_one(n=1000, p=1, random_state=0)

    clf = Douglas(verbose=True, random_state=0, gemini=MMDOvA(), batch_size=100, n_epochs=10)
    y_pred = clf.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))


def test_good_estimator():
    estimator = Douglas()
    check_estimator(estimator)
