import numpy as np
import pandas as pd
import pytest

from ccvae.metrics import knn_metric
from ccvae.metrics import entropy_of_mixing
from ccvae.metrics import silhouette_coeff


def test_entropy_of_mixing():
    features = pd.DataFrame(
        [
            [0.0, 0.0],  # query
            [1.0, 0.0],
            [1.0, 1.0],  # nearest to cluster 2
            [-1.0, 0.0],
            [0.0, -1.0],
            [10.0, 10.0],  # query
            [11.0, 10.0],
            [10.0, 11.0],
            [10.0, 9.0],
        ]
    )
    queries = np.array([True, False, False, False, False, True, False, False, False])
    labels = np.array([0, 1, 0, 0, 0, 2, 2, 2, 3])
    class_partition = np.array([True, True, True, True, False, False, True, False, True])
    n_neighbours = 4
    _, counts = knn_metric(
        features, queries, labels, class_partition, n_neighbours=n_neighbours, return_counts=True
    )
    entropy, entropy_disease = entropy_of_mixing(counts, n_neighbours=n_neighbours)

    expected_entropy = -np.array(
        [
            ((1 + 1) / 6) * np.log((1 + 1) / 6) + (4 / 6) * np.log(4 / 6),
            (4 / 6) * np.log(4 / 6) + (2 / 6) * np.log(2 / 6),
        ]
    )
    expected_entropy_disease = -np.array(
        [
            ((1 + 1) / 5) * np.log((1 + 1) / 5) + (3 / 5) * np.log(3 / 5),
            (2 / 4) * np.log(2 / 4) + (2 / 4) * np.log(2 / 4),
        ]
    )
    np.testing.assert_almost_equal(entropy, expected_entropy)
    np.testing.assert_almost_equal(entropy_disease, expected_entropy_disease)


def test_silhouette_coeff():
    # Normal
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-2.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 0, 0, 0, 0])
    class_partition = np.array([True, True, True, False, False])
    actual, _ = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 2.0])
    np.testing.assert_almost_equal(actual, expected)

    # All same class
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 0, 0, 0, 0])
    class_partition = np.array([True, True, True, True, True])
    actual, _ = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1])
    np.testing.assert_almost_equal(actual, expected)
    
    # All other class
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 0, 0, 0, 0])
    class_partition = np.array([True, False, False, False, False])
    actual, _ = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([-1])
    np.testing.assert_almost_equal(actual, expected)

    # Disease Normal all same disease
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-2.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 0, 0, 0, 0])
    class_partition = np.array([True, True, True, False, False])
    actual, actual_disease = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 2.0])
    expected_disease = np.array([1 - 1.05 / 2.0])
    np.testing.assert_almost_equal(actual, expected)
    np.testing.assert_almost_equal(actual_disease, expected_disease)

    # Disease Normal all different disease
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-2.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 1, 0, 0, 0])
    class_partition = np.array([True, True, True, False, False])
    actual, actual_disease = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 2.0])
    expected_disease = np.array([1 - 1.0 / 2.0])
    np.testing.assert_almost_equal(actual, expected)
    np.testing.assert_almost_equal(actual_disease, expected_disease)

    # Disease Normal all different disease #2
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-4.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 1, 0, 0, 1])
    class_partition = np.array([True, True, True, False, False])
    actual, actual_disease = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 3.0])
    expected_disease = np.array([1 - 1.0 / 2.0])
    np.testing.assert_almost_equal(actual, expected)
    np.testing.assert_almost_equal(actual_disease, expected_disease)

    # Disease Normal all different disease -- edge case no diseases both classes
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-4.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 1, 1, 1, 1])
    class_partition = np.array([True, True, True, False, False])
    actual, actual_disease = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 3.0])
    expected_disease = np.array([0])
    np.testing.assert_almost_equal(actual, expected)
    np.testing.assert_almost_equal(actual_disease, expected_disease)

    # Disease Normal all different disease -- edge case no diseases of same class
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-4.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 1, 1, 0, 1])
    class_partition = np.array([True, True, True, False, False])
    actual, actual_disease = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 3.0])
    expected_disease = np.array([-1])
    np.testing.assert_almost_equal(actual, expected)
    np.testing.assert_almost_equal(actual_disease, expected_disease)

    # Disease Normal all different disease -- edge case no diseases of other class
    features = pd.DataFrame([[0.0, 0.0], [1.1, 0.0], [-1.0, 0.0], [2.0, 0.0], [-4.0, 0.0]])
    queries = np.array([True, False, False, False, False])
    labels = np.array([0, 1, 0, 1, 1])
    class_partition = np.array([True, True, True, False, False])
    actual, actual_disease = silhouette_coeff(features, queries, labels, class_partition, n_neighbours=4)
    expected = np.array([1 - 1.05 / 3.0])
    expected_disease = np.array([1])
    np.testing.assert_almost_equal(actual, expected)
    np.testing.assert_almost_equal(actual_disease, expected_disease)