import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn import metrics


def unsupervised_accuracy(y, y_pred):
    confusion_matrix = metrics.confusion_matrix(y, y_pred)
    r, c = linear_sum_assignment(confusion_matrix, maximize=True)

    return confusion_matrix[r, c].sum() / confusion_matrix.sum()


def compute_kauri_wad(tree, X, node=0):
    if tree.children_left[node] == -1:
        return len(X) * tree.depths[node]
    else:
        if tree.categorical_nodes[node]:
            X_left = X[:, tree.features[node]] == tree.thresholds
        else:
            X_left = X[:, tree.features[node]] <= tree.thresholds[node]
        X_right = ~X_left

        wad = 0
        wad += compute_kauri_wad(tree, X[X_left], tree.children_left[node])
        wad += compute_kauri_wad(tree, X[X_right], tree.children_right[node])

        if node == 0:
            return wad / len(X)
        else:
            return wad


def compute_dt_wad(tree, X, depth=0, node=0):
    if tree.children_left[node] == -1:
        return len(X) * depth
    else:
        X_left = X[:, tree.feature[node]] <= tree.threshold[node]
        X_right = ~X_left

        wad = 0
        wad += compute_dt_wad(tree, X[X_left], depth + 1, tree.children_left[node])
        wad += compute_dt_wad(tree, X[X_right], depth + 1, tree.children_right[node])

        if depth == 0:
            return wad / len(X)
        else:
            return wad


def compute_exkmc_wad(tree, depth=0):
    if tree.is_leaf():
        return tree.samples * depth
    else:
        wad = 0

        wad += compute_exkmc_wad(tree.left, depth + 1)
        wad += compute_exkmc_wad(tree.right, depth + 1)

        if depth == 0:
            return wad / tree.samples
        else:
            return wad


def simplify_rules(ruleset):
    assert len(ruleset) > 0
    for elem in ruleset:
        assert len(elem) == 3  # structure is [feature[int], threshold[float], left_child[bool]]

    used_features = np.unique([x[0] for x in ruleset])
    simplified_rules = []

    for feature in used_features:
        feature_rules = [x for x in ruleset if x[0] == feature]

        lower_rules = [x for x in feature_rules if x[2]]
        greater_rules = [x for x in feature_rules if not x[2]]

        if len(lower_rules) != 0:
            lowest_threshold = min(lower_rules, key=lambda x: x[1])
            simplified_rules += [(feature, lowest_threshold, True)]
        if len(greater_rules) != 0:
            greatest_threshold = max(greater_rules, key=lambda x: x[1])
            simplified_rules += [(feature, greatest_threshold, False)]
    return simplified_rules


def compute_kauri_waes(tree, X, ruleset=None, node=0):
    if ruleset is None:
        ruleset = []
    if tree.children_left[node] == -1:
        simplified_rules = simplify_rules(ruleset)
        return len(X) * len(simplified_rules)
    else:
        X_left = X[:, tree.features[node]] <= tree.thresholds[node]
        X_right = ~X_left

        left_rule = [(tree.features[node], tree.thresholds[node], True)]
        right_rule = [(tree.features[node], tree.thresholds[node], False)]

        waes = 0
        waes += compute_kauri_waes(tree, X[X_left], ruleset + left_rule, tree.children_left[node])
        waes += compute_kauri_waes(tree, X[X_right], ruleset + right_rule, tree.children_right[node])

        if node == 0:
            return waes / len(X)
        else:
            return waes


def compute_dt_waes(tree, X, ruleset=None, node=0):
    if ruleset is None:
        ruleset = []
    if tree.children_left[node] == -1:
        simplified_rules = simplify_rules(ruleset)
        return len(X) * len(simplified_rules)
    else:
        X_left = X[:, tree.feature[node]] <= tree.threshold[node]
        X_right = ~X_left

        left_rule = [(tree.feature[node], tree.threshold[node], True)]
        right_rule = [(tree.feature[node], tree.threshold[node], False)]

        waes = 0
        waes += compute_dt_waes(tree, X[X_left], ruleset + left_rule, tree.children_left[node])
        waes += compute_dt_waes(tree, X[X_right], ruleset + right_rule, tree.children_right[node])

        if node == 0:
            return waes / len(X)
        else:
            return waes


def compute_exkmc_waes(tree, ruleset=None):
    if ruleset is None:
        ruleset = []
    if tree.is_leaf():
        simplified_rules = simplify_rules(ruleset)
        return tree.samples * len(simplified_rules)
    else:
        left_rule = [(tree.feature, tree.value, True)]
        right_rule = [(tree.feature, tree.value, False)]

        waes = 0
        waes += compute_exkmc_waes(tree.left, ruleset + left_rule)
        waes += compute_exkmc_waes(tree.right, ruleset + right_rule)

        if len(ruleset) == 0:
            return waes / tree.samples
        else:
            return waes


def compute_kmeans_score(X, y):
    total_score = 0
    for k in np.unique(y):
        cluster_idx, = np.where(y == k)
        X_cluster = X[cluster_idx]

        cluster_mean = X_cluster.mean(0, keepdims=True)

        total_score += metrics.pairwise_distances(X_cluster, cluster_mean, metric="sqeuclidean").sum()

    return total_score


def get_exkmc_n_leaves(tree):
    if tree.is_leaf():
        return 1
    else:
        return get_exkmc_n_leaves(tree.left) + get_exkmc_n_leaves(tree.right)


def get_tree_n_leaves(tree):
    return len([x for x in tree.children_left if x == -1])
