import numpy as np
import random
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd

from copy import deepcopy

from scipy.sparse.csgraph import johnson # for computing shortest paths
from scipy.sparse.linalg import eigs # for custom spectral clustering algorithm
from scipy.sparse.linalg import eigsh # for custom spectral clustering algorithm
from sklearn.cluster import SpectralClustering
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt
from scipy.sparse.csgraph import johnson # for computing shortest paths

from custom_helper_functions import env_to_state, get_probs, get_environment_from_action
from custom_helper_functions import remove_subgoals_custom_rules, curate_subgoal_path_custom_rules

# create a list of environments for all possible moves of max length max_steps
# from env_initial
# env_initial - environment from which to compute all possible moves of length <= max_steps
# ac - policy
# max_steps - maximum number of steps to move from env_initial
# action_prob_threshold - threshold to be a valid action
# Returns:
# env_list - list of enviroments with max_steps of env_initial
# number_steps - integer list where number_steps[i] is the number of steps moved from env_initial
def get_environments(env_initial, ac, max_steps, action_prob_threshold=0.00001):
    env_list = [env_initial]
    env_list_temp = [env_initial] # this contains recently added environments during inner loop
    number_steps = [0] # list of the number of steps from initial environment
    for i in range(max_steps):
        env_list_add = []
        (env_list_temp, __, __) = get_unique_environments(env_list_temp, [0]*len(env_list_temp)) # reduce the number of environments that go through the following two loops
        for env_temp in env_list_temp:
            probs = get_probs(env_temp, ac)
            for action in range(5): # for each environment in inner loop, check all possible actions
                if probs[0][action] >= action_prob_threshold: # don't add an environment that is not possible by the policy
                    env2 = deepcopy(env_temp) # will have to keep copying the environment to simulate actions
                    env_list_new = get_environment_from_action(env2, action)
                    for env3 in env_list_new:
                        env_list_add.append(env3)
                        number_steps.append(i+1)
        env_list.extend(env_list_add) # add new environments to the overall environment list
        env_list_temp = deepcopy(env_list_add)

    return (env_list, number_steps)

#### state_list is small because only includes states with 'non-zero' probabilities in transition
#### This code uses np.lexsort
# Input:
# env_list - list of enviroments with max_steps of env_initial
# number_steps - integer list where number_steps[i] is the number of steps moved from env_initial
# Output:
# env_list_final - list of unique environments
# state_list_final - list of unique grid worlds
# number_steps_final - list of number of steps from env_initial to each environment in env_list_final
def get_unique_environments(env_list, number_steps):
    # convert environments to 2d grid worlds
    state_list = []
    for env_temp in env_list:
        state_list.append(env_to_state(env_temp))

    # reshape the 2d grid worlds
    m = len(state_list[0])
    state_list_reshaped = []
    for state in state_list:
        state_list_reshaped.append(state.reshape((1,m))[0])

    # use np.lexsort to sort the 2d grid worlds
    state_list_reshape_sorted_inds = np.lexsort(np.transpose(np.array(state_list_reshaped)))
    state_list_reshape_sorted = [state_list_reshaped[state_list_reshape_sorted_inds[i]] for i in range(len(state_list_reshaped))]
    inds = [state_list_reshape_sorted_inds[0]] # these will be indices from unsorted (corresponding to environments)
    # find unique indices
    for i in range(1, len(state_list_reshape_sorted)):
        if np.sum(np.abs(state_list_reshape_sorted[i]-state_list_reshape_sorted[i-1])) > 0: # add new index
            inds.append(state_list_reshape_sorted_inds[i])

    # take only unique grid worlds/environements
    state_list_final = [state_list[i] for i in inds]
    env_list_final = [env_list[i] for i in inds]
    number_steps_final = [number_steps[i] for i in inds]

    return (env_list_final, state_list_final, number_steps_final)

