import pytest
import gemclus
from sklearn.utils.estimator_checks import check_estimator

from ._tree import TorchDouglas
from sklearn import metrics

@pytest.mark.parametrize(
    "gemini", ["mmd_ova", "mmd_ovo", "wasserstein_ova", "wasserstein_ovo"]
)
@pytest.mark.timeout(5)
def test_something_sagittarius(gemini):
    X, y = gemclus.data.celeux_one(n=300, p=1, random_state=0)

    clf = TorchDouglas(verbose=True, random_state=0, gemini=gemini)
    y_pred = clf.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

def test_something_batches():
    X, y = gemclus.data.celeux_one(n=1000, p=1, random_state=0)

    clf = TorchDouglas(verbose=True, random_state=0, gemini="mmd_ova", batch_size=100, n_epochs=10)
    y_pred = clf.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

def test_good_estimator():
    estimator = TorchDouglas(gemini="mmd_ova")
    check_estimator(estimator)
    