import numpy as np
import pytest

from src.baseline_sep_metrics import (
    balanced_accuracy_index,
    roc_auc_index,
    thornton_separability_index,
)

X = np.array([[0, 0], [0, 0.1], [10, 10], [10, 10.1]])


def th2(X, y):
    return thornton_separability_index(X, y, n_neighbors=2)


def rcauc(X, y):
    return roc_auc_index(X, y, n_splits=2)


def racc(X, y):
    return balanced_accuracy_index(X, y, n_splits=2)


@pytest.mark.parametrize(
    ["y", "method", "expected"],
    [
        (np.array([0, 0, 1, 1]), th2, 1.0),
        (np.array([0, 0, 1, 1]), rcauc, 1.0),
        (np.array([0, 0, 1, 1]), racc, 1.0),
        (np.array([0, 1, 1, 0]), th2, 0),
        (np.array([0, 1, 1, 0]), rcauc, 0),
        (np.array([0, 1, 1, 0]), racc, 0.5),
    ],
)
def test_supervised(y, method, expected):
    value = method(X, y)
    print(value)
    assert value == pytest.approx(expected)