# Create a transition matrix for moving between the different environments
# Inputs:
# env_list_final - list of enviroments with max_steps of env_initial
# state_list_final - list of unique grid worlds
# actor_critic - object that contains policy information
# action_prob_threshold - threshold to be a valid action
# Outputs:
# transition_matrix - numpy matrix with one-step transition probabilities
def create_transition_matrix(env_list_final, state_list_final, actor_critic, action_prob_threshold=0.0001):
    num_states = len(env_list_final)
    transition_matrix = np.zeros((num_states, num_states)) # transition_matrix[i, j]  is likelihood of going from state i to state j in one step
    for i in range(num_states):
        probs = get_probs(env_list_final[i], actor_critic)
        for action in range(5): # for each environoment in inner loop, check all possible actions
            if probs[0][action] >= action_prob_threshold: # don't add an environment that is not possible by the policy
                env2 = deepcopy(env_list_final[i]) # will have to keep copying the environment to simulate actions
                env_list_new = get_environment_from_action(env2, action)
                for env3 in env_list_new:
                    state = env_to_state(env3)
                    for j in range(num_states): # search for the state that action moves to
                        if np.sum(abs(state - state_list_final[j])) == 0: # found the state
                            transition_matrix[i, j] = probs[0][action] / len(env_list_new) # assume each env in env_list_new has equal likelihood
                            break
    return transition_matrix

# Given a transition matrix where transition_matrix[i,j] is probability of going
# from state i to state j, create a distance matrix between each pair of states
# and return the shortest_dist and predecessors matrix
# Input:
# transition_matrix - transition_matrix[i, j] is probability of going from state i to j
# Output:
# shortest_dist - shortest_dist[i, j] is the shortest distance between states i and j
# predecessors - predecessors[i, j] is predecessor state to j on shortest path
#                from state i to state j
def shortest_path_ssx(transition_matrix):
    num_states = transition_matrix.shape[0]
    # transformation - instead of looking for max likelihood path, we look for minimum distance (summable due to log transformation)
    # states that are further from each other (i.e., more diverse) have smaller likelihood path or larger shortest distance path
    distance_matrix = np.zeros((num_states,num_states))
    [inds0, inds1]= np.where(transition_matrix>0)
    distance_matrix[inds0,inds1] = -np.log(transition_matrix[inds0, inds1]) # shortest distance 0 represents no link between two nodes
    [shortest_dist, predecessors] = johnson(distance_matrix, return_predecessors = True)

    return (shortest_dist, predecessors)

# Compute number of changes from metastate i to metastate j at each state
# Inputs:
# shortest_dist - shortest_dist[i,j] is shortest distance between nodes i and j
# predessors - predessors[i,j] is predessor of node j on path from i to j
# labels_spectral - labels_spectral[i] is cluster id of node i
# num_states - number of states in graph
# num_clusters - number of clusters in graph
# eps - -np.log(eps) is the threshold considered for infinity distance between two nodes
# Outputs:
# meta_state_count - meta_state_count[i,j,k] is number of paths through state i from cluster j to cluster k
def compute_metastate_count(shortest_dist, predessors, labels_spectral, num_states, num_clusters, eps= 1e-8):
    metastate_count = np.zeros((num_states, num_clusters, num_clusters))
    dist_inf = -np.log(eps)
    for i in range(num_states): # run loop for each state
        predecessors_meta = predecessors[i,:]
        inds = np.where(np.abs(shortest_dist[i,:]-dist_inf) <= 1e-8)[0]
        predecessors_meta[inds] = -9999 # no path for i to all nodes in inds
        inds = np.where(predecessors_meta >= 0)[0] # there is a path from i to inds
        for j in inds: # go through each node j with path from i to j
            pred_j = predecessors_meta[j] # predecessor of node j
            if labels_spectral[j] != labels_spectral[pred_j]: # if pred_j is not in same cluster as j
                metastate_count[pred_j, labels_spectral[pred_j], labels_spectral[j]] += 1 # there is a linkfrom pred_j to j that changes clusters
    return (meta_state_count)

