import numpy as np
import scipy
import math
from sklearn.cluster import KMeans
import copy
from tqdm import tqdm
import random
from sklearn.neighbors import KDTree
from scipy.spatial.distance import cdist
from scipy.spatial.distance import pdist
import pickle
import time
from tools import dist

def constrained_kmeans(data, weight,potential_centers, k, max_iterations=100):
    # Randomly choose initial centers from the potential set
    centers = potential_centers[np.random.choice(len(potential_centers), k, replace=False)]

    for _ in range(max_iterations):
        # Assign each point to the nearest center
        distances = cdist(data, centers)
        labels = np.argmin(distances, axis=1)

        new_centers = []
        for i in range(k):
            # Get all points assigned to center i
            assigned_points = data[labels == i]
            if len(assigned_points) == 0:
                # If no points are assigned, choose a random potential center
                new_center = potential_centers[np.random.choice(len(potential_centers))]
            else:
                # Find the potential center closest to the mean of assigned points
                mean_point = np.sum(assigned_points*weight[labels==i,np.newaxis], axis=0)/np.sum(weight[labels==i])
                distances_to_mean = cdist(potential_centers, [mean_point])
                new_center_index = np.argmin(distances_to_mean)
                new_center = potential_centers[new_center_index]

            new_centers.append(new_center)

        new_centers = np.array(new_centers)
        if np.all(centers == new_centers):
            break
        centers = new_centers

    # Calculate the cost as the sum of squared distances
    final_distances = cdist(data, centers)
    cost = np.sum(np.min(final_distances, axis=1)**2)

    return centers, labels, cost

