import numpy as np
import copy
from RF import *
import torch
from connected_component import *
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import rand_score
import matplotlib.pyplot as plt
import matplotlib.pyplot as plot
from itertools import product

MS_clustering_h               = {'UserKnowledge':0.5,'iris':0.65,'WallRobot':0.2,'WirelessLocalization':0.5}
MS_clustering_iteration       = {'UserKnowledge':100,'iris':100 ,'WallRobot':100,'WirelessLocalization':100}
MS_clustering_threshold       = {'UserKnowledge':0.9,'iris':0.9 ,'WallRobot':0.9,'WirelessLocalization':0.9}


def run_MSPP(X, y,count,dataset):
	h = MS_clustering_h[dataset]
	iteration = MS_clustering_iteration[dataset]
	threshold = MS_clustering_threshold[dataset]

	col_min = X.min(axis=0)
	X = X - col_min

	new = copy.deepcopy(X)
	current = copy.deepcopy(X)

	t = 0
	v = [-1, 0, 1]

	offsets = list(product(v, repeat=X.shape[1]))
	#print(len(offsets))

	while True:
		current = copy.deepcopy(new)
		maxs = current.max(axis=0)
		shapes = np.ceil((maxs)/h).astype(int) + 1

		C = np.zeros(shapes)
		S = np.zeros(shapes.tolist() + [current.shape[1]])

		#print(shapes)
		#print(C.shape)
		#print(S.shape)
		
		cell_idx = np.floor(current/h).astype(int)

		#print(cell_idx.shape)
		
		for pt, idx in zip(current, cell_idx):
			C[tuple(idx)] += 1
			S[tuple(idx)] += pt


		sum_vec = np.zeros_like(current)
		count = np.zeros(len(X))
		idx = np.floor(current/h).astype(int)

		#print(sum_vec.shape)
		#print(count.shape)
		#print(idx.shape)


		for offset in offsets:
			neighbor_idx = idx + offset
			neighbor_idx[neighbor_idx<0] = 0
			
			for i in range(X.shape[1]):
				neighbor_idx[:,i][neighbor_idx[:,i]>=shapes[i]] = shapes[i]-1
			
			temp = list(map(tuple, neighbor_idx.T))
			
			#print(temp)
			#print(len(temp))
			#print(len(temp[0]))
			#print(len(temp[1]))
			#print(len(temp[2]))
			#print(len(temp[3]))
			#print(len(temp[4]))
			#print(X.shape)

			tup = tuple([temp[i] for i in range(X.shape[1])])

			sum_vec += S[tup]
			
			count += C[tup]
		count[count == 0] = 1
		new = sum_vec / count.reshape(-1,1)

		t+=1
		if t>=iteration:
			break


	final_gram = Gaussian_kernel(new,new,1)
	labels = CC(final_gram,threshold)

	NMI = normalized_mutual_info_score(labels,y)
	AMI = adjusted_mutual_info_score(labels,y)
	ARS = adjusted_rand_score(labels,y)
	RS = rand_score(labels,y)
	return NMI, AMI, ARS, RS