# Compute number of path state s is on from metastate phi(s) to some other metastate
# assuming metastate for all other states are fixed
# Inputs:
# predessors - predessors[i,j] is predessor of node j on path from i to j
# labels_spectral - labels_spectral[i] is cluster id of node i
# shortest_dist - shortest_dist[i,j] is shortest distance between nodes i and j
# num_states - number of states in graph
# num_clusters - number of clusters in graph
# Outputs:
# numpaths_count - numpaths_count[i, j] is number of paths through node i in cluster j that go to different cluster from node i
# numpaths_prob - numpaths_prob[i, j] is sum of likelihood of paths through node i in cluster j that go to different cluster from node i
# numpaths_count_pred - numpaths_count[i, j] is number of paths through node i straight to node j that is in a different cluster from node i
def compute_numpaths_count(predecessors, labels_spectral, shortest_dist, num_states, num_clusters):
    numpaths_count = np.zeros((num_states, num_clusters))
    numpaths_prob = np.zeros((num_states, num_clusters))
    numpaths_count_pred = np.zeros((num_states,num_states))
    # we run through every path s->t and add 1 to corresponding states w along the path
    # such that phi(s)=phi(w) and phi(w) != phi(t)
    for s in range(num_states): # run loop for each state
        for t in range(num_states): # run loop for each state
            meta_s = labels_spectral[s]
            meta_t = labels_spectral[t]
            if meta_s != meta_t: # only search path if s and t are in different metastates
                pred = t # start with last node
                meta_pred = labels_spectral[pred]
                while pred != s and pred != -9999 and meta_s != meta_pred: # while predecessor is not in same cluster as s
                    pred_temp = pred
                    pred = predecessors[s,pred] # get next predessor along path from s->t
                    if pred != -9999 and pred != s : # check that there is a path from s to pred
                        meta_pred = labels_spectral[pred]
                        if meta_s == meta_pred: # increase counter when predecessor along path in same cluster as s
                            if check_path(s, pred, predecessors, labels_spectral): # check that rest of path to s remains in same cluster as s
                                numpaths_count[pred, meta_pred] += 1 # pred is in same metastate as s which is different from t
                                numpaths_prob[pred, meta_pred] += np.exp(-shortest_dist[s, t]) # pred is in same metastate as s which is different from t
                                numpaths_count_pred[pred, pred_temp] += 1

    return (numpaths_count, numpaths_prob, numpaths_count_pred)

# Compute number of path state s is on from metastate phi(s) to some other metastate
# assuming phi(s) varies through all possible metastates
# while also assuming metastate for all other states
# Inputs:
# predessors - predessors[i,j] is predessor of node j on path from i to j
# labels_spectral - labels_spectral[i] is cluster id of node i
# shortest_dist - shortest_dist[i,j] is shortest distance between nodes i and j
# num_states - number of states in graph
# num_clusters - number of clusters in graph
# Outputs:
# numpaths_count - numpaths_count[i, j] is number of paths through node i in cluster j that go to different cluster from node i
# numpaths_prob - numpaths_prob[i, j] is sum of likelihood of paths through node i in cluster j that go to different cluster from node i
# numpaths_count_pred - numpaths_count[i, j] is number of paths through node i straight to node j that is in a different cluster from node i
def compute_numpaths_count_all(predecessors, labels_spectral, shortest_dist, num_states, num_clusters):
    numpaths_count = np.zeros((num_states, num_clusters))
    numpaths_prob = np.zeros((num_states, num_clusters))
    numpaths_count_pred = np.zeros((num_states,num_states))
    # we run through every path s->t and add 1 to corresponding states w along the path
    # such that phi(s)=phi(w) and phi(w) != phi(t)
    for s in range(num_states): # run loop for each state
        for t in range(num_states): # run loop for each state
            meta_s = labels_spectral[s]
            meta_t = labels_spectral[t]
            if meta_s != meta_t: # only search path if s and t are in different metastates
                pred = t # start with last node
                meta_pred = labels_spectral[pred]
                while pred != s and pred != -9999 and meta_s != meta_pred: # while predecessor is not in same cluster as s
                    pred_temp = pred
                    pred = predecessors[s,pred] # get next predessor along path from s->t
                    if pred != -9999 and pred != s : # check that there is a path from s to pred
                        for meta_pred in range(num_clusters): # compute for every possible cluster assignment of pred
                            if meta_s == meta_pred: # increase counter when predecessor along path in same cluster as s
                                if check_path(s, pred, predecessors, labels_spectral): # check that rest of path to s remains in same cluster as s
                                    numpaths_count[pred, meta_pred] += 1 # pred is in same metastate as s which is different from t
                                    numpaths_prob[pred, meta_pred] += np.exp(-shortest_dist[s, t]) # pred is in same metastate as s which is different from t
                                    numpaths_count_pred[pred, pred_temp] += 1

    return (numpaths_count, numpaths_prob, numpaths_count_pred)

# Inputs:
# i - first node
# j - second node
# labels_spectral - np array of cluster ids
# predecessors - predessors[i,j] is predessor of node j on path from i to j
# Outputs:
# done - indicator for whether or not path exists
def check_path(i, j, predecessors, labels_spectral):
    s = j
    path_exists = False
    done = False
    while not done and labels_spectral[i]==labels_spectral[s]:
        s = predecessors[i, s]
        if s == i: # path was found
            path_exists = True
            done = True
        if s == -9999: # path does not exist
            done = True

    return path_exists

