import numpy as np
import scipy
import math
import copy
from tqdm import tqdm
import random
from sklearn.neighbors import KDTree
from scipy.spatial.distance import cdist
from scipy.spatial.distance import pdist
import time
from tools import dist

class incremental_coreset:
  rep=1
  double_time=12


  def init(self,P,k,sample_rate):
    self.P=P
    self.k=k
    self.n=np.shape(self.P)[0]
    self.center=[]

    self.Ring_cnt=np.zeros((100000,20))
    self.Group_cnt=np.zeros((100000,20))

    self.coreset_points=[]
    self.coreset_weight=[]
    self.sample_rate=sample_rate
    self.time=[]
    self.index=[]

    self.opt_est=[10**i for i in range(self.double_time)]
    self.center_weight=[]
    self.center_weight_record=[]
    self.size_cnt_bound=4*self.k*(1+np.log10(self.n))
    self.size_cnt=np.zeros(self.double_time)

  def update_myerson(self,input_index):

    candidate=self.P[input_index,:]

    if not self.center:
      self.center.append(candidate)
      #self.assignment[input_index]=len(self.center)-1
      self.center_weight_record.append(1)
      self.center_weight.append(1)

      self.coreset_points.append(candidate)
      self.coreset_weight.append(1)
      self.index.append(input_index)
      return 

    d=dist(candidate,self.center)
    nn=np.argmin(d)


    probability=np.minimum([self.k*(1+np.log10(self.n))*(d[nn]**2)/self.opt_est[i] for i in range(self.double_time)],1)
    np.random.seed(input_index)
    trials=np.random.binomial(self.rep,probability)
    success_sample=np.where(np.logical_and(trials==1,self.size_cnt<self.size_cnt_bound))
    self.size_cnt[success_sample]+=1

    if np.size(success_sample)>0:
        self.center.append(candidate)
        self.center_weight_record.append(1)
        self.center_weight.append(0)

        self.coreset_points.append(candidate)
        self.coreset_weight.append(1)
        self.index.append(input_index)
        #self.assignment[input_index]=len(self.center)-1
        weight_index=len(self.center)-1

    else:
      #self.assignment[input_index]=nn
      weight_index=nn


    self.center_weight[weight_index]=self.center_weight[weight_index]+1
    if self.center_weight[weight_index]>2*self.center_weight_record[weight_index]:
      self.index.append(input_index)
      #print(np.array(self.center))
      #print(self.weight)
      self.coreset_points.append(self.center[weight_index])
      self.coreset_weight.append(1.5*self.center_weight_record[weight_index])

      self.center_weight_record[weight_index]=self.center_weight[weight_index]
    

  def update_coreset(self,input_index):

    candidate=self.P[input_index,:]
    d=dist(candidate,self.center)
    nn=np.argmin(d)

    if d[nn]>1000:
      r=int(np.floor(np.log10(d[nn])))
      self.Ring_cnt[nn][r]+=1

      cnt_rank=int(np.floor(np.log10(self.Ring_cnt[nn][r])))
      self.Group_cnt[r][cnt_rank]+=1

      p=min(self.sample_rate*np.log(self.n)/self.Group_cnt[r][cnt_rank],1)
      trial=np.random.binomial(1,p)
      if trial==1:
        self.coreset_points.append(self.P[input_index,:])
        self.coreset_weight.append(1/p)
        self.index.append(input_index)

  def fit(self):
    for i in tqdm(range(0,self.n)):
      t0=time.time()
      self.update_myerson(i)
      #self.update_coreset(i)
      self.time.append(time.time()-t0)

