
import numpy as np
from kernel_kmeans import KernelMiniBatchKMeans, initialize_kernel_kmeans_plusplus



class SimpleCoefficientWindowPerCluster:
    """
    Store a sliding window per cluster of the R coefficients (the ones corresponding to points in batches)

    This is backed by two arrays per cluster that are both length tau.
    One is for the data indices and one is for the coefficients.

    We just need to keep track of the insert index since <sum x_i, sum y_i> is invariant under orderings.
    At the start while the window is not full, we just use zero coefficients and zero data indices (first datapoint).
    """

    def __init__(self,tau, k):
        self.tau = tau
        self.k = k

        self.coefficient_arrays = np.zeros((k,tau))
        self.data_index_arrays = np.zeros((k,tau),dtype=int)
        self.insert_indices = np.zeros(k, dtype=int)

    def add_batch(self, batch_data_indices, batch_coefficients, labels):
        """
        Insert a batch of coefficients and data indices into the window.
        """
        for j in range(self.k):
            indices = np.where(labels == j)[0]
            b_j = len(indices)
            if b_j == 0:
                continue
            if b_j > self.tau:
                # Only use the last tau elements if the batch is too large
                indices = indices[-self.tau:]
                b_j = self.tau
                
            coeffs = batch_coefficients[indices]
            data_indices = batch_data_indices[indices]
            # Insert position
            pos = self.insert_indices[j]
            if pos + b_j <= self.tau:
                # No wrap around
                self.coefficient_arrays[j, pos:pos + b_j] = coeffs
                self.data_index_arrays[j, pos:pos + b_j] = data_indices
                pos = (pos + b_j) % self.tau
            else:
                top_part = self.tau - pos
                self.coefficient_arrays[j, pos:] = coeffs[:top_part]
                self.data_index_arrays[j, pos:] = data_indices[:top_part]
                self.coefficient_arrays[j, :b_j - top_part] = coeffs[top_part:]
                self.data_index_arrays[j, :b_j - top_part] = data_indices[top_part:]
                pos = b_j - top_part
            self.insert_indices[j] = pos

    
    def scale_windows(self,factors):
        """
        factors is a k dimensional array of scaling factors for each cluster.
        """
        self.coefficient_arrays *= factors[:,np.newaxis]



    def batch_to_windows(self,batch_indices, K):
        """
        Compute the inner products of a batch to the R coefficients of each cluster.
        """
        batch_to_windows = np.zeros((batch_indices.shape[0],self.k))
        for j in range(self.k):
            ixgrid = np.ix_(batch_indices,self.data_index_arrays[j])

            batch_to_windows[:,j] = np.sum(K[ixgrid]*self.coefficient_arrays[j][np.newaxis,:],axis=1)
        return batch_to_windows

    def windows_to_windows_inner_product(self,K):
        """
        Compute the inner products of the R coefficients to themselves for each cluster.
        """
        windows_to_windows_inner_product = np.zeros((self.k,))
        for j in range(self.k):
            ixgrid = np.ix_(self.data_index_arrays[j],self.data_index_arrays[j])
            windows_to_windows_inner_product[j] = (np.outer(self.coefficient_arrays[j],self.coefficient_arrays[j])*K[ixgrid]).sum()
        return windows_to_windows_inner_product     

    def L_to_windows_inner_product(self,initial_cluster_center_indices,initial_cluster_coefficients,K):
        """
        Compute the inner products of the L coefficients to the R coefficients for each cluster.
        """

        L_to_windows_inner_product = np.zeros((self.k,))
        for j in range(self.k):
            L_to_windows_inner_product[j] = initial_cluster_coefficients[j] * np.sum(K[initial_cluster_center_indices[j],self.data_index_arrays[j]]*self.coefficient_arrays[j])
        return L_to_windows_inner_product

    def batch_delta(self, batch_indices, self_affinity_array, initial_cluster_center_indices, initial_cluster_coefficients, initial_cluster_self_inner_products, K):
        """
        Compute the distance of a batch to each of the implied centers before truncating the L terms:

        <phi(x),phi(x)> - 2<phi(x),L^j+R^j> + <L^j,L^j> + 2<L^j,R^j> + <R^j,R^j>
        """

        batch_to_L_inner_products = K[batch_indices[:,np.newaxis],initial_cluster_center_indices]*initial_cluster_coefficients
        batch_to_windows = self.batch_to_windows(batch_indices,K)
        L_to_L_inner_products = initial_cluster_self_inner_products * initial_cluster_coefficients**2
        windows_to_windows_inner_product = self.windows_to_windows_inner_product(K)
        L_to_windows_inner_product = self.L_to_windows_inner_product(initial_cluster_center_indices,initial_cluster_coefficients,K)

        # print("in delta")
        # print("shapes:")
        # print(batch_to_L_inner_products.shape)
        # print(batch_to_windows.shape)
        # print(L_to_L_inner_products[np.newaxis,:].shape)
        # print(L_to_windows_inner_product[np.newaxis,:].shape)
        # print(windows_to_windows_inner_product[np.newaxis,:].shape)
        # print(self_affinity_array[batch_indices][np.newaxis,:].shape)


        batch_delta = self_affinity_array[batch_indices][:,np.newaxis] - 2*batch_to_L_inner_products - 2*batch_to_windows + L_to_L_inner_products[np.newaxis,:] + 2*L_to_windows_inner_product[np.newaxis,:] + windows_to_windows_inner_product[np.newaxis,:]



        return batch_delta

    def truncated_batch_delta(self, batch_indices, self_affinity_array, K):
        """
        Compute the distance of a batch to each of the implied centers after truncating the L terms:

        <phi(x),phi(x)> - 2<phi(x),R^j> + <R^j,R^j>
        """
        batch_to_windows = self.batch_to_windows(batch_indices,K)
        windows_to_windows_inner_product = self.windows_to_windows_inner_product(K)
        batch_delta = self_affinity_array[batch_indices][:,np.newaxis] - 2*batch_to_windows + windows_to_windows_inner_product
        return batch_delta