# check coverage of non-subgoal states in list
# Inputs:
# state_list_final - list of unique grid worlds
# labels_spectral - np array of cluster ids
# subgoals_list - list of lists of subgoal ids for each cluster
# end_states_list - list of lists of end statss ids for each cluster
# predecessors - predessors[i,j] is predessor of node j on path from i to j
def check_coverage(state_list_final, labels_spectral, subgoals_list, end_states_list, predecessors):
    num_clusters = len(np.unique(labels_spectral))
    for i in range(num_clusters):
        inds = np.where(labels_spectral == i)[0]

        for j in inds:
            path_found = False
            # if j is not a subgoal or end state, check that it is covered by one
            if j not in subgoals_list[i] and j not in end_states_list[i]:
                # first check if covered by a subgoal and then by end states
                for k in subgoals_list[i]+end_states_list[i]:
                    path_exists = check_path(j, k, predecessors, labels_spectral)
                    if path_exists:
                        print('State '+str(j)+' in Cluster '+str(i) + ' is covered by State '+str(k))
                        path_found = True
                        break
            else:
                path_found = True # state is subgoal or end state
            if not path_found:
                print('State '+str(j)+' in Cluster '+str(i) + ' is not covered.')

# compute and return subgoals for each cluster
# Inputs:
# numpaths_count - numpaths_count[i, j] is number of paths through node i in cluster j that go to different cluster from node i
# numpaths_prob - numpaths_prob[i, j] is sum of -log likelihood of paths through node i in cluster j that go to different cluster from node i
# labels_spectral - np array of cluster ids
# diversity - diversity[i,j] is the shortest distance between states i and j
# m - number of subgoals to learn
# rho - penalty on tradeoff between number of paths through a state and diversity with other subgoals
# num_states - number of states
# num_clusters - number of clusters
# use_count - True if to measure subgoals by numpaths_count and False if to measure subgoals by numpaths_prob
# eps_add_goal - percent change required in objective to add a new subgoal
# Outputs:
# subgoals_list - subgoals_list[i] are the subgoal states for cluster i
def get_subgoals(numpaths_count, numpaths_prob, labels_spectral, diversity, m, rho, num_states, num_clusters, use_count = True, eps_add_goal=.2):
    if use_count:
        numpaths = numpaths_count
    else:
        numpaths = numpaths_prob
    subgoals_list = []
    for cluster in range(num_clusters):
        if np.max(numpaths[:,cluster])>0: # only if there is a state that has paths through it to another cluster
            subgoals = [np.argmax(numpaths[:, cluster])] # get first subgoal
            inds = list(np.where(labels_spectral==cluster)[0])
            obj = numpaths[subgoals[0],cluster] # keep track of objective to give stopping condition for number subgoals
            for i in range(m-1): # get m-1 more subgoals
                inds.remove(subgoals[i])
        #       numpaths_count[s, cluster] is number of paths s is on
        #          from on from metastate phi(s) to metastate cluster
                score_temp = np.zeros(numpaths[inds, cluster].shape)
                score_temp_count = np.zeros(numpaths[inds, cluster].shape)
                for j in range(i+1):
                    for k in range(len(inds)):
                        if numpaths[inds[k], cluster] > 0: # if state is a potential subgoal
                            score_temp[k] -= rho*diversity[inds[k], subgoals[j]]
                            score_temp_count[k] += 1
                inds_count = np.where(score_temp_count > 0)[0]
                score_temp[inds_count] = np.divide(score_temp[inds_count], score_temp_count[inds_count]) # normalize the diversity per potential subgoal
                score_temp += numpaths[inds, cluster]
                obj_new = obj + max(score_temp)
                if max(score_temp) > 0 and (obj_new - obj)/obj >= eps_add_goal:
                    subgoals.append(inds[np.argmax(score_temp)])
                    obj = obj_new
                else:
                    break
        else: # no states in this cluster lead to another cluster
            subgoals = []
        print('The subgoals for cluster '+str(cluster)+' are:')
        for i in range(len(subgoals)):
            print('state '+str(subgoals[i]))
        subgoals_list.append(subgoals)
    return subgoals_list

