import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels

# MARK: -lazy Kernel init
def initialize_kernel_kmeans_plusplus_lazy(X_data, k,kernel_function, random_gen=None):
    # fill K with Nans
    K = np.full((X_data.shape[0],X_data.shape[0]),np.nan)
    if random_gen is None:
        random_gen = np.random.default_rng()

    n = X.shape[0]
    X_indices = list(range(n))  # List of indices
    C = []  # Ids of the points that have already been selected

    # Add a uniformly chosen point from X to C without replacement using random_gen
    rand_idx = random_gen.choice(len(X_indices), 1, replace=False)
    # compute the corresponding row/column of the kernel matrix
    K[:,rand_idx] = kernel_function(X_data,X_data[rand_idx])
    K[rand_idx,:] = K[:,rand_idx].T
    C.append(rand_idx.item())

    # compute the diagonal of K
    np.fill_diagonal(K,np.array([kernel_function(X_data[i,:][np.newaxis,:],X_data[i,:][np.newaxis,:]).item() for i in range(n)]))

    # Distance squared: ||phi(x)-c_1||^2 = K(x,x) + K(c_1,c_1) - 2K(x,c_1)
    diag = K.diagonal()[:, np.newaxis]


    X_to_Cs_distance_squared_helper = None
    X_to_Cs_distance_squared = (diag + (K[rand_idx,rand_idx] - 2 * K[:,rand_idx])).flatten()
    label_vector = np.ones(n, dtype=int) * rand_idx.item()
    for i in range(1, k):
        cost = X_to_Cs_distance_squared.sum()
        if cost <= 0:
            raise ValueError("Sum of squared distances is non-positive, check kernel calculations.")

        probs = X_to_Cs_distance_squared / cost
        if np.any(probs < 0) or not np.isfinite(probs).all():
            raise ValueError("Computed probabilities contain non-positive or non-finite values.")
        #print(probs)
        rand_idx = random_gen.choice(len(X_indices), 1, p = probs)
        C.append(rand_idx.item())

        # compute the kernel entries for the new center:
        K[:,rand_idx] = kernel_function(X_data,X_data[rand_idx])
        K[rand_idx,:] = K[:,rand_idx].T

        # Update distances for the newly chosen center
        X_to_Cs_distance_squared_helper = (diag + K[rand_idx,rand_idx] - 2 * K[:,rand_idx]).flatten()

        # if the distance in the helper matrix is smaller, update the label and the distance
        mask = X_to_Cs_distance_squared_helper < X_to_Cs_distance_squared
        label_vector[mask] = rand_idx.item()
        X_to_Cs_distance_squared[mask] = X_to_Cs_distance_squared_helper[mask]

    # label vector currently contains the indices of the respective centers.
    # for ease of use, we map these to 0,1,...,k-1 instead
    label_vector = np.array([C.index(l) for l in label_vector])

    return label_vector, X_to_Cs_distance_squared, C, K


def initialize_kernel_kmeans_plusplus(K, k, random_gen=None):
    # Return the labels which correspond to the result of running KMeans++ initialization
    # in the kernel space.
    if random_gen is None:
        random_gen = np.random.default_rng()

    n = K.shape[0]
    X = list(range(n))  # List of indices
    C = []  # Ids of the points that have already been selected

    # Add a uniformly chosen point from X to C without replacement using random_gen
    rand_idx = random_gen.choice(len(X), 1, replace=False)
    C.append(rand_idx.item())

    # Distance squared: ||phi(x)-c_1||^2 = K(x,x) + K(c_1,c_1) - 2K(x,c_1)
    # Precompute diagonal terms K(x,x) and K(c_1,c_1)
    diagonal_K = K.diagonal()  # Extract the diagonal of K
    X_to_Cs_distance_squared = np.tile(diagonal_K[:, np.newaxis], (1, k))
    X_to_Cs_distance_squared[:, 0] += (K[rand_idx, rand_idx] - 2 * K[:, rand_idx].flatten())

    dist_squared_vector = X_to_Cs_distance_squared[:, 0]

    # Compute probabilities and choose new center points
    for i in range(1, k):
        cost = dist_squared_vector.sum()
        probs = dist_squared_vector / cost
        rand_idx = random_gen.choice(len(X), 1, p=probs)
        C.append(rand_idx.item())

        # Update distances for each newly chosen center
        X_to_Cs_distance_squared[:, i] += (K[rand_idx, rand_idx] - 2 * K[:, rand_idx].flatten())
        dist_squared_vector = np.minimum(dist_squared_vector, X_to_Cs_distance_squared[:, i])

    labels = X_to_Cs_distance_squared.argmin(axis=1)
    distances = dist_squared_vector
    
    return labels, distances, C

