import numpy as np
import math
from scipy.spatial.distance import squareform, pdist
from matplotlib import pyplot as plt

def neighborsMat(D, r):
    '''
    Given a distance matrix and radius r, return an adjacency matrix
    with edges between points within distance r
    '''
    n = D.shape[0]
    adjMat = np.zeros((n,n), dtype=bool)
    for i in range(n):
        adjMat[i,:] = (D[i,:] <= r)
    return adjMat

def ballCarving(D, r, C1=5, C2=10):
    '''
    Return a set of clusters formed by ball carving with a given radius r
    and then increasing shells
    '''
    n = D.shape[0]
    nbrMat = neighborsMat(D, r)
    farMat = neighborsMat(D, C1*r) 
    numNbrs = np.sum(nbrMat, axis=0)
    centers = []
    centerSizes = []
    clustering = -1 * np.ones(n)

    while np.max(numNbrs) >= 0: # finding initial balls
        center = np.argmax(numNbrs) # find center maximizing #points within r
        centers.append(center)
        clustering[nbrMat[center]] = center
        centerSizes.append(np.sum(nbrMat[center]))
        numNbrs[farMat[center]] = -1 # outlawing future centers within 5r 

    for i, center in enumerate(centers): # expanding shells
        delta = math.log(n, 2)
        radius = r + delta
        while radius < 2*r:
            pointsInShell = D[center,:] <= radius
            clustering[pointsInShell] = center
            if np.sum(pointsInShell) < 2 * centerSizes[i]:
                break
            centerSizes[i] = np.sum(pointsInShell)
    
    for center in centers: # assigning remaining points
        clustering[(D[center,:] <= C2*r) * (clustering == -1)] = center
    
    assert(np.min(clustering) >= 0) # all points clustered
    return centers, clustering


def kCenter(D, k):
    '''
    Return a set of 1/2 approximate k-centers
    '''
    centers = [-1 for i in range(k)]
    centers[0] = 0 
    for i in range(1, k):
        distToCenters = np.min(D[centers[:i]], axis=0)
        centers[i] = np.argmax(distToCenters)
    
    distToCenters = np.min(D[centers[:i]], axis=0)
    r = np.max(distToCenters)
    return centers, r


def approxIP(D, k, C1=5, C2=10, C3=1/100):
    '''
    Return a O(log n) approximate IP-stable clustering (length n vector)
    '''
    if len(D.shape) == 1: # Turn D into n x n matrix
        D = squareform(D)
    
    n = D.shape[0]
    clustering = -1 * np.ones(n)

    centers, radius = kCenter(D, k)
    ballCenters, ballClustering = ballCarving(D, radius*C3, C1=C1, C2=C2)
    for ball in ballCenters:
        closestCenter = centers[np.argmin(D[ball, centers])]
        clustering[ballClustering == ball] = closestCenter

    assert(np.min(clustering) >= 0)
    return centers, clustering


def test():
    # TODO: create some articial data to test
    n = 1000
    d = 2
    points = np.random.uniform(size=n*d).reshape(n, d)
    D = squareform(pdist(points, metric='euclidean'))
    nbrMat = neighborsMat(D, 0.1)
    # plt.figure()
    # plt.scatter(points[:,0], points[:,1])
    # plt.scatter(points[nbrMat[0,:],0], points[nbrMat[0,:],1], color='orange')
    # plt.scatter(points[0,0], points[0,1], color='red')
    # plt.show()

    kCenters, radius = kCenter(D, 5)

    centers, clustering = ballCarving(D, radius/100)
    plt.figure()
    colors='bgrcmk'
    for i, center in enumerate(centers):
        plt.scatter(points[clustering == center,0], points[clustering == center,1], marker='o', color=colors[i % 6], alpha=0.3)
        plt.scatter(points[center,0], points[center,1], marker='*', color=colors[i % 6])
    # plt.show()

    # print(radius)
    # plt.figure()
    
    # plt.scatter(points[:,0], points[:,1], color='blue')
    # plt.scatter(points[kCenters,0], points[kCenters,1], color='orange')

    centers, clustering = approxIP(D, 5)
    plt.figure(figsize=(10,8))
    colors='bgrcmk'
    for i, center in enumerate(centers):
        plt.scatter(points[clustering == center,0], points[clustering == center,1], marker='o', color=colors[i % 6], alpha=0.5)
        plt.scatter(points[center,0], points[center,1], marker='*', color=colors[i % 6])
    plt.show()

if __name__ == '__main__':
    test()
