import numpy as np
from helper_functions import *
from concurrent.futures import ThreadPoolExecutor, as_completed

#define threaded_experiment
def run_single_ri(r,dimension,points_per_cluster, num_clusters, distribution, candidates, R):
    means, data = data_generator(
        num_clusters=num_clusters, dimensions=dimension,
        points_per_cluster=points_per_cluster, radius=R,
        means=None, distribution=distribution)

    local_coverage = np.zeros(num_clusters)
    local_objective = np.zeros(num_clusters)
    covered_clusters = np.zeros(num_clusters)

    seeded_centers = None
    assignments = None
    cluster_ending_index = np.cumsum(points_per_cluster)        #stores where each cluster ends (exclusive) to account for potentially varying sizes
    for c in range(num_clusters):
        greedy_choice, seeded_centers, assignments = oneSeeder(
            data,
            centers=seeded_centers,
            num_candidates=candidates[c],
            cluster_assignments=assignments)

        chosen_cluster = np.where(greedy_choice<cluster_ending_index)[0][0]
        local_coverage[c] = seed_quality_metric(covered_clusters, chosen_cluster)
        covered_clusters[chosen_cluster] += 1
        local_objective[c] = np.sum(cost(data, seeded_centers, assignments))

    return r, local_coverage, local_objective

def threaded_experiment(repetitions=100, dimension=8, num_clusters=8,
               distribution='halfnormal', points_per_cluster=None, candidates=None, R=None):
    if candidates is None:
        candidates = np.tile(2, num_clusters)
    
    elif candidates.shape[0] != num_clusters:
        print(f'Invalid candidate set. Need {num_clusters} candidates, but have {candidates.shape[0]}')
        return

    if num_clusters != dimension:
        print('Altering dimension so 1 cluster per dimension')
        dimension = num_clusters
    
    if points_per_cluster is None:
        points_per_cluster = np.random.randint(low=64,high=256,size=num_clusters)
        
    if R is None:
        R = 2*np.random.rand(num_clusters) #uniform (0,2)
        
    # Preallocate shared output arrays
    coverage_counts = np.zeros((num_clusters, repetitions))
    kmeans_objective = np.zeros((num_clusters, repetitions))
    i=0
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(run_single_ri, r,dimension,points_per_cluster, num_clusters, distribution, candidates, R)
                   for r in range(repetitions)]

        for future in as_completed(futures):
            r, coverage, obj = future.result()
            coverage_counts[:, r] = coverage
            kmeans_objective[:, r] = obj
            i+=1
            if(i%(repetitions//100)==0):
                print(f"Completed {100*i/repetitions:.2f}%",end='\r')

    coverage_probability_per_cluster = np.mean(coverage_counts, axis=1)
    average_objective_per_cluster = np.mean(kmeans_objective, axis=1)
    return np.asarray((coverage_probability_per_cluster,average_objective_per_cluster))

if __name__ == '__main__':
    run_single_ri(r=1,dimension=8,
                  points_per_cluster = np.random.randint(low=64,high=256,size=8),
                  num_clusters=8,distribution='halfnormal',
                  candidates = np.tile(int(np.ceil(np.log2(8)))+2, 8),
                  R = 2*np.random.rand(8))