import numpy as np
import pandas as pd
import argparse
from sklearn.cluster import MiniBatchKMeans

# Config
parser = argparse.ArgumentParser()
parser.add_argument("--model_ind", type=int, required=True)
parser.add_argument("--dataset", type=str, default="MNIST")
parser.add_argument("--dataset_root", type=str,
                    default="/MNIST")

# Clustering
parser.add_argument("--num_clusters", type=int, default=10)

# Saving
parser.add_argument("--out_root", type=str,
                    default="/saves/")
parser.add_argument("--restart", dest="restart", default=False,
                    action="store_true")
# Custom Datasets
parser.add_argument("--customdata_train_path", type=str,
                    default="./src/datasets/mnist_multiple_train.pkl")
parser.add_argument("--customdata_test_path", type=str,
                    default="./src/datasets/mnist_multiple_test.pkl")

config = parser.parse_args()
config.out_dir = config.out_root + str(config.model_ind) + "/"

# Load data

if config.dataset == "PartMNIST":
    df = pd.read_pickle(config.customdata_train_path)
    print("Shape of dataset:", df.shape)
    from sklearn.utils import shuffle
    df = shuffle(df, random_state=0)
    labels = df.iloc[:,-1]
    df = df.iloc[:,:-1]
    labels = labels.to_numpy().astype(int)
    print(labels[:10])

    df_test = pd.read_pickle(config.customdata_test_path)
    print("Shape of dataset:", df_test.shape)
    from sklearn.utils import shuffle

    df_test = shuffle(df_test, random_state=0)
    labels_test= df_test.iloc[:, -1]
    df_test = df_test.iloc[:, :-1]
    labels_test = labels_test.to_numpy().astype(int)
    print(labels_test[:10])
if config.dataset == "RotMNIST":
    df = np.loadtxt(config.customdata_train_path)
    df = pd.DataFrame(df)
    print("Shape of dataset:", df.shape)
    from sklearn.utils import shuffle

    df = shuffle(df, random_state=0)
    labels = df.iloc[:, -1]
    df = df.iloc[:, :-1]
    labels = labels.to_numpy().astype(int)
    print(labels[:10])

    df_test = np.loadtxt(config.customdata_test_path)
    df_test = pd.DataFrame(df_test)
    print("Shape of dataset:", df_test.shape)
    from sklearn.utils import shuffle

    df_test = shuffle(df_test, random_state=0)
    labels_test = df_test.iloc[:, -1]
    df_test = df_test.iloc[:, :-1]
    labels_test = labels_test.to_numpy().astype(int)
    print(labels_test[:10])

# Initialize KMeans model
n_digits = 10
kmeans = MiniBatchKMeans(n_clusters=n_digits)

df_np = df.to_numpy()
df_test_np = df_test.to_numpy()
# Train dataset
X = df_np
Y = labels
Y = Y.astype(int)
# Test dataset
X_test = df_test_np
Y_test = labels_test
Y_test = Y_test.astype(int)
# Fit the model to the training data
print("Fitting K-means")
kmeans.fit(df_np)


# Assignment problem

def infer_cluster_labels(kmeans, actual_labels):
    """
    Associates most probable label with each cluster in KMeans model
    returns: dictionary of clusters assigned to each label
    """

    inferred_labels = {}

    for i in range(kmeans.n_clusters):

        # find index of points in cluster
        labels = []
        index = np.where(kmeans.labels_ == i)

        # append actual labels for each point in cluster
        labels.append(actual_labels[index])

        # determine most common label
        if len(labels[0]) == 1:
            counts = np.bincount(labels[0])
        else:
            counts = np.bincount(np.squeeze(labels))

        # assign the cluster to a value in the inferred_labels dictionary
        if np.argmax(counts) in inferred_labels:
            # append the new number to the existing array at this slot
            inferred_labels[np.argmax(counts)].append(i)
        else:
            # create a new array in this slot
            inferred_labels[np.argmax(counts)] = [i]

        # print(labels)
        # print('Cluster: {}, label: {}'.format(i, np.argmax(counts)))

    return inferred_labels


def infer_data_labels(X_labels, cluster_labels):
    """
    Determines label for each array, depending on the cluster it has been assigned to.
    returns: predicted labels for each array
    """

    # empty array of len(X)
    predicted_labels = np.zeros(len(X_labels)).astype(np.uint8)

    for i, cluster in enumerate(X_labels):
        for key, value in cluster_labels.items():
            if cluster in value:
                predicted_labels[i] = key

    return predicted_labels

# Test the infer_cluster_labels() and infer_data_labels() functions
cluster_labels = infer_cluster_labels(kmeans, Y)
X_clusters = kmeans.predict(X)
predicted_labels = infer_data_labels(X_clusters, cluster_labels)
print(predicted_labels[:20])
print(Y[:20])


# Accuracies
from sklearn import metrics

def calculate_metrics(estimator, data, labels):

    # Calculate and print metrics
    print('Number of Clusters: {}'.format(estimator.n_clusters))
    print('Inertia: {}'.format(estimator.inertia_))
    print('Homogeneity: {}'.format(metrics.homogeneity_score(labels, estimator.labels_)))

clusters = [10, 16, 36, 64, 144, 256]

# test different numbers of clusters
for n_clusters in clusters:
    estimator = MiniBatchKMeans(n_clusters=n_clusters)
    estimator.fit(X)

    # print cluster metrics
    calculate_metrics(estimator, X, Y)

    # determine predicted labels
    cluster_labels = infer_cluster_labels(estimator, Y)
    predicted_Y = infer_data_labels(estimator.labels_, cluster_labels)

    # calculate and print accuracy
    print('Train Accuracy: {}\n'.format(metrics.accuracy_score(Y, predicted_Y)))
    # Predict test data
    test_clusters = estimator.predict(X_test)

    # Use the cluster labels to infer the labels for the test data
    predicted_Y_test = infer_data_labels(test_clusters, cluster_labels)

    # Calculate and print test accuracy
    print('Test accuracy: {}\n'.format(metrics.accuracy_score(Y_test, predicted_Y_test)))