import numpy as np
import scipy
import math
from sklearn.cluster import KMeans
import copy
from tqdm import tqdm
import random
import time
from tools import dist

class baseline:
  rep=1
  double_time=12


  #baseline


  def init_parameter(self,P,k):
    self.kmeans = KMeans(n_clusters=k, random_state=np.random.randint(1), n_init="auto")
    self.P=P
    self.k=k
    self.n=np.shape(self.P)[0]
    self.center=[]
    self.weight=[]
    self.weight_record=[]
    self.solution=[]
    self.index=[]
    self.consistency=[]
    self.cost=[]
    self.time=[]
    self.rep=1
    self.double_time=12

    self.assignment=[0 for i in range(self.n)]
    self.opt_est=[10**i for i in range(self.double_time)]
    self.size_cnt=np.zeros(self.double_time)
    self.size_cnt_bound=4*k*(1+np.log10(self.n))

  def sketch(self):
    if self.center==[]:
      self.center=[self.P[0,:]]
      self.assignment[0]=0
      self.weight_record.append(1)
      self.weight.append(1)
      self.index.append(0)
      self.cost.append(0)
      self.consistency.append(0)
    for j in tqdm(range(1,self.n)):
      t0=time.time()
      candidate=self.P[j,:]
      d=dist(candidate,self.center)
      nn=np.argmin(d)

      #print('number of centers:',len(self.center))
      #print('point number:',j)
      #print('nn:',nn)
      probability=np.minimum([self.k*(1+np.log10(self.n))*(d[nn]**2)/self.opt_est[i] for i in range(self.double_time)],1)
      np.random.seed(j)
      trials=np.random.binomial(self.rep,probability)
      success_sample=np.where(np.logical_and(trials==1,self.size_cnt<self.size_cnt_bound))
      self.size_cnt[success_sample]+=1

      if np.size(success_sample)>0:
        self.center.append(candidate)
        self.assignment[j]=len(self.center)-1
        self.weight_record.append(1)
        self.weight.append(0)
        weight_index=len(self.center)-1

      else:
        self.assignment[j]=nn
        weight_index=nn


      self.weight[weight_index]=self.weight[weight_index]+1
      if self.weight[weight_index]>2*self.weight_record[weight_index]:
        if len(self.center)>self.k:
          self.index.append(j)
          #print(np.array(self.center))
          #print(self.weight)
          self.kmeans.fit(np.array(self.center),sample_weight=self.weight)
          self.cost.append(self.kmeans.inertia_)

          if np.size(self.solution)==0:
            self.solution=copy.deepcopy(self.kmeans.cluster_centers_)
            self.consistency.append(self.k)
          else:
            difference=np.sum((self.solution-self.kmeans.cluster_centers_)**2,axis=1)
            self.consistency.append(self.consistency[-1]+np.count_nonzero(difference))
            self.solution=copy.deepcopy(self.kmeans.cluster_centers_)
        self.weight_record[weight_index]=self.weight[weight_index]
      self.time.append(time.time()-t0)


