import numpy as np

#helper functions
#generates clusters of data using numpy's gaussian function
#shape= num_clusters x num_samples x dimension
def data_generator(num_clusters, dimensions, points_per_cluster,radius=None,means=None,distribution='halfnormal'):                     
    if radius is None:
        radius = np.tile(1,num_clusters)
    if means is None:
        means = np.random.rand(num_clusters,dimensions) #generate random means for each cluster in [0,1]^dimensions (row vector)
    shape = (np.sum(points_per_cluster),dimensions)
    data = np.empty(shape)                          #generate gaussian data for each cluster (cluster x num_samples x dimensions)
    indexes = np.concatenate([[0],np.cumsum(points_per_cluster)])
    for i in range(num_clusters):
        num_samples = points_per_cluster[i]
        #scales by radius/sqrt(dimensions) so that expected distance from mean is approximately radius (using fact that expected distance from mean of multivariate normal is sqrt(dimensions))
        directions = np.random.randn(num_samples,dimensions)
        directions /= np.linalg.norm(directions, axis=1,keepdims=True)#get unit length
        
        #want squared_magnitudes to follow distribution
        if distribution=='halfnormal':
            squared_magnitudes = np.abs(np.random.randn(num_samples,1))#half normal distribution, want
        elif distribution=='exponential':
            squared_magnitudes = np.random.exponential(scale=1, size=(num_samples,1))#numpy uses scale=1/lambda for parameter, expected value of exponential is 1/lambda, so scale = radius^2
        elif distribution=='lomax':
            squared_magnitudes = np.random.pareto(a=1, size=(num_samples,1))#numpy uses lomax for pareto
        
        magnitudes = np.sqrt(squared_magnitudes)
        magnitudes *= radius[i]/np.mean(magnitudes)
        data[indexes[i]:indexes[i+1]] = means[i] + (magnitudes * directions) #generate data points for each cluster
    return means,data

#update list of cluster assignments based on new center (change to newest center if its the new best)
def cluster_reassignment(X,C,nearest_cluster_assignments,potential_center):
    current_centers = C[nearest_cluster_assignments]
    current_distances = np.sum((X - current_centers)**2, axis=1)
    potential_distances = np.sum((X - potential_center)**2, axis=1)

    mask = potential_distances < current_distances
    nearest_cluster_assignments[mask] = C.shape[0] #index of the new center

    return nearest_cluster_assignments

#get the cost if we were to choose this candidate as a seed
def potential_cost(X,C,nearest_cluster_assignments,potential_center):
    new_differences = np.sum((X-potential_center)**2, axis=1)
    if np.any(C==None) or C.shape[0]==0:
        return np.sum(new_differences)
    
    current_differences = cost(X,C,nearest_cluster_assignments)
    best = np.minimum(new_differences,current_differences) #choose best of the two elementwise
    return np.sum(best)

#cost of data X with respect to the centers C (kmeans objective function)
def cost(X, C, nearest_cluster_assignments):
    if np.any(nearest_cluster_assignments==None) or nearest_cluster_assignments.shape[0]==0: #if we don't know assignments compute long way
        # Reshape X to (N, 1, D) and C to (1, K, D) for broadcasting.
        X_reshaped = X[:, np.newaxis, :]
        C_reshaped = C[np.newaxis, :, :]

        squared_distances = np.sum((X_reshaped - C_reshaped)**2, axis=2)
        return np.min(squared_distances, axis=1)
    #assumes nearest_cluster_assignments is maintained with the nearest seed for each data point
    return np.sum((X - C[nearest_cluster_assignments])**2, axis=1)

#generate 1 seed
def oneSeeder(data,centers,num_candidates,cluster_assignments):
    #compute weighting proportional to the cost of it
    if(np.any(centers==None) or centers.shape[0]==0):#first seed edge case
        centers = np.empty(shape=(1,data.shape[1]))
        candidates = np.unique(np.random.randint(low=0,high=data.shape[0],size=num_candidates))
        potential_costs = np.array([potential_cost(data,None,None,data[candidate]) for candidate in candidates])  #cost if we were to choose candidate center
        greedy_choice = candidates[np.argmin(potential_costs)]
        centers[0] = data[greedy_choice]
        cluster_assignments=np.zeros(data.shape[0],dtype=int)

        return greedy_choice,centers,cluster_assignments
    
    price = cost(data,centers,cluster_assignments)
    proportional_probability = price/price.sum()
    cumulative_probability = np.cumsum(proportional_probability)
    
    samples = np.random.rand(num_candidates)
    candidates = np.array([np.where(sample<=cumulative_probability)[0][0] for sample in samples])                           #candidate data points
    if num_candidates>1:
        #sample from computed weighting
        potential_costs = np.array([potential_cost(data,centers,cluster_assignments,data[candidate]) for candidate in candidates])  #cost if we were to choose candidate center
        greedy_choice = candidates[np.argmin(potential_costs)]                                                                  #greedily choose candidate center that minimizes next cost
    else:
        greedy_choice = candidates[0]                                                                                          #if num_candidates is 1, just choose first candidate
    cluster_assignments = cluster_reassignment(data,centers,cluster_assignments,data[greedy_choice])       #figure out reassignments
    centers = np.append(centers,[data[greedy_choice]],axis=0)                                                               #add greedy choice to list of chosen centers
    return greedy_choice,centers,cluster_assignments

def seed_quality_metric(covered_clusters,chosen_cluster):
    return int(covered_clusters[chosen_cluster]==0)     #return 1 if cluster has not been covered yet, 0 if it has been covered