import sklearn
from numba import jit
import sys 
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
import random
import matplotlib.pyplot as plt
import networkx as nx


# return k means cost given centers
def k_means_cost(points, centers):
    distance = euclidean_distances(points, centers)
    distance = distance**2
    labels = np.argmin(distance, axis=1)
    return labels, np.min(distance, axis = 1).sum()

# return k means cost given labels of points
def kmeans_cost_label(points, labels, d, num_labels):
    centers = np.zeros((num_labels, d))
    good_indices = []
    for i in range(num_labels):
        to_index = np.where(labels == i)[0]
        if len(to_index) > 0:
            curr_points = points[to_index]
            centers[i,:] = np.average(curr_points, axis = 0)
            good_indices.append(i)
        else:
            pass
    centers = centers[good_indices,:]
        
    return k_means_cost(points, centers)
       

# k means ++ given distance matrix
def k_means_pp(X, k, distances, n):
    sampled_p = int(np.random.choice(n, size=1))
    C = [sampled_p]
    for i in range(k-1):
        dist_to_centers = distances[:, C].min(axis = 1)
        probabilities = dist_to_centers/dist_to_centers.sum()
        sampled_p = int(np.random.choice(n, p=probabilities, size=1))
        while sampled_p in C:
            sampled_p = int(np.random.choice(n, p=probabilities, size=1))
        C.append(sampled_p)
        
    
    # given centers, get labels
    labels = np.argmin(distances[:, C], axis = 1)    
    return labels

def distance(p1, p2): 
    return np.sum((p1 - p2)**2) 

# k means ++ with no distance matrix
def kpp2(data, k, d): 
    ''' 
    initialized the centroids for K-means++ 
    inputs: 
        data - numpy array of data points
        k - number of clusters  
    '''
    ## initialize the centroids list and add 
    ## a randomly selected data point to the list 
    n = data.shape[0]
    centroids = np.zeros((k, d))
    centroids_idx = [np.random.randint(n)]
    centroids[0,:] = data[centroids_idx[0], :]
   
    ## compute remaining k - 1 centroids 
    for c_id in range(k - 1): 
          
        ## initialize a list to store distances of data 
        ## points from nearest centroid 
        dist = [] 
        for i in range(data.shape[0]): 
            point = data[i, :] 
            d = sys.maxsize 
              
            ## compute distance of 'point' from each of the previously 
            ## selected centroid and store the minimum distance 
            for j in range(len(centroids)): 
                temp_dist = distance(point, centroids[j]) 
                d = min(d, temp_dist) 
            dist.append(d) 
              
        ## select data point with maximum distance as our next centroid 
        dist = np.array(dist) 
        probabilities = dist/dist.sum()
        sampled_center = int(np.random.choice(n, p=probabilities, size=1))
        while sampled_center in centroids_idx:
            sampled_center = int(np.random.choice(n, p=probabilities, size=1))
        centroids_idx.append(sampled_center)
        next_centroid = data[sampled_center, :] 
        centroids[c_id+1,:] = next_centroid
        dist = [] 
    return centroids 
        

# noisy oracle
def noisy_oracle(true_labels, index, num_labels, prob_error):
    if random.random() <= prob_error:
        return random.randint(0, num_labels-1)
    else:
        return true_labels[index]


# faster version of algorithm 2 using numba jit, tested with this version
@jit(nopython=True)
def algo2new(points, eps):
    
    n = len(points)
    
    if n <= 10:
        return points.mean()
    
    to_return = 0.0
    for i in range(25):
        points = np.random.permutation(points)
        X1 = points[:n//2]
        X2 = points[n//2:]
        X1 = np.sort(X1)

        counter = int((1-5*eps)*(n//2))
        
        if counter == 1:
            to_return += X2.mean()
        else:
            X1_left = X1[:-counter+1]
            X1_right = X1[counter-1:]

            good_indx = np.argmin(X1_right-X1_left)
            a = X1_left[good_indx]
            b = X1_right[good_indx]
            to_index = np.where((a <= X2) & (X2 <= b))[0]
            if len(to_index) == 0:
                to_return += 0.0
            else:
                to_return += X2[to_index].mean()
    return to_return/25.0
    


# algorithm 2 from paper without jit, not tested
def algo2(points, eps):
    n = len(points)
    
    to_return = 0.0
    if n < 10:
        return sum(points)/n

    for i in range(25):
        # randomly partition points into two groups of equal size
        points = np.random.permutation(points)
        X1 = points[:n//2]
        X2 = points[n//2:]
        X1 = np.sort(X1)

        # find interval of X1 with (1-eps) fraction of points
        # call this interval [a,b]
        counter = int((1-5*eps)*(n//2))
        curr_len = float('inf')
        a = 0
        b = 0
        for i in range(n//2-counter+1):
            curr_int_left = X1[i]
            curr_int_right = X1[i+counter-1]
            if curr_int_right -  curr_int_left < curr_len:
                a = curr_int_left
                b = curr_int_right
                curr_len = b - a
        X2_filtered = [x for x in X2 if a <= x <= b]


        # return average of points in X2 that are in [a,b]
        if len(X2_filtered) == 0:
            to_return += 0.0
        else:
            to_return += sum(X2_filtered)/len(X2_filtered)
    return to_return/25.0

# main algo of paper
def algo1(points, oracle_labels, k, eps):
    n,d = points.shape
    centers = np.zeros((k, d))
    labels_so_far = []

    # loop over each label
    for i in range(k):

        # get labels that haven't been processed so far
        good_indices = np.where(~np.isin(oracle_labels, labels_so_far))[0]
        curr_labels = oracle_labels[good_indices]
        
        if len(curr_labels) > 0:

            # get most common label
            label_counts = np.bincount(curr_labels)
            most_common_label = np.argmax(label_counts)
            points_with_labels = points[np.where(oracle_labels == most_common_label)[0]]


            # for most common label, loop over each dimension and run alg 2
            for j in range(d):
                curr_dim_points = points_with_labels[:,j]
                curr_dim_center = algo2new(curr_dim_points, eps)
                centers[most_common_label, j] = curr_dim_center
            
            labels_so_far.append(most_common_label)

                
        else:
            pass
    return centers

#sampling baseline
def sampling_baseline(points, labels, d, num_labels, rate = 50):
    rate = rate/100.0
    centers = np.zeros((num_labels, d))
    good_indices = []
    for i in range(num_labels):
        to_index = np.where(labels == i)[0]
        size_to_keep = int(rate*len(to_index))
        if size_to_keep > 0:
            to_index = np.random.permutation(to_index)
            curr_points = points[to_index[:size_to_keep]]
            centers[i,:] = np.average(curr_points, axis = 0)
            good_indices.append(i)
        else:
            pass
    centers = centers[good_indices,:]
    return k_means_cost(points, centers)



