import pytest
import gemclus
from sklearn.utils.estimator_checks import check_estimator

from ._tree import Kauri
from sklearn import metrics

@pytest.mark.parametrize(
    ["X","y"], [gemclus.data.celeux_one(n=300, p=1, random_state=0)]
)
@pytest.mark.timeout(5)
def test_something_sagittarius(X,y):
    print(X)

    model = Kauri(max_clusters=3, verbose=True)

    y_pred = model.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

@pytest.mark.parametrize(
    ["X","y"], [gemclus.data.celeux_one(n=300, p=5, random_state=0)]
)
def test_something_more_variables(X,y):
    print(X)

    model = Kauri(max_clusters=3, verbose=True)

    y_pred = model.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

    print("Used variables: ")
    import numpy as np
    print(np.unique([x for x in model.tree_.features if x is not None]))

@pytest.mark.parametrize(
    ["X","y"], [gemclus.data.gstm(n=300, random_state=0)]
)
def test_gstm_min_leaf(X,y):
    print(X)

    model = Kauri(max_clusters=4, min_samples_leaf = 10, min_samples_split=20, verbose=True)

    y_pred = model.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

    print(y_pred)

@pytest.mark.parametrize(
    ["X","y"], [gemclus.data.gstm(n=300, random_state=0)]
)
def test_gstm_max_depth(X,y):
    print(X)

    model = Kauri(max_clusters=4, min_samples_split=20, max_depth=2, verbose=True)

    y_pred = model.fit_predict(X)

    print(metrics.adjusted_rand_score(y_pred, y))

    print(y_pred)

    print(model.score(X))

def test_good_estimator():
    estimator = Kauri()
    check_estimator(estimator)
