import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import load_digits
import random

random.seed(20)  # Set seed to any integer
np.random.seed(20)

def sample_uniform(S, k):
    """
    Sample k points uniformly at random from S (with replacement).

    Parameters:
        S (list or np.ndarray): Candidate points, shape (n, d)
        k (int): Number of points to sample

    Returns:
        list: k sampled points from S
    """
    S = np.array(S)

    if len(S) == 0 or k == 0:
        return []

    # Assign equal probability to each point
    probs = np.ones(len(S)) / len(S)

    # Sample k points from S with replacement according to uniform distribution
    indices = np.random.choice(len(S), size=k, replace=True, p=probs)
    return S[indices].tolist()

def sample_by_distance(S, C, k):
    """
    Sample k points from S with probability proportional to distance from C.
    
    Parameters:
        S (list or np.ndarray): Candidate points, shape (n, d)
        C (list or np.ndarray): Center points, shape (m, d)
        k (int): Number of points to sample

    Returns:
        list: k sampled points from S
    """
    S = np.array(S)
    C = np.array(C)

    if len(S) == 0 or k == 0:
        return []

    if len(C) == 0:
        # If no centers, sample uniformly
        probs = np.ones(len(S))
    else:
        # Compute distance from each point in S to its nearest center in C
        dists = np.min(np.linalg.norm(S[:, np.newaxis] - C, axis=2), axis=1)
        probs = dists

    # Normalize to get probabilities
    total = np.sum(probs)
    if total == 0:
        probs = np.ones(len(S)) / len(S)  # fallback to uniform
    else:
        probs = probs / total

    # Sample k points from S with replacement according to probs
    indices = np.random.choice(len(S), size=k, replace=True, p=probs)
    return S[indices].tolist()

def count_within_radius(S, C, r):
    """
    Count the number of points in S within distance r from any point in C.

    Parameters:
        S (list or np.ndarray): Points to check, shape (n, d)
        C (list or np.ndarray): Center points, shape (m, d)
        r (float): Radius threshold

    Returns:
        int: Number of points in S within distance r from some center in C
    """
    S = np.array(S)
    C = np.array(C)

    if len(S) == 0 or len(C) == 0:
        return 0 # No points or centers → no coverage

    # Compute distances from each point in S to all points in C
    dists = np.linalg.norm(S[:, np.newaxis] - C, axis=2)

    # Find the minimum distance to a center for each point in S
    min_dists = np.min(dists, axis=1)

    # Count how many distances are <= r
    return np.sum(min_dists <= r)

def remove_within_radius(S, C, r):
    """
    Remove points in S that are within distance r from any point in C.

    Parameters:
        S (list or np.ndarray): Candidate points, shape (n, d)
        C (list or np.ndarray): Center points, shape (m, d)
        r (float): Distance threshold

    Returns:
        list: Filtered list of points in S that are farther than r from all points in C
    """
    S = np.array(S)
    C = np.array(C)

    if len(S) == 0 or len(C) == 0:
        return S.tolist()  # No points or centers → return all of S

    # Compute distances between each point in S and all centers in C
    dists = np.linalg.norm(S[:, np.newaxis] - C, axis=2)
    
    # Minimum distance to any center for each point in S
    min_dists = np.min(dists, axis=1)

    # Filter points whose min distance > r
    filtered_S = S[min_dists > r]

    return filtered_S.tolist()

def assign_labels(S, C):
    """
    Assigns each point in S to the index of the closest center in C.

    Parameters:
    - S (ndarray): An (n, d) array of n points in d-dimensional space.
    - C (ndarray): A (k, d) array of k center points.

    Returns:
    - labels (ndarray): An array of length n where labels[i] is the index j of the closest center C[j] to S[i].
    """
    return np.argmin(np.linalg.norm(S[:, np.newaxis] - C, axis=2), axis=1)

def mettu_plaxton(points, k, c):
    """
    Mettu–Plaxton-style algorithm:
    - Runs O(log n) rounds
    - Samples k points per round using sample_by_distance

    Parameters:
        points (np.ndarray): Array of shape (n, d)
        k (int): Number of samples per round
        constant c

    Returns:
        np.ndarray: Array of shape (k * log n, d) with sampled candidates
    """
    n = len(points)
    log_n = max(1, int(c*np.ceil(np.log2(n))))
    candidates = points.copy()
    centers = []

    for _ in range(log_n):
        new_samples = sample_uniform(candidates, k)
        centers.extend(new_samples)
        r = 0.1
        while(count_within_radius(candidates, centers, r) < len(candidates)/2):
            r = r*1.05
        candidates = remove_within_radius(candidates, centers, r)

    return np.array(centers)

def adaptive_sampling(points, k, c):
    """
    Adaptive sampling algorithm:
    - Runs O(k) rounds
    - Samples 1 points per round using sample_by_distance

    Parameters:
        points (np.ndarray): Array of shape (n, d)
        k (int): Number of samples per round
        constant c

    Returns:
        np.ndarray: Array of shape (k * log n, d) with sampled candidates
    """
    n = len(points)
    candidates = points.copy()
    centers = []

    for _ in range(k*c):
        new_samples = sample_by_distance(candidates, centers, 1)
        centers.extend(new_samples)
    return np.array(centers)

def generate_gaussian_points(n, k, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)

    points = []
    labels = []

    for i in range(k):
        # Randomly generate mean in range [-10, 10] for each Gaussian
        mean = np.random.uniform(-10, 10, size=2)
        
        # Create a random positive-definite covariance matrix
        A = np.random.rand(2, 2)
        cov = np.dot(A, A.T) + np.eye(2) * 0.5  # ensures it's positive-definite

        # Sample n points from this Gaussian
        samples = np.random.multivariate_normal(mean, cov, n)
        points.append(samples)
        labels.extend([i] * n)

    # Stack all the points together
    all_points = np.vstack(points)
    all_labels = np.array(labels)
    return all_points, all_labels

