import itertools
import munkres
import numpy as np
import pandas as pd
from scipy.io import arff
from sklearn.metrics import confusion_matrix as sk_confusion
from scipy.optimize import linear_sum_assignment
import sklearn
from utils import utils


def homogeneity_score(y_true, y_pred):
    """
    Calculate Homogeneity Score
    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        homogeneity_score float [0.0, 1.0]
        1.0 stands for perfectly homogeneous labeling
    """

    return sklearn.metrics.cluster.homogeneity_score(y_true, y_pred)


def completeness_score(y_true, y_pred):
    """
    Calculate Completness Score
    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        completeness float [0.0, 1.0]
        1.0 stands for perfectly complete labeling
    """

    return sklearn.metrics.cluster.completeness_score(y_true, y_pred)


def v_measure_score(y_true, y_pred):
    """
    Calculate V Measure Score
    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        v_measure float [0.0, 1.0]
        1.0 stands for perfectly complete labeling
    """

    return sklearn.metrics.cluster.v_measure_score(y_true, y_pred)


def nmi(y_true, y_pred):
    """
    Calculate clustering Normalized Mutual Information
    Changed in version 0.22: The default value of average_method changed from ‘geometric’ to ‘arithmetic’.
    Geometric is unstable and can diverge the metric. Arithmetic give more precise values and does not diverge.

    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return float
        score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling
    """

    return sklearn.metrics.normalized_mutual_info_score(y_true, y_pred, average_method='arithmetic')


def ari(y_true, y_pred):    
    """
    Calculate the Adjusted Rand Index
    # Arguments
        true_labels: true labels, numpy.array with shape `(n_samples,)`
        predict_labels: predicted labels, numpy.array with shape `(n_samples,)`

    # Return float
        Similarity score between -1.0 and 1.0. Random labelings have an 
        ARI close to 0.0. 1.0 stands for perfect match.
    """

    return sklearn.metrics.adjusted_rand_score(y_true, y_pred)


def purity(y_true, y_pred):
    """
    Calculate clustering purity
    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        purity, in [0,1]
    """

    y_true = np.asarray(y_true).astype(np.int64)
    y_pred = np.asarray(y_pred).astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    label_mapping = w.argmax(axis=1)
    y_pred_voted = y_pred.copy()
    for i in range(y_pred.size):
        y_pred_voted[i] = label_mapping[y_pred[i]]
    return sklearn.metrics.accuracy_score(y_pred_voted, y_true)