# create path of states for each subgoal
# Inputs:
# labels_spectral - np array of cluster ids
# subgoals_list - subgoal_list[i] is a list of the subgoals for cluster i
# env_list_final - list of unique environments
# transition_matrix - numpy matrix with one-step transition probabilities
# number_steps_final - list of number of steps from env_initial to each environment in env_list_final
# params - dictionary of any extra parameters
# Output: all_paths - all_paths[i] is list of reduced subgoal paths for cluster i
def create_subgoal_paths(labels_spectral, subgoals_list, env_list_final, transition_matrix, number_steps_final, params={}):
    num_clusters = len(np.unique(labels_spectral))
    # for each cluster create a figure of subgoal with path to get to subgoal within cluster
    all_paths = []
    for i in range(num_clusters):
        subgoals = subgoals_list[i]
        subgoals = remove_subgoals_custom_rules(subgoals, env_list_final) # remove subgoals according to custom rules

        path = []
        inds_cluster = np.where(labels_spectral == i)[0]
        for j in range(len(subgoals)):
            subgoal_path = [subgoals[j]]
            node = subgoals[j]
            count = 0
            find_more = True
            while find_more:
                probs = transition_matrix[inds_cluster, node]
                if np.sum(probs) > 0:
                    probs = probs/np.sum(probs)
                ind_pred = random.choices(list(range(len(probs))), list(probs))[0] # select path according to distribution from other nodes in cluster to node
#                ind_pred = np.argmax(transition_matrix[inds_cluster, node]) # find most likely path from another node in cluster to node
                node_pred = inds_cluster[ind_pred]
                if transition_matrix[node_pred, node] > 0 and count < np.max(number_steps_final):
                    subgoal_path.append(node_pred)
                    node = node_pred
                    count += 1
                else:
                    find_more = False

            subgoal_path = list(np.unique(subgoal_path)) # take unique states in the path only
            subgoal_path.reverse()
            subgoal_path.remove(subgoals[j])
            subgoal_path.append(subgoals[j]) # make sure subgoal is last state in the list

            subgoal_path = curate_subgoal_path_custom_rules(subgoal_path, env_list_final, params)
            if len(subgoal_path) > 0:
                path.append(subgoal_path)

        all_paths.append(path)

    #     inds_path = np.where(1.*(shortest_dist[subgoals[j],inds_cluster] < np.infty))[0]
    #     ind_node = np.argmax(subgoals[j]-np.array(number_steps_final)[inds_cluster[inds_path]])
    #     node = inds_cluster[inds_path][ind_node] # this is the node that is the most steps away from the subgoal within the cluster
    #     subgoal_path = [node]
    #     while node != subgoals[j]: # find path from node to subgoal
    #         node = predecessors[subgoals[j], node]
    #         subgoal_path.append(node)
    #     subgoal_path.reverse()
    #     path.append(subgoal_path)
    #     max_path = np.maximum(max_path, len(subgoal_path))
    # all_paths.append(path)

    return all_paths