# MARK: -KMBKM class
class KernelMiniBatchKMeans:
    def __init__(self, n_clusters, batch_size, n_iterations,  new_lr=False, full_batch =False, use_coreset=False, random_state=None, lazy=False):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        # print("batch size in constructor: ", batch_size)
        self.n_iterations = n_iterations
        self.full_batch = full_batch
        self.final_inertia = 0
        self.new_lr=new_lr #new learning rate flag
        self.random_state = random_state
        self.lazy = lazy
    

    def compute_new_labels_and_distances(
        self, K, labels, batch, X_to_C_inner_products, C_to_C_inner_products, X_to_implied_centers, weight_history=None):
        
        k = self.n_clusters
        # compute the distances squared between the input points and the new induced cluster centers
        # and update X_to_C_inner_products, C_to_C_inner_products, and X_to_implied_centers in place.
        B = len(batch)  # Since all weights are 1, this is just the count of items in batch
        labels_batch = labels.copy()

        # set all labels not in the batch to -1:
        labels_batch[np.isin(range(K.shape[0]), batch, invert=True)] = -1

        phi_a_phi_a_term = K.diagonal()
        for j in range(k):
            mask = labels_batch == j
            if not mask.any():
                # print(f"Cluster {j} is empty, skipping.")
                continue
            
            learning_rate = 0
            b = mask.sum()  # Sum of weights where mask is true, which is just the count of True values in mask
            if self.new_lr:
                learning_rate = np.sqrt(b / B)
            else:
                weight_history[j] += b
                learning_rate = b / weight_history[j]
            
            if self.full_batch:
                learning_rate = 1
            
            denom = b
            denom_squared = denom * denom

            phi_a_cm_batch = (K[:, mask].sum(axis=1) / denom)

            # Update formulas simplify as weights are 1
            C_cm_batch_term = (X_to_C_inner_products[mask, j]).sum() / denom

            cm_batch_cm_batch_term = (K[mask][:, mask].sum() / denom_squared)

            # update the inner products between the data points and the new cluster center(s)
            X_to_C_inner_products[:, j] = (1 - learning_rate) * X_to_C_inner_products[:, j] + learning_rate * phi_a_cm_batch

            C_to_C_inner_products[j] = (1 - learning_rate) ** 2 * C_to_C_inner_products[j] \
                                       + 2 * (1 - learning_rate) * learning_rate * C_cm_batch_term \
                                       + learning_rate ** 2 * cm_batch_cm_batch_term
            
            X_to_implied_centers[:, j] = phi_a_phi_a_term - 2 * X_to_C_inner_products[:, j] + C_to_C_inner_products[j]

        new_labels = X_to_implied_centers.argmin(axis=1)
        X_to_closest_center = X_to_implied_centers.min(axis=1)

        return new_labels, X_to_closest_center


    def get_inertia(self,distances_squared):
        return distances_squared.sum()

    # # MARK: -Lazy fit
    # def fit_lazy(self, X, kernel_function):
    #     rng = np.random.default_rng(self.random_state)
    #     n_iterations = self.n_iterations
    #     batch_size = self.batch_size
    #     # print(f"batch size in fit: {batch_size}")
    #     if self.full_batch:
    #         batch_size = X.shape[0]
        

    #     #kernel_function = lambda x,y: pairwise_kernels(x,y,metric="rbf",gamma=1/(2*kernel_sigma**2))
    #     # first we compute the kernel kmeans++ initialization
    #     labels, distances_squared, C, K = initialize_kernel_kmeans_plusplus_lazy(X, self.n_clusters,kernel_function, rng)
        
    #     # track which parts of the kernel matrix have already been computed
    #     #computed_kernel_indices = np.zeros(X.shape[0], dtype=bool)
    #     #computed_kernel_indices[C] = True


    #     X_to_C_inner_products = K[:, C].copy()
    #     C_to_C_inner_products = K[C][:, C].copy().ravel()
    #     print(C_to_C_inner_products.shape)
    #     print(C_to_C_inner_products)
    #     assert 1==2
    #     X_to_implied_centers = np.zeros((K.shape[0], self.n_clusters))

    #     weight_history = None
    #     if not self.new_lr:
    #         weight_history = np.zeros(self.n_clusters)

    #     #pbar = tqdm(range(n_iterations))
    #     #presample batches and compute kernel entries at once
    #     batches = []
    #     unique_indices = set()
    #     for i in range(n_iterations):
    #         batch = rng.choice(X.shape[0],batch_size,replace=False)
    #         batches.append(batch)
    #         unique_indices.update(set(batch)) 
        
    #     unique_indices = list(unique_indices)
    #     #K = kernel_function(X[unique_indices],X)
    #     K[:,unique_indices] = kernel_function(X,X[unique_indices])
    #     K[unique_indices,:]= K[:,unique_indices].T
    #     for batch in batches:
    #         labels,distances_squared = self.compute_new_labels_and_distances(
    #             K,labels,batch,X_to_C_inner_products,C_to_C_inner_products,
    #             X_to_implied_centers, weight_history
    #         )
    #         new_inertia = self.get_inertia(distances_squared)

    #     self.final_inertia, self.labels_ = new_inertia, labels 
        

    def fit(self, X, K,labels, distances_squared, C):
        rng = np.random.default_rng(self.random_state)
        n_iterations = self.n_iterations
        batch_size = self.batch_size
        if self.full_batch:
            batch_size = X.shape[0]
        
        X_to_C_inner_products = K[:,C].copy()

        #  C_to_C_inner_products should be the diagonal terms of K corresponding to the initial centers
        C_to_C_inner_products = K[C,C].copy()


        # print("norms of untruncated clusters:",C_to_C_inner_products)
        
        # scratch matrix we construct now to avoid reallocations. This is only used in 
        # compute_new_labels_and_distances_to_centers().
        X_to_implied_centers = np.zeros((K.shape[0],self.n_clusters))

        weight_history = None
        if not self.new_lr:
            weight_history = np.zeros(self.n_clusters)

        for i in range(n_iterations):
            batch = rng.choice(X.shape[0],batch_size,replace=True)
            
            labels,distances_squared = self.compute_new_labels_and_distances(
                K,labels,batch,X_to_C_inner_products,C_to_C_inner_products,
                X_to_implied_centers, weight_history
            )

            # print("norms of untruncated clusters:",C_to_C_inner_products)

            new_inertia = self.get_inertia(distances_squared)
            # print(new_inertia)
            
        # Calculate final inertia snd cluster assignments for the entire dataset
        self.final_inertia, self.labels_ = new_inertia, labels 