class CoefficientWindowPerCluster:
    """
    Store a sliding window, per cluster, of the R coefficients.
    This is backed by two arrays per cluster that are twice the length of tau *b,
    One is for data indices and one is for the coefficients. Since we are storing per cluster, the labels are not needed.    
    """

    def __init__(self,tau, b, k):
        self.b = b
        self.k = k
        self.tau = int(tau)
        self.coefficient_arrays = np.zeros((k,2*self.tau))
        self.data_index_arrays = np.zeros((k,2*self.tau),dtype=int)
        self.insert_indices = np.zeros(k, dtype=int)
        self.window_start_indices = np.zeros(k, dtype=int)
        self.window_widths = np.zeros(k, dtype=int)
        self.current_batch_start_stop = np.zeros((k,2),dtype=int)

        # print(f"tau: {tau}, b: {b}, k: {k}")

    def add_batch(self, batch_data_indices, batch_coefficients, labels):
        # print(f"batch_data_indices: {batch_data_indices.shape}, batch_coefficients: {batch_coefficients.shape}, labels: {labels.shape}")
        """
        Add a batch to the windows by cluster.
        """
        for j in range(self.k):
            # if j == 0:
            #     print(f"cluster {j}")
            #     print(f"coefficients: {self.coefficient_arrays[j]}, data_indices: {self.data_index_arrays[j]}")
            #     print(f"insert index: {self.insert_indices[j]}, window start index: {self.window_start_indices[j]}, window width: {self.window_widths[j]}")
            #     print(f"current batch start stop: {self.current_batch_start_stop[j]}")
            mask = (labels == j)
            b_j = np.sum(mask)
            if np.sum(mask) == 0:
                continue
            # print(b_j.sum(),b_j.shape,)
            # print(b_j,self.tau)

            data_indices = batch_data_indices[mask]
            coefficients = batch_coefficients[mask]

            if b_j > self.tau:
                #  only use the last tau elements
                data_indices = data_indices[-self.tau:]
                coefficients = coefficients[-self.tau:]
                b_j = self.tau

            assert self.insert_indices[j]< self.tau
            pos = self.insert_indices[j]

            self.coefficient_arrays[j,pos:pos+b_j] = coefficients
            self.data_index_arrays[j,pos:pos+b_j] = data_indices

            if pos+ b_j<self.tau:  # <= should be < because tau is the window size.
                # no wrap around
                self.coefficient_arrays[j,pos+self.tau:pos+self.tau+b_j] = coefficients
                self.data_index_arrays[j,pos+self.tau:pos+self.tau+b_j] = data_indices
            else:
                assert 1== 2
                top_part = self.tau - pos
                bottom_part = pos + b_j - self.tau
                # assert top_part + bottom_part == b_j
                # print(f"top part: {top_part}, bottom part: {bottom_part}")
                self.coefficient_arrays[j,pos+self.tau:] = coefficients[:top_part]
                self.data_index_arrays[j,pos+self.tau:] = data_indices[:top_part]

                self.coefficient_arrays[j,:bottom_part] = coefficients[top_part:]
                self.data_index_arrays[j,:bottom_part] = data_indices[top_part:]
            
            # update the insert index to always lie in the bottom half of the array
            self.insert_indices[j] = (self.insert_indices[j] + b_j) %self.tau
            
            # update the window width and start index
            self.window_widths[j] = min(self.tau, self.window_widths[j] + b_j)
            
            # only start moving the window start index after the window is full.
            if self.window_widths[j] == self.tau:
                assert 1 == 2
                self.window_start_indices[j] = (self.window_start_indices[j] + b_j) % self.tau

            # use the upper half of the window for the current batch indices
            self.current_batch_start_stop[j] = [self.insert_indices[j]+self.tau - b_j,self.insert_indices[j] + self.tau]
            # assert self.current_batch_start_stop[j,1] - self.current_batch_start_stop[j,0] == b_j
            
    def window_view(self,j):
        """
        Return a view of the current window's data and coefficients for cluster j.
        """
        pos = self.window_start_indices[j]
        return self.data_index_arrays[j,pos:pos+self.window_widths[j]], self.coefficient_arrays[j,pos:pos+self.window_widths[j]]
    
    def current_batch_indices(self,j):
        """
        Return the indices of the most recent batch for cluster j.
        We guarrantee the stop index is in the top half of the array. The start may be in the bottom half.
        This means we always have a contiguous view.
        """
        return self.data_index_arrays[j,self.current_batch_start_stop[j,0]:self.current_batch_start_stop[j,1]]

    def update_R_term_coefficients(self):
        """
        Assume the most recent batch is the current batch. Update the rest of the window using the current batch's cluster sizes.
        """
        for j in range(self.k):
            # assert self.current_batch_start_stop[j,1]>=self.tau
            b_j_current = self.current_batch_start_stop[j,1] - self.current_batch_start_stop[j,0]
            if b_j_current == 0:
                continue

            factor = (1- np.sqrt(b_j_current/self.b))
            if self.current_batch_start_stop[j,0]>=self.tau:
                # The current batch indices are contigous in both parts of the array
                # We update everything not in the current batch    
                bottom_stop = self.current_batch_start_stop[j,0] - self.tau
                bottom_start = self.current_batch_start_stop[j,1] - self.tau
                top_stop = self.current_batch_start_stop[j,0]
                top_start = self.current_batch_start_stop[j,1]

                
                # assert 


                self.coefficient_arrays[j,:bottom_stop] *=factor
                self.coefficient_arrays[j,bottom_start:top_stop] *=factor
                self.coefficient_arrays[j,top_start:] *=factor
            else:
                # The current batch indices are contigous "in the middle" of the array but
                # wrap around at the end back to the start.
                bottom_start = self.current_batch_start_stop[j,1] - self.tau
                bottom_stop = self.current_batch_start_stop[j,0]
                top_start = self.current_batch_start_stop[j,1]
                top_stop = self.current_batch_start_stop[j,0] + self.tau

                self.coefficient_arrays[j,bottom_start:bottom_stop] *=factor
                self.coefficient_arrays[j,top_start:top_stop] *=factor

    def batch_to_R_inner_products(self,batch_indices, K):
        """
        Compute the inner products of a batch to the R coefficients of each cluster.
        """
        batch_to_R_inner_products = np.zeros((batch_indices.shape[0],self.k))
        for j in range(self.k):
            window_start = self.window_start_indices[j]
            window_end = self.window_start_indices[j] + self.window_widths[j]

            ixgrid = np.ix_(batch_indices,self.data_index_arrays[j,window_start:window_end])
            _batch_to_j_R_inner_products = K[ixgrid]*self.coefficient_arrays[j,window_start:window_end]
            batch_to_R_inner_products[:,j] = np.sum(_batch_to_j_R_inner_products,axis=1)
        return batch_to_R_inner_products

    def R_to_R_inner_products(self,K):
        """
        Compute the inner products of the R coefficients to themselves for each cluster.
        """
        R_to_R_inner_products = np.zeros((self.k,))
        for j in range(self.k):
            window_start = self.window_start_indices[j]
            window_end = self.window_start_indices[j] + self.window_widths[j]
            ixgrid = np.ix_(self.data_index_arrays[j,window_start:window_end],self.data_index_arrays[j,window_start:window_end])
            R_to_R_inner_products[j] = (np.outer(self.coefficient_arrays[j,window_start:window_end],self.coefficient_arrays[j,window_start:window_end])*K[ixgrid]).sum()
        return R_to_R_inner_products
    
    def L_to_R_inner_products(self,initial_cluster_center_indices,initial_cluster_coefficients,K):
        """
        Compute the inner products of the L coefficients to the R coefficients for each cluster.
        """
        L_to_R_inner_products = np.zeros((self.k,))
        for j in range(self.k):
            window_start = self.window_start_indices[j]
            window_end = self.window_start_indices[j] + self.window_widths[j]
            ixgrid = np.ix_(initial_cluster_center_indices,self.data_index_arrays[j,window_start:window_end])
            L_to_R_inner_products[j] = initial_cluster_coefficients[j] * np.sum(K[ixgrid]*self.coefficient_arrays[j,window_start:window_end])
        return L_to_R_inner_products
    
    def batch_delta(self, batch_indices, self_affinity_array, initial_cluster_center_indices, initial_cluster_coefficients, initial_cluster_self_inner_products, K):
        """
        Compute the distance of a batch to each of the implied centers before truncation:

        <phi(x),phi(x)> - 2<phi(x),L^j+R^j> + <L^j,L^j> + 2<L^j,R^j> + <R^j,R^j>
        """
        batch_to_L_inner_products = K[batch_indices[:,np.newaxis],initial_cluster_center_indices]*initial_cluster_coefficients
        batch_to_R_inner_products = self.batch_to_R_inner_products(batch_indices,K)
        L_to_L_inner_products = initial_cluster_self_inner_products * initial_cluster_coefficients**2
        R_to_R_inner_products = self.R_to_R_inner_products(K)
        L_to_R_inner_products = self.L_to_R_inner_products(initial_cluster_center_indices,initial_cluster_coefficients,K)

        batch_delta = self_affinity_array[batch_indices][:,np.newaxis] - 2*batch_to_L_inner_products - 2*batch_to_R_inner_products + L_to_L_inner_products + 2*L_to_R_inner_products + R_to_R_inner_products
        return batch_delta
    
    def truncated_batch_delta(self, batch_indices, self_affinity_array, K):
        """
        Compute the distance of a batch to each of the implied centers after truncation:

        <phi(x),phi(x)> - 2<phi(x),R^j> + <R^j,R^j>
        """
        batch_to_R_inner_products = self.batch_to_R_inner_products(batch_indices,K)
        R_to_R_inner_products = self.R_to_R_inner_products(K)

        batch_delta = self_affinity_array[batch_indices][:,np.newaxis] - 2*batch_to_R_inner_products + R_to_R_inner_products
        return batch_delta