def acc(labels_true, labels_pred):
    """
    Given thelabels true and labels predict calculate the unsupervised acc
    """

    labels_true = np.asarray(labels_true).astype(np.int64)
    labels_pred = np.asarray(labels_pred).astype(np.int64)
    assert labels_pred.size == labels_true.size
    
    D = max(labels_pred.max(), labels_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    
    for i in range(labels_pred.size):
        w[labels_pred[i], labels_true[i]] += 1
    
    ind = np.transpose(np.asarray(linear_sum_assignment(w.max() - w)))
    
    return sum([w[i, j] for i, j in ind]) * 1.0 / labels_pred.size


def compare_columns_cm(first_cm, second_cm):
    """
    Given two confusion matrix one before apply hungarian algorithm and other after it,
    return the map labels of first confusion matrix related to the other one.
    """

    labels = []
    for i in range(0, first_cm.shape[1]):
        for j in range(0, second_cm.shape[1]):
            if np.array_equal(first_cm[:, i], second_cm[:, j]):
                labels.append([i, j])
    return np.asarray(labels)


def reorder_predict_labels(y_pred):
    """
    Given the predict labels (y_pred) it is reorder to be 0 index base and incremented by one until the len of it
    """

    reorder_map = dict(zip(np.unique(y_pred),np.arange(len(y_pred))))
    for i in range(len(y_pred)):
        y_pred[i] = reorder_map[y_pred[i]]
    return y_pred


def map_confusion_matrix_labels(y_true, y_pred):
    """
    Map confusion matrix columns before and after hungarian method
    # Arguments
        y_true: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        list of tuples that maps each column of the confusion matrix
    """

    cm_original = predict_to_confusion(y_true, y_pred)
    cm_hungarian = maximize_trace(cm_original)
    map_labels = compare_columns_cm(cm_original, cm_hungarian)

    # Need to sort because of hungarian algorithm
    return map_labels[map_labels[:, 1].argsort()]

def clustering_error(confusion):
    """
    Calculates the CE (clustering error) of a clustering 
    represented by its confusion matrix.

    Note: disjoint clustering only, i.e., it only works for data that belong 
    to exactly one class label.
    """

    confusion = maximize_trace(confusion)

    ce = 1 - np.trace(confusion) / np.sum(confusion)
    return ce


def permute_cols(a, inds):
    """
    Permutes the columns of matrix `a` given
    a list of tuples `inds` whose elements `(from, to)` describe how columns
    should be permuted.
    """

    p = np.zeros_like(a)
    for i in inds:
        p[i] = 1
    return np.dot(a, p)


def maximize_trace(a):
    """
    Maximize trace by minimizing the Frobenius norm of 
    `np.dot(p, a)-np.eye(a.shape[0])`, where `a` is square and
    `p` is a permutation matrix. Returns permuted version of `a` with
    maximal trace.
    """

    # Adding columns or rows with zeros to enforce that a is a square matrix.
    while a.shape[0] != a.shape[1]:
        if a.shape[0] < a.shape[1]:
            a = np.vstack((a, np.zeros(a.shape[1])))
        elif a.shape[1] < a.shape[0]:
            a = np.hstack((a, np.zeros((a.shape[0], 1))))

    assert a.shape[0] == a.shape[1]
    d = np.zeros_like(a)
    n = a.shape[0]
    b = np.eye(n, dtype=int)
    for i, j in itertools.product(range(n), range(n)):
        d[j, i] = sum((b[j, :] - a[i, :]) ** 2)
    m = munkres.Munkres()
    inds = m.compute(d)
    return permute_cols(a, inds)


def predict_to_confusion(y_true, y_pred):
    """
    Calculates the confusion matrix based on y_true and y_pred
    """
    return confusion_matrix(y_true, y_pred)


def predict_to_clustering_error(y_true, y_pred):
    """
    Calculates the clustering error from y_true and y_pred
    """
    return confusion_to_clustering_error(predict_to_confusion(y_true, y_pred), len(y_true))


def confusion_to_clustering_error(confusion, n_samples):
    """
    Calculates the CE (clustering error) of a clustering 
    represented by its confusion matrix.

    Note: disjoint clustering only, i.e., it only works for data that belong 
    to exactly one class label.
    """

    confusion = maximize_trace(confusion)

    return np.trace(confusion) / n_samples


def results_to_clustering_error(data_file, results_file):
    """
    Receives the data file (in .arff format) and a .arff.results file, which contains clustering 
    assignments, and returns the confusion matrix associated with the solution.

    If multiclass multilabel data (i.e., each instance can 
    belong to one or more categories), the label information in the data file must be in 
    one hot notation.

    For example: say the data is 3 dimensional and there are four possible 
    categories. A valid data file must be of the following format:

    3 1 2 0 0 1
    2 2 2 1 1 0
    4 5 1 0 1 1, and so on.

    The first data instance (array [3, 1, 2]) belongs solely to the 
    third label etc.

    In data file, label must be in the last column(s)
    
    """

    data, _ = arff.loadarff(open(data_file, 'r'))
    data = pd.DataFrame(data)

    data_n_winner, found_clusters, dim = utils.read_results(results_file)

    # Checking whether it's multiclass multilabel (subspace clustering) problem
    if data.shape[1] > dim + 1:
        print("Multilabel")
        return

    if not data_n_winner:  # Empty list
        ce = 0

    else:
        confusion = confusion_matrix(data, data_n_winner)
        ce = confusion_to_clustering_error(confusion, data.shape[0])

    return ce, found_clusters


def results_to_confusion_matrix(data_file, results_file):
    """
    Receives the data file (in .arff format) and a .arff.results file, which contains clustering
    assignments, and returns the confusion matrix associated with the solution.

    If multiclass multilabel data (i.e., each instance can
    belong to one or more categories), the label information in the data file must be in
    one hot notation.

    For example: say the data is 3 dimensional and there are four possible
    categories. A valid data file must be of the following format:

    3 1 2 0 0 1
    2 2 2 1 1 0
    4 5 1 0 1 1, and so on.

    The first data instance (array [3, 1, 2]) belongs solely to the
    third label etc.

    In data file, label must be in the last column(s)

    """

    data, _ = arff.loadarff(open(data_file, 'r'))
    data = pd.DataFrame(data)

    data_n_winner, found_clusters, dim = utils.read_results(results_file)

    # Checking whether it's multiclass multilabel (subspace clustering) problem
    if data.shape[1] > dim + 1:
        print("Multilabel")
        return

    if not data_n_winner:  # Empty list
        confusion = []

    else:
        confusion = confusion_matrix(data, data_n_winner)

    return confusion


def conditional_entropy(confusion):
    """
    Given a confusion matrix, computes the conditional entropy like in Tuytelaars 2008, 
    Unsupervised Object Discovery: a Comparison

    Obs.: the rows of the confusion matrix must be the found clusters, whereas 
    the columns are the reference (ground truth), i.e., the number of rows must equal 
    the number of found clusters, and the number of columns must 
    equal the number of class labels.
    """

    # Array with probability of each found cluster 
    py = np.zeros(confusion.shape[0])

    total_sum = np.sum(confusion)

    for i in range(confusion.shape[0]):
        py[i] = np.sum(confusion[i]) / total_sum

    pxDy = np.zeros((confusion.shape[0], confusion.shape[1]))
    for i in range(confusion.shape[0]):
        row_sum = np.sum(confusion[i]) + 0.00001
        for j in range(confusion.shape[1]):
            pxDy[i, j] = confusion[i, j] / row_sum

    out = 0
    for i in range(confusion.shape[0]):
        for j in range(confusion.shape[1]):
            out += py[i] * pxDy[i, j] * np.log2(1 / (pxDy[i, j] + 0.0001))

    return out


def multilabelresults_to_clustering_error(data_file, results_file, qty_categories):
    """
    Receives .arff file, results file and quantity of categories. Returns the 
    confusion matrix when data instances may belong to one or more labels

    The labels, in reality, are transformed to decimal. Example: if original label 
    was [1, 0, 1], meaning that the instance belongs to first and third classes, 
    the label in the file is the decimal notation, i.e., 5.
    """

    data, _ = arff.loadarff(open(data_file, 'r'))
    data = pd.DataFrame(data)
    results = open(results_file, 'r')

    # First line of results contains the number of found clusters and the dimension of the data
    first_line = results.readline().split()
    qty_found_clusters = int(first_line[0])

    # Finding found clusters id / data_n_winner is a tuple data id and winner id, typical of .results files.
    data_n_winner = []
    for line in results:
        line_split = line.split()

        # Results file contains a section with unnecessary data, which takes more than two columns. 
        # we are interested in the section with only two columns
        if len(line_split) == 2:
            data_n_winner.append(line.split())

    if len(data_n_winner) == 0:
        data_n_winner.append([0, 0])

    data_n_winner = np.asarray(data_n_winner, dtype=np.int)
    found_clusters_list = list(set(data_n_winner[:, -1]))

    # Converting winners to binary array
    bin_winner_array = np.zeros((data.shape[0], qty_found_clusters))

    for cell in data_n_winner:
        data_id = cell[0]
        winner_id = cell[1]

        bin_winner_array[data_id, found_clusters_list.index(winner_id)] = 1

    # Building bin label array
    bin_label_array = np.zeros((data.shape[0], qty_categories))

    for i in range(data.shape[0]):
        label = data[i, -1]
        binary_label = np.array(list(bin(int(label))[2:].zfill(int(qty_categories))))
        bin_label_array[i] = binary_label

    # Updating confusion matrix
    confusion_matrix = np.zeros((qty_found_clusters, qty_categories))
    for i in range(data.shape[0]):

        for j in range(bin_label_array.shape[1]):

            for k in range(bin_winner_array.shape[1]):

                if bin_label_array[i, j] == 1 and bin_winner_array[i, k] == 1:
                    confusion_matrix[k, j] += 1

    outconf = np.copy(confusion_matrix)
    confusion_matrix = maximize_trace(confusion_matrix)

    # Finding union size of all data points
    count = 0

    for i in range(bin_winner_array.shape[0]):
        qty_non_zero1 = len(bin_label_array[i][bin_label_array[i] != 0])
        qty_non_zero2 = len(bin_winner_array[i][bin_winner_array[i] != 0])
        count += np.max([qty_non_zero1, qty_non_zero2])

    ce = confusion_matrix.trace() / count

    return ce, outconf


def confusion_matrix(y_true, y_pred):
    """
    Given a y_true and y_pred calculate the confusion matrix.
    """

    y_true = pd.DataFrame(y_true)
    y_pred = pd.DataFrame(y_pred)

    if len(y_true.iloc[:, -1]) == len(y_pred.iloc[:, -1]):  # During the training and most cases
        y_true = y_true.iloc[:, -1].astype('float32').values
        y_pred = y_pred.iloc[:, -1].astype('float32').values
        y_pred = reorder_predict_labels(y_pred)
        confusion = sk_confusion(y_true, y_pred)

    else:  # Called only when reading from .results with samples considered as noise
        # Finding correct number of clusters and true clusters id's
        true_clusters_list = list(y_true.iloc[:, -1].unique())
        qty_true_clusters = len(true_clusters_list)

        found_clusters_list = list(y_pred.iloc[:, -1].unique())
        qty_found_clusters = len(found_clusters_list)

        # Confusion matrix
        confusion = np.zeros((qty_found_clusters, qty_true_clusters))

        for i in range(y_pred.shape[0]):
            curr_sample = int(y_pred.iloc[i, 0])
            curr_data = np.array(y_true.iloc[curr_sample])

            # Row of the confusion matrix to be updated
            curr_winner = y_pred.iloc[i, 1]
            row = found_clusters_list.index(curr_winner)

            # Column to be updated
            curr_true = curr_data[-1]
            column = true_clusters_list.index(curr_true)

            confusion[row, column] += 1

    return confusion