def round_vector_to_nearest_power(v, q):
    """
    Rounds each coordinate of vector v to the nearest power of q.
    
    - Positive values → nearest positive power of q
    - Negative values → nearest negative power of q
    - Zero stays zero

    Parameters:
    - v: 1D NumPy array of real numbers
    - q: base > 1

    Returns:
    - A NumPy array with each value rounded accordingly
    """
    v = np.asarray(v)
    if q <= 1:
        raise ValueError("q must be > 1")

    result = np.zeros_like(v)

    # Non-zero values
    nonzero_mask = v != 0
    abs_vals = np.abs(v[nonzero_mask])
    exponents = np.round(np.log(abs_vals) / np.log(q))
    powers = np.power(q, exponents)

    result[nonzero_mask] = np.sign(v[nonzero_mask]) * powers
    return result

def kmeans_cost(S, L, C):
    """
    Compute the k-means cost for a set of points S, label assignments L, and centers C.
    
    Parameters:
        S (ndarray): An (n, d) array of n points in d dimensions.
        L (ndarray): A (n,) array of integer labels assigning each point to a center.
        C (ndarray): A (k, d) array of k cluster centers.
        
    Returns:
        float: The total k-means cost (sum of squared distances to assigned centers).
    """
    diffs = S - C[L]  # Subtract assigned centers
    squared_distances = np.sum(diffs**2, axis=1)  # Squared L2 norm per point
    return np.sum(squared_distances)

digits = load_digits()
n = len(digits.data)
d = len(digits.data[0])
k = 10
c = 2

all_mp_costs = []
all_as_costs = []
all_rounded_as_costs = []
all_mp_comms = []
all_as_comms = []
all_rounded_as_comms = []

for c in range(11,21):
    #Sample points for coresets
    mp_centers = mettu_plaxton(digits.data, k, c)
    as_centers = adaptive_sampling(digits.data, k, c)
    q = 2

    rounded_as_centers = round_vector_to_nearest_power(as_centers, q)

    mp_labels = assign_labels(digits.data, mp_centers)
    as_labels = assign_labels(digits.data, as_centers)
    rounded_as_labels = assign_labels(digits.data, rounded_as_centers)

    #Construct coresets
    #Give the proper weight to each sample
    
    S_mp = mp_centers[mp_labels]
    S_as = as_centers[mp_labels]
    S_rounded_as = rounded_as_centers[mp_labels]

    kmeans = KMeans(n_clusters=k, init='k-means++', random_state=42)
    kmeans.fit(S_mp)  # S is your (n, d) array of data points

    mp_labels = kmeans.labels_       # labels[i] is the index of the cluster assigned to S[i]
    mp_centers = kmeans.cluster_centers_  # the resulting cluster centers
    mp_cost = kmeans.inertia_
    
    kmeans.fit(S_as)  # S is your (n, d) array of data points

    as_labels = kmeans.labels_       # labels[i] is the index of the cluster assigned to S[i]
    as_centers = kmeans.cluster_centers_  # the resulting cluster centers
    as_cost = kmeans.inertia_
    
    kmeans.fit(S_rounded_as)  # S is your (n, d) array of data points

    rounded_as_labels = kmeans.labels_       # labels[i] is the index of the cluster assigned to S[i]
    rounded_as_centers = kmeans.cluster_centers_  # the resulting cluster centers
    rounded_as_cost = kmeans.inertia_
    
    #Compute costs of coresets

    #rounded_as_cost = kmeans_cost(digits.data, rounded_as_labels, rounded_as_centers)
    #as_cost = kmeans_cost(digits.data, as_labels, as_centers)
    #mp_cost = kmeans_cost(digits.data, mp_labels, mp_centers)

    all_mp_costs.append(mp_cost)
    all_as_costs.append(as_cost)
    all_rounded_as_costs.append(rounded_as_cost)

    all_mp_comms.append(len(mp_centers)*32)
    all_as_comms.append(len(as_centers)*32)
    all_rounded_as_comms.append(len(as_centers)*5)

cs = [11,12,13,14,15,16,17,18,19,20]

##plt.plot(cs, all_mp_comms, marker='o', markerfacecolor='none', label="MP Comms")
##plt.plot(cs, all_as_comms, marker='^', label="AS Comms")
##plt.plot(cs, all_rounded_as_comms, marker='x', label="EAS Comms")
##
### X-axis from 0 to 10 with ticks at every integer
##plt.xlim(11, 20)
##plt.xticks(np.arange(11, 21, 1))  # Ticks from 11 to 20
##
##plt.xlabel('Sampling Coefficient')
##plt.ylabel('Communication (Bits)')
##plt.title('Communication Costs')
##plt.legend()
##plt.grid(True)
##plt.show()

plt.plot(cs, all_mp_costs, marker='o', markerfacecolor='none', label="MP Costs")
plt.plot(cs, all_as_costs, marker='^', label="AS Costs")
plt.plot(cs, all_rounded_as_costs, marker='x', label="EAS Costs")

# X-axis from 11 to 20 with ticks at every integer
plt.xlim(11, 20)
plt.xticks(np.arange(11, 21, 1))  # Ticks from 11 to 20

plt.xlabel('Sampling Coefficient')
plt.ylabel('Cost')
plt.title('Clustering Costs')
plt.legend()
plt.grid(True)
plt.show()