class Consistent_Clustering:
   #Basic parameter
   z=2
   k=5
   aspect_ratio=1000000
   n=0
   total_cnt=50
   total_P=np.random.randint(0,aspect_ratio,[n,2])
   total_weight=np.ones(total_cnt)
   C=np.zeros((k,2))
   cost=np.zeros(total_cnt)
   total_cost=0
   nn=[]
   del_num=0
   cluster=[[]]
   wit=[]
   cost_kmeans=[]
   cost_alg=[]
   robust_parameter=2
   wsp_parameter=0.001
   del_parameter=10
   consistency=0
   baseline_consistency=0
   inter_cost=0
   start_point=0
   time=[]

   def __init__(self):
      self.z = 2
      self.k = 5
      self.aspect_ratio=10000
      self.total_cnt=50
      self.total_P=np.random.randint(0,self.aspect_ratio,[self.total_cnt,2])
      self.total_weight=np.ones(self.total_cnt)
      self.time=[]

   def init(self,k,P,weight,del_parameter=10,wsp_parameter=1e-6):
      self.k=k
      self.total_P=np.array(P)
      self.total_cnt=np.shape(P)[0]
      self.total_weight=np.array(weight)
      self.wsp_parameter=wsp_parameter
      self.consistency=[]
      self.baseline_consistency=[]
      self.cost_alg=[]
      self.cost_kmeans=[]
      self.wit=[]
      self.total_cost=0
      
      self.del_parameter=del_parameter
      self.time=[]
      self.time_brute_force=[]


   def gen_point(self): #only for test
      np.random.seed(0)
      self.total_P=np.random.randint(0,self.aspect_ratio,[self.total_cnt,2])
      total_weight=np.ones(self.total_cnt)

   def test(self,k,del_parameter=10,wsp_parameter=0.00001,start_point=0):
      self.gen_point()
      self.del_parameter=del_parameter
      self.wsp_parameter=wsp_parameter
      self.consistency=[0]
      self.baseline_consistency=[0]
      self.cost_alg=[]
      self.cost_kmeans=[]
      self.wit=[]
      self.total_cost=0
      self.start_point=start_point



   def nearest_neighbor(self):
      self.cost=np.min(cdist(self.P,self.C),axis=1)**2
      self.total_cost=np.sum(self.cost*self.weight)

   def robustify(self,center_index):
      #robust level
      center_dist=cdist(self.C[center_index],self.C)
      #print("center distance\n")
      #print(center_dist[np.nonzero(center_dist)],'\n')
      level=max(0,np.ceil(np.log10(np.min(center_dist[np.nonzero(center_dist)])/200)).astype(int))
      #print("level:",level,'\n')
      #print("wit:",self.wit[center_index],'\n')
      #print('wit length:',len(self.wit[center_index])-1,'\n')
      #check robustiness
      if level<=len(self.wit[center_index])-1:
        for i in range(len(self.wit[center_index])-1,0,-1):
          # Query the tree for points within the radius r from q
          indices = self.tree.query_radius(np.array(self.wit[center_index][i]), self.robust_parameter**i)[0]

          if np.size(indices==0):
            break

          # Extract the points within the radius
          points_within_r = self.P[indices,:]

          # Compute the squared distances
          squared_distances = np.sum((points_within_r - self.wit[center_index][i])**2, axis=1)

          #print("inner points\n")
          #print(self.P[ball_index,:],'\n')

          avg_cost=squared_distances/np.sum(self.weight[indices])
          if avg_cost>=(self.robust_parameter)**(2*i)/50 and self.wit[center_index][i]!=self.wit[center_index][i-1]:
            break
          else:
            if avg_cost<(self.robust_parameter)**(2*i)/50 and self.wit[center_index][i-1]!=((np.sum(self.P[indices,:]*self.weight[indices,np.newaxis],axis=0))/np.sum(self.weight[indices])):
              break

      if i==1:
        return

      #robustify
      self.wit[center_index]=[self.C[center_index]]
      for i in range(level,-1,-1):
        c=self.wit[center_index][0]
        #print("center",c,'\n')
        # Query the tree for points within the radius r from q
        indices = self.tree.query_radius(np.array(c), self.robust_parameter**i)[0]

        # Extract the points within the radius
        points_within_r = self.P[indices,:]

        # Compute the squared distances
        squared_distances = np.sum((points_within_r - np.array(c)) ** 2, axis=1)
        #print("inner points\n")
        #print(self.P[ball_index,:],'\n')

        if(np.size(indices)==0):
          for j in range(i,-1,-1):
            self.wit[center_index].insert(0,self.wit[center_index][0])
          break
        avg_cost=squared_distances/np.sum(self.weight[indices])

        if avg_cost>=(self.robust_parameter)**(2*i)/50:
          self.wit[center_index].insert(0,c)
        else:
          weight_mul=np.append(self.weight[indices],self.weight[indices])
          weight_mul=weight_mul.reshape((2,-1)).transpose()
          new_center=((np.sum(self.P[indices,:]*self.weight[indices,np.newaxis],axis=0))/np.sum(self.weight[indices]))
          self.wit[center_index].insert(0,new_center.flatten())
      self.C[center_index]=self.wit[center_index][0]
      #print('finish robustify')



   def detect_deleting_number(self):
      candidate_cost=0
      del_num=0
      for i in range(self.k-1):
        center,label,cost=constrained_kmeans(self.P,self.weight,self.C,self.k-i-1)
        if cost>=self.del_parameter*self.total_cost:
          del_num=i
          self.inter_cost=candidate_cost
          break
        candidate_cost=cost
      return del_num

   def swap_center(self):
    #an approximate solution


    #print(self.kmeans.cluster_centers_)
    candidate_center=copy.deepcopy(self.kmeans.cluster_centers_)
    inter_center=candidate_center


    #print('current center:',self.C)
    #print('candidate center:',candidate_center)


    swap_cnt=self.k
    new_cluster=[]

    dist_load=cdist(self.P[self.phase_begin:self.n],candidate_center)
    nn_indices=np.argmin(dist_load,axis=1)


    new_cluster=(set(nn_indices))

    candidate_swapping=list((set(range(self.k))-(new_cluster)))
    #print(candidate_center)
    if(len(candidate_swapping)>1):

      tree_V=KDTree(candidate_center[candidate_swapping])
      tree_U=KDTree(self.C)

      distance_cross,nearest_cross=tree_U.query(candidate_center[candidate_swapping],k=1)
      distance_V, _=tree_V.query(candidate_center[candidate_swapping],k=2)
      distance_U, _=tree_U.query(self.C,k=2)

      distance_cross=distance_cross.flatten()
      nearest_cross=nearest_cross.flatten()
      #print(np.shape(distance_cross))

      #print(distance_cross)
      #print(distance_cross<distance_U[nearest_cross,1]/self.wsp_parameter)
      #print(distance_cross<distance_V[:,1]/self.wsp_parameter)

      swapping_index=np.where(np.logical_and((distance_cross<(distance_U[nearest_cross,1]/self.wsp_parameter)).flatten(), (distance_cross<(distance_V[:,1]/self.wsp_parameter)).flatten()))[0]
      #print(swapping_index)

      #print('before:',inter_center)
      #print(candidate_swapping[0])
      for i in swapping_index:
        inter_center[candidate_swapping[i]]=self.C[nearest_cross[i]]


    self.C=inter_center
    #print('after:',self.C)

    #print(self.kmeans.cluster_centers_)
    #print('swap',swap_cnt,'points','\n')
    #print('after swapping',self.C,'\n')


   def simulate(self):
    #self.gen_point()
    #initialization
    #print("initialization\n")
    if self.total_cnt<self.k+1:
      self.time=[0 for _ in range(self.total_cnt)]
      self.C=self.total_P
      self.cost_kmeans.append(0)
      self.cost_alg.append(0)
      self.consistency.append(self.total_cnt)
      self.baseline_consistency.append(self.total_cnt)
      return

    
    self.n=self.k

    for i in range(self.k):
      self.time.append(0)
      self.time_brute_force.append(0)
      self.consistency.append(0)
      self.baseline_consistency.append(0)
      self.cost_alg.append(0)
      self.cost_kmeans.append(0)

    t0=time.time()
    self.kmeans = KMeans(n_clusters=self.k, random_state=0, n_init="auto")
    #print(self.n)
    self.kmeans.fit(self.total_P[0:self.n,:],sample_weight=self.total_weight[0:self.n])
    self.C=(self.kmeans.cluster_centers_)
    self.P=self.total_P[0:self.n,:]
    self.weight=self.total_weight[0:self.n]
    self.tree=KDTree(self.P)
    self.consistency.append(self.k)
    self.baseline_consistency.append(self.k)
    for i in range(self.k):
      self.wit.append([self.C[i]])

    self.time.append(time.time()-t0)
    self.time_brute_force.append(time.time()-t0)
    #print(self.wit,'\n')
    self.phase_begin=self.n

    self.cost_kmeans.append(self.kmeans.inertia_)
    self.cost_alg.append(self.kmeans.inertia_)


    while self.n <self.total_cnt:
      print('%d/%d'%(self.n,self.total_cnt),end='\r')
      self.phase_begin_C=self.C
      self.phase_begin=self.n
      base_C=copy.deepcopy(self.kmeans.cluster_centers_)

      #print("delete")
      t0=time.time()
      ell=self.detect_deleting_number()
      #print("delete ",ell," elements\n")

      for i in range(min(self.total_cnt-1-self.n,ell)):
        t_brute=time.time()
        self.cost_alg.append(self.inter_cost)
        self.P=self.total_P[0:self.n+i,:]
        self.weight=self.total_weight[0:self.n+i]
        self.kmeans.fit(self.P,sample_weight=self.weight)
        self.cost_kmeans.append(self.kmeans.inertia_)

        self.consistency.append(self.consistency[-1])
        self.baseline_consistency.append(self.baseline_consistency[-1])

        dist_load=np.min(cdist(self.kmeans.cluster_centers_,base_C),axis=1)
        self.baseline_consistency[-1]=self.baseline_consistency[-1]+np.count_nonzero(dist_load)
        self.time.append(0)
        self.time_brute_force.append(time.time()-t_brute)

      self.n=self.n+ell+1
      if self.n>=self.total_cnt:
        break

      self.consistency[-1]=self.consistency[-1]+ell

      

      self.P=self.total_P[0:self.n,:]
      self.weight=self.total_weight[0:self.n]
      self.tree=KDTree(self.P)

      self.swap_center()

      self.time.append(time.time()-t0)


      t0_brute=time.time()
      self.kmeans.fit(self.P,sample_weight=self.weight)

      base_end_C=copy.deepcopy(self.kmeans.cluster_centers_)
      self.baseline_consistency.append(self.baseline_consistency[-1])
      self.consistency.append(self.consistency[-1])

      dist_load=np.min(cdist(self.C,self.phase_begin_C),axis=1)
      self.consistency[-1]=self.consistency[-1]+np.count_nonzero(dist_load)

      dist_load=np.min(cdist(base_end_C,base_C),axis=1)
      self.baseline_consistency[-1]=self.baseline_consistency[-1]+np.count_nonzero(dist_load)


      #print('alg:',self.C)
      #print('base:',self.kmeans.cluster_centers_)
      self.nearest_neighbor()
      
      self.cost_kmeans.append(self.kmeans.inertia_)
      self.time_brute_force.append(time.time()-t0_brute)
      self.cost_alg.append(self.total_cost)

      #print('finish',self.n,'points\n')


