import numpy as np
import scipy
import math
from sklearn.cluster import KMeans
import copy
from tqdm import tqdm
import random
import time
from scipy.spatial.distance import cdist

class OneStepSwap:
    def __init__(self, k):
        self.k = k
        self.centers = None  # Will be a NumPy array of shape (num_centers, n_features)
        self.data = None     # Processed input data as a NumPy array
        self.time=[]
        self.cost=[]
        self.consistency=[]
        self.cluster_counts=np.zeros(k)

    def pre_process(self, raw_data,weight):
        if not isinstance(raw_data, np.ndarray):
            self.data = np.array(raw_data, dtype=float)
        else:
            self.data = raw_data.astype(float)
        self.weight=np.array(weight)

    def update_counts_and_cost(self, index):
        new_counts = np.zeros(self.k, dtype=int)
        distload=cdist(self.centers,self.data[0:index,:])
        cluster_idx=np.argmin(distload,axis=0)
        _,self.cluster_counts = np.unique(cluster_idx, return_counts=True)
        

        self.cost.append(np.sum(np.min(distload**2,axis=0)))

    def one_step_swap(self):
        n_samples = self.data.shape[0]

        if self.data is None or n_samples <self.k:
            self.time=[0]*n_samples
            self.cost=[0]*n_samples
            self.consistency=n_samples
            return

        self.time=[0]*self.k
        self.cost=[0]*self.k
        self.consistency=[1 for i in range(self.k)]

        self.centers=self.data[0:self.k,:]

        for i in tqdm(range(self.k,n_samples)):
          t0=time.time()
          candidate_cost=self.cost[-1]+self.weight[i]*np.min(np.linalg.norm(self.centers - self.data[i,:], axis=1)**2)
          #candidate_cost=math.inf
          candidate_idx=-1
          for j in range(self.k):
            candidate_centers=self.centers.copy()
            candidate_centers[j,:]=self.data[i,:].copy()
            distload=cdist(candidate_centers,self.data[0:i,:])**2
            swapping_cost=np.sum(np.min(distload*self.weight[np.newaxis,0:i],axis=0))
            #print("%d %d \n"%(swapping_cost,candidate_cost))
            if swapping_cost<candidate_cost:
              candidate_cost=swapping_cost
              candidate_idx=j
          if candidate_idx!=-1:
            self.consistency.append(1)
            candidate_centers=self.centers.copy()
            self.centers[candidate_idx,:]=self.data[i,:]
          else:
            candidate_centers=self.centers.copy()
            self.consistency.append(0)
          
          self.cost.append(candidate_cost)
          self.time.append(time.time()-t0)