# Compute number of path state s is on from metastate phi(s) to some other metastate
# assuming metastate for all other states are fixed
# Inputs:
# predessors - predessors[i,j] is predessor of node j on path from i to j
# shortest_dist - shortest_dist[i,j] is shortest distance between nodes i and j
# num_states - number of states in graph
# num_clusters - number of clusters in graph
# gamma - parameter for affinity matrix, i.e. A[i,j] = exp(-gamma*shortest_dist[i,j])
# eta - parameter for regularizing number of outpaths
# projection_type - 1 if L = D^{-1/2}*A*D^{-1/2} (use largest eigenvectors), 2 if L = D^{-1}*A (use largest eigenvectors), 3 if L = D - A (use smallest eigenvectors)
# eps - tolerance for terminating the clustering procedure
# max_iters - maximum number of iterations to run clustering procedure
# Outputs:
# labels_cluster - np array of cluster ids
def cluster_ssx(predecessors, shortest_dist, num_states, num_clusters, gamma, eta, projection_type=1, eps=0.001, max_iters=50):
    affinity_matrix = np.exp(-gamma*shortest_dist)
    for i in range(num_states):
        affinity_matrix[i,i] = 0
    affinity_matrix = (affinity_matrix + np.transpose(affinity_matrix))/2 # make affinity matrix symmetric (because shortest_dist is not symmetric)
    d = np.sum(affinity_matrix, axis=1) # sum over the rows of affinity matrix
    # Numerical fix in case d[i] == 0 for some state i
    inds = np.where(d==0)
    d[inds] = 0.000001
    if projection_type == 1: # L = D^{-1/2}*A*D^{-1/2}
        L = np.matmul(np.matmul(np.diag(np.reciprocal(np.sqrt(d))), affinity_matrix),np.diag(np.reciprocal(np.sqrt(d))))
        E, V = eigsh(L, k=num_clusters, which='LM')
    elif projection_type == 2: # L = D^{-1}*A
        L = np.matmul(np.diag(np.reciprocal(d)), affinity_matrix)
        E, V = eigsh(L, k=num_clusters, which='LM')
    elif projection_type == 3: # L = D - A
        L = np.diag(np.reciprocal(d)) - affinity_matrix
        E, V = eigsh(L, k=num_clusters, which='SM')
    else:
        print("Projection Type unknown for Spectral Clustering")
    Y = normalize(V, norm='l2', axis=1) # These are the projected points to be clustered

    # Run regularized version of K-Means clustering on Y
    # Randomly initilize states to clusters
    assign_clusters = True
    while assign_clusters: # if a cluster is not assigned, redo random clustering
        labels_cluster = np.floor((np.random.random((num_states,))*num_clusters)).astype(int)
        x=np.zeros((num_clusters,1))
        for i in range(num_clusters):
            x[i] = np.sum(labels_cluster == i)
        if np.sum(x > 0) == num_clusters:
            assign_clusters = False
    # Compute centroids
    centroids = np.zeros((num_clusters, num_clusters))
    for i in range(num_clusters):
        inds = np.where(labels_cluster == i)[0]
        centroids[i,:]= np.mean(Y[inds,:], axis=0)
    # Compute initial number of outpaths for each node
    (numpaths_count, numpaths_prob, numpaths_count_pred) = compute_numpaths_count(predecessors, labels_cluster, shortest_dist, num_states, num_clusters)
    # compute initial objective
    obj_curr = 0
    for s in range(num_states):
        obj_curr += np.square(np.linalg.norm(Y[s,:] - centroids[labels_cluster[s], :]))
        obj_curr -= eta*numpaths_prob[s, labels_cluster[s]]
    run_clustering_iteration = True
    iter_count = 0
    while run_clustering_iteration:
        iter_count += 1
        obj_prev = obj_curr
        # Compute number of outpaths for each node for each possible cluster assignment
        (numpaths_count, numpaths_prob, numpaths_count_pred) = compute_numpaths_count_all(predecessors, labels_cluster, shortest_dist, num_states, num_clusters)
        for s in range(num_states):
            obj = np.zeros((num_clusters,1))
            for k in range(num_clusters):
                obj[k] = np.square(np.linalg.norm(Y[s,:] - centroids[k, :])) - eta*numpaths_prob[s,k]
            labels_cluster[s] = np.argmin(obj)
        # Compute current objective
        # Compute new centroids
        centroids = np.zeros((num_clusters, num_clusters))
        for i in range(num_clusters):
            inds = np.where(labels_cluster == i)[0]
            if len(inds) > 0:
                centroids[i,:]= np.mean(Y[inds,:], axis=0)
            else: # no points are close to centroid
                centroids[i,:] = np.inf
        # Compute number of outpaths for each node
        (numpaths_count, numpaths_prob, numpaths_count_pred) = compute_numpaths_count(predecessors, labels_cluster, shortest_dist, num_states, num_clusters)
        obj_curr = 0
        for s in range(num_states):
            obj_curr += np.square(np.linalg.norm(Y[s,:] - centroids[labels_cluster[s], :]))
            obj_curr -= eta*numpaths_prob[s, labels_cluster[s]]

        if np.abs(obj_curr - obj_prev) <= eps or iter_count >= max_iters:
            run_clustering_iteration = False


    # remove cluster ids that have no points assigned
    labels_cluster_ret = labels_cluster.copy()
    labels_unique = np.unique(labels_cluster)
    for i in range(len(labels_unique)):
        inds = np.where(labels_cluster == labels_unique[i])[0]
        labels_cluster_ret[inds] = i
    return labels_cluster_ret

# Create a matrix of Jaccard scores comparing each cluster if first clustering
# result to each cluster in second clustering result
# Inputs:
# label1: labels1[i] is cluster of point i in first clustering result
# label2: labels1[i] is cluster of point i in first clustering result
def jaccard_score(labels1, labels2):
    labels_unique1 = np.unique(labels1)
    labels_unique2 = np.unique(labels2)
    m = len(labels_unique1)
    n = len(labels_unique2)
    scores = np.zeros((m,n))
    for i in range(m):
        for j in range(n):
            c1 = labels_unique1[i]
            c2 = labels_unique2[j]
            inds1 = np.where(labels1==c1)[0]
            inds2 = np.where(labels2==c2)[0]
            num_intersection = len(set.intersection(set(inds1),set(inds2)))
            num_union = len(set.union(set(inds1),set(inds2)))
            scores[i, j] = 1.*num_intersection / num_union
    return scores