# MARK: - CoefficientWindow
class CoefficientWindow:
    """
    Store a sliding window of the R coefficients. 
    This is backed by three arrays that are twice the length of tau *b,
    one for the window data, one for the window coefficients and one for the window labels.
    """
    def __init__(self, tau, b,k):
        self.tau = tau
        self.b = b
        self.k = k
        self.size = tau*b
        self.coefficient_array = np.zeros(2*self.size)
        self.label_array = np.zeros(2*self.size, dtype=int)
        self.data_index_array = np.zeros((2*self.size),dtype=int)
        self.insert_index = 0
        self.window_start_index = 0
        self.window_width = 0

    def add_batch(self, data_indices, coefficients, labels):
        """
        Add a batch of coefficients and labels to the window.
        INPUTS:
            data: np.array of shape (b,)
            coefficients: np.array of shape (b,)
            labels: np.array of shape (b,)
        """
        assert coefficients.shape[0] == self.b
        assert labels.shape[0] == self.b

        pos = self.insert_index % self.size

        self.coefficient_array[pos:pos+self.b] = coefficients
        self.coefficient_array[pos+self.size:pos+self.size+self.b] = coefficients
        
        self.data_index_array[pos:pos+self.b] = data_indices
        self.data_index_array[pos+self.size:pos+self.size+self.b] = data_indices
        
        self.label_array[pos:pos+self.b] = labels
        self.label_array[pos+self.size:pos+self.size+self.b] = labels
        
        self.insert_index += self.b

        self.window_width = min(self.size, self.window_width + self.b)
        if self.window_width == self.size:
            self.window_start_index = (self.window_start_index + self.b) % self.size
        
    
    def window_view(self):
        """
        Return a view of  the current window's data, coefficients and labels.
        """
        pos = self.window_start_index
        return self.data_index_array[pos:pos+self.window_width], self.coefficient_array[pos:pos+self.window_width], self.label_array[pos:pos+self.window_width]
    
    def current_batch_labels(self):
        """
        Return the labels of the most recent batch.
        """
        current_batch_start = (self.insert_index - self.b) % self.size
        return self.label_array[current_batch_start:current_batch_start+self.b]

    def update_R_term_coefficients(self):
        """
        Assume the most recent batch is the current batch. Update the rest of the window using the current batch's cluster sizes.
        """
        current_batch_labels = self.current_batch_labels()
        b_j_current_batch_array = np.array([np.sum(current_batch_labels == j) for j in range(self.k)])


        for batch_start in range(self.window_start_index, self.window_start_index + self.window_width-self.b, self.b):
            # update lower part and upper part at the same time:
            for j in range(self.k):
                batch_mask = (self.label_array[batch_start:batch_start+self.b] == j)
                b_j_current = b_j_current_batch_array[j]
                if b_j_current == 0:
                    continue
                else:
                    # decay the coefficients of the points in cluster j by (1-sqrt(b_j/b))
                    factor = (1- np.sqrt(b_j_current/self.b))
                    bottom_start = batch_start % self.size
                    top_start =  bottom_start + self.size
                    self.coefficient_array[bottom_start:bottom_start+self.b][batch_mask] *= factor
                    self.coefficient_array[top_start:top_start+self.b][batch_mask] *= factor

    def batch_to_R_inner_products(self,batch_indices, K):

        window_start = self.window_start_index
        window_end = self.window_start_index + self.window_width

        ixgrid = np.ix_(batch_indices,self.data_index_array[window_start:window_end])
        _batch_to_R_inner_products = K[ixgrid]*self.coefficient_array[window_start:window_end]
        # to get the actual inner products to the clusters we need to sum over the cluster masks
        batch_to_R_inner_products = np.zeros((batch_indices.shape[0],self.k))
        for j in range(self.k):
            mask = (self.label_array[window_start:window_end] == j)
            batch_to_R_inner_products[:,j] = np.sum(_batch_to_R_inner_products[:,mask],axis=1)
        return batch_to_R_inner_products
        
    def R_to_R_inner_products(self,K):
        window_start = self.window_start_index
        window_end = self.window_start_index + self.window_width
        R_to_R_inner_products = np.zeros((self.k,))
        for j in range(self.k):
            mask = (self.label_array[window_start:window_end] == j)
            if np.sum(mask) == 0:
                R_to_R_inner_products[j] = 0
            else:
                ixgrid = np.ix_(self.data_index_array[window_start:window_end][mask],self.data_index_array[window_start:window_end][mask])
                R_to_R_inner_products[j] = (np.outer(self.coefficient_array[window_start:window_end][mask],self.coefficient_array[window_start:window_end][mask])*K[ixgrid]).sum()
                # R_to_R_inner_products[j] = (np.outer(self.coefficient_array[window_start:window_end][mask],self.coefficient_array[window_start:window_end][mask])*kernel_function(self.data_index_array[window_start:window_end][mask],self.data_index_array[window_start:window_end][mask])).sum()
        return R_to_R_inner_products

    def L_to_R_inner_products(self,initial_cluster_center_indices,initial_cluster_coefficients,K):
        window_start = self.window_start_index
        window_end = self.window_start_index + self.window_width
        L_to_R_inner_products = np.zeros((self.k,))
        for j in range(self.k):
            mask = (self.label_array[window_start:window_end] == j)
            if np.sum(mask) == 0:
                L_to_R_inner_products[j] = 0
            else:
                ixgrid = np.ix_(initial_cluster_center_indices,self.data_index_array[window_start:window_end][mask])
                L_to_R_inner_products[j] = initial_cluster_coefficients[j] * np.sum(K[ixgrid]*self.coefficient_array[window_start:window_end][mask])
                # L_to_R_inner_products[j] = initial_cluster_coefficients[j] * np.sum(kernel_function(initial_cluster_centers[j].reshape(1,-1),self.data_index_array[window_start:window_end][mask])*self.coefficient_array[window_start:window_end][mask])
        return L_to_R_inner_products

    def batch_delta(self, batch_indices, self_affinity_array, initial_cluster_center_indices, initial_cluster_coefficients, initial_cluster_self_inner_products, K):
        """
        Compute the distance of a batch to each of the implied centers before truncation:

        <phi(x),phi(x)> - 2<phi(x),L^j+R^j> + <L^j,L^j> + 2<L^j,R^j> + <R^j,R^j>
        """

        ixgrid = np.ix_(batch_indices,initial_cluster_center_indices)
        batch_to_L_inner_products = K[ixgrid]*initial_cluster_coefficients
        # batch_to_L_inner_products = kernel_function(batch,initial_cluster_centers)*initial_cluster_coefficients
        batch_to_R_inner_products = self.batch_to_R_inner_products(batch_indices,K)
        L_to_L_inner_products = initial_cluster_self_inner_products * initial_cluster_coefficients**2
        R_to_R_inner_products = self.R_to_R_inner_products(K)
        L_to_R_inner_products = self.L_to_R_inner_products(initial_cluster_center_indices,initial_cluster_coefficients,K)

        batch_delta = self_affinity_array[batch_indices][:,np.newaxis] - 2*batch_to_L_inner_products - 2*batch_to_R_inner_products + L_to_L_inner_products + 2*L_to_R_inner_products + R_to_R_inner_products
        return batch_delta

    def truncated_batch_delta(self, batch_indices, self_affinity_array, K):
        """
        Compute the distance of a batch to each of the implied centers after truncation:

        <phi(x),phi(x)> - 2<phi(x),R^j> + <R^j,R^j>
        """
        batch_to_R_inner_products = self.batch_to_R_inner_products(batch_indices,K)
        R_to_R_inner_products = self.R_to_R_inner_products(K)

        batch_delta = self_affinity_array[batch_indices][:,np.newaxis] - 2*batch_to_R_inner_products + R_to_R_inner_products
        return batch_delta






