import numpy as np
import copy
from RF import *
import torch
from connected_component import *
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import rand_score
import matplotlib.pyplot as plt
import matplotlib.pyplot as plot

GMS_clustering_bandwidth       = {'a1':0.1   ,'a2':0.1   ,'a3':0.1   ,'unbalance':0.5  ,'s1':0.2  ,'s2':0.2  ,'s3':0.2  ,'s4':0.2  }
GMS_clustering_iteration       = {'a1':100   ,'a2':100   ,'a3':100   ,'unbalance':100  ,'s1':100  ,'s2':100  ,'s3':100  ,'s4':100  }
GMS_clustering_threshold       = {'a1':0.9   ,'a2':0.9   ,'a3':0.9   ,'unbalance':0.9  ,'s1':0.9  ,'s2':0.9  ,'s3':0.9  ,'s4':0.9  }
GMS_clustering_learning_rate   = {'a1':0.001 ,'a2':0.001 ,'a3':0.001 ,'unbalance':0.001,'s1':0.003,'s2':0.003,'s3':0.003,'s4':0.003}


def run_GMS(X, y,count,dataset):
	iteration = GMS_clustering_iteration[dataset]
	bandwidth = GMS_clustering_bandwidth[dataset]
	threshold = GMS_clustering_threshold[dataset]
	learning_rate = GMS_clustering_learning_rate[dataset]

	
	X = torch.from_numpy(X)
	current = copy.copy(X)
	num, d = X.shape
	c = (2*np.pi*(bandwidth**2))**(d/2)

	for t in range(iteration):
		gram = Gaussian_kernel(current,X,bandwidth)
		diff = X.repeat(num,1,1) - torch.transpose(current.repeat(num,1,1),0,1)
		gradient = (1/((bandwidth**2)*(c*num))) * torch.bmm(gram.unsqueeze(1),diff).squeeze()
		density = (1/(c*num)) * (gram).sum(1).unsqueeze(1)
		update = gradient / density
		current = current + (learning_rate * update)

	final_gram = Gaussian_kernel(current,current,bandwidth)
	labels = CC(final_gram,threshold)

	NMI = normalized_mutual_info_score(labels,y)
	AMI = adjusted_mutual_info_score(labels,y)
	ARS = adjusted_rand_score(labels,y)
	RS = rand_score(labels,y)
	return NMI, AMI, ARS, RS