# MARK: - TMBKKM

class WTKernelMiniBatchKMeans:
    def __init__(self, n_clusters, batch_size, n_iterations, tau=10, lazy=True, new_lr=False, random_state=None):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        self.n_iterations = n_iterations
        self.centroids = None
        self.tau = tau
        self.lazy = lazy
        self.final_inertia = 0
        self.new_lr=new_lr #new learning rate flag
        self.random_state = random_state

        # self.untruncated = KernelMiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, n_iterations=n_iterations, new_lr=True)
        
    #needed for consistency with sklrean (for running experiments)   
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self

    def update_L_term_coefficients(self,L_coefficients, batch_labels):
        for j in range(self.n_clusters):
            batch_mask = (batch_labels == j)
            b_j = np.sum(batch_mask)
            L_coefficients[j] *= (1-np.sqrt(b_j/self.batch_size))
        return L_coefficients
    
    def update_L_term_coefficients_per_cluster(self,L_coefficients, factors):
        L_coefficients *= factors
        return L_coefficients

    #   X, K, init_labels, init_distances_squared, init_C   
    def fit(self, X, K,init_labels=None, init_distances=None, init_C=None, W=None):
        if init_C is None:
            init_C = initialize_kernel_kmeans_plusplus(K,self.n_clusters)

        sampling_dist = np.ones(X.shape[0])/X.shape[0]
        if W is None:
            W = np.ones(X.shape[0])
        else:
            sampling_dist = W/np.sum(W)

        initial_cluster_center_indices = np.array(init_C)
        
        
        rng = np.random.default_rng(self.random_state)
        

        k = self.n_clusters
        b = self.batch_size
        tau = self.tau
        iterations = self.n_iterations


        n = X.shape[0]

        self.window = SimpleCoefficientWindowPerCluster(tau, k)
        # diagonal of K
        self_affinity_array = K.diagonal()


        initial_cluster_coefficients = np.ones(k)
        initial_cluster_self_inner_products = K.diagonal()[initial_cluster_center_indices] 






        b_js = np.zeros(k)
        for i in range(iterations):
            # print(i)
            if True:  # truncate after 50 iters



                total_R_coefficients = np.sum(self.window.coefficient_arrays, axis=1)
                total_coefficients = initial_cluster_coefficients + total_R_coefficients

                # multiply coefficients so they sum to 1:
                factor = 1/total_coefficients
                initial_cluster_coefficients *= factor
                self.window.coefficient_arrays *= factor[:,np.newaxis]

                # Assert that the coefficients sum to 1 for each cluster
                # assert np.allclose(total_coefficients, 1.0, atol=1e-6), f"Coefficients do not sum to 1: {total_coefficients}"



                batch_indices = rng.choice(n,b,replace=True, p=sampling_dist)
                # batch_indices = untruncated_batch
                
                next_batch_delta = self.window.batch_delta(batch_indices, self_affinity_array, initial_cluster_center_indices, initial_cluster_coefficients, initial_cluster_self_inner_products, K)


                # assert np.all(next_batch_delta <= 4), f"Deltas exceed 4: {next_batch_delta}"        

                next_batch_labels = np.argmin(next_batch_delta,axis=1)
                # print("\t\t\t\t",[np.sum(next_batch_labels == j) for j in range(k)])

                # compute batch cluster sizes, alphas, 1-alphas  for each cluster
                for j in range(k):
                    # Find indices of points assigned to cluster j
                    cluster_indices = np.where(next_batch_labels == j)[0]
                    b_j = len(cluster_indices)
                    
                    # If the batch size for this cluster exceeds tau, we clip it
                    if b_j > tau:
                        # Clip to the last tau points
                        excess_points = b_j - tau
                        # Set the labels of the first 'excess_points' to -1, so they are ignored
                        next_batch_labels[cluster_indices[:excess_points]] = -1
                        # Only keep the last tau points
                        b_j = tau
                    
                    # Update the b_js array with the clipped number of points
                    b_js[j] = b_j
                # print("batch cluster sizes: \t \t", b_js, np.sum(b_js))
                b = int(np.sum(b_js))
                # print("batch cluster sizes: \t \t", b_js)
                b_js_nonzero = np.maximum(b_js,1)
                alphas = np.array([np.sqrt(b_j/b) for b_j in b_js])
                # print("alphas:", alphas)
                one_minus_alphas = 1-alphas

                # Scale L coefficients and window coefficients
                self.update_L_term_coefficients_per_cluster(initial_cluster_coefficients,one_minus_alphas)

                self.window.scale_windows(one_minus_alphas)

                # print("updating coeffs")

                batch_coefficients = alphas[next_batch_labels] * 1/b_js_nonzero[next_batch_labels]
                # print(alphas[next_batch_labels],1/b_js_nonzero[next_batch_labels])

                


                self.window.add_batch(batch_indices,batch_coefficients, next_batch_labels)

            else:
                total_R_coefficients = np.sum(self.window.coefficient_arrays, axis=1)
                total_coefficients =  total_R_coefficients

                # multiply coefficients so they sum to 1:
                factor = 1/total_coefficients
                self.window.coefficient_arrays *= factor[:,np.newaxis]

                # Assert that the coefficients sum to 1 for each cluster
                # assert np.allclose(total_coefficients, 1.0, atol=1e-6), f"Coefficients do not sum to 1: {total_coefficients}"



                batch_indices = rng.choice(n,b,replace=True, p=sampling_dist)
                # batch_indices = untruncated_batch
                
                next_batch_delta = self.window.truncated_batch_delta(batch_indices, self_affinity_array, K)


                # assert np.all(next_batch_delta <= 4), f"Deltas exceed 4: {next_batch_delta}"        

                next_batch_labels = np.argmin(next_batch_delta,axis=1)
                # print("\t\t\t\t",[np.sum(next_batch_labels == j) for j in range(k)])

                # compute batch cluster sizes, alphas, 1-alphas  for each cluster
                for j in range(k):
                    # Find indices of points assigned to cluster j
                    cluster_indices = np.where(next_batch_labels == j)[0]
                    b_j = len(cluster_indices)
                    
                    # If the batch size for this cluster exceeds tau, we clip it
                    if b_j > tau:
                        # Clip to the last tau points
                        excess_points = b_j - tau
                        # Set the labels of the first 'excess_points' to -1, so they are ignored
                        next_batch_labels[cluster_indices[:excess_points]] = -1
                        # Only keep the last tau points
                        b_j = tau
                    
                    # Update the b_js array with the clipped number of points
                    b_js[j] = b_j
                # print("batch cluster sizes: \t \t", b_js, np.sum(b_js))
                b = int(np.sum(b_js))
                # print("batch cluster sizes: \t \t", b_js)
                b_js_nonzero = np.maximum(b_js,1)
                alphas = np.array([np.sqrt(b_j/b) for b_j in b_js])
                # print("alphas:", alphas)
                one_minus_alphas = 1-alphas

                # Scale window coefficients
                self.window.scale_windows(one_minus_alphas)

                # print("updating coeffs")

                batch_coefficients = alphas[next_batch_labels] * 1/b_js_nonzero[next_batch_labels]
                # print(alphas[next_batch_labels],1/b_js_nonzero[next_batch_labels])

            
                self.window.add_batch(batch_indices,batch_coefficients, next_batch_labels)
        

        data_delta = self.window.batch_delta(np.arange(n),self_affinity_array, initial_cluster_center_indices,initial_cluster_coefficients,initial_cluster_self_inner_products,K)
        self.labels_ = np.argmin(data_delta,axis=1)
        return self.labels_


        

