import numpy as np
import copy
from RF import *
import torch
from connected_component import *
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import rand_score
import matplotlib.pyplot as plt
import matplotlib.pyplot as plot

MS_clustering_bandwidth       = {'a1':0.1 ,'a2':0.1 ,'a3':0.1 ,'unbalance':0.2 ,'s1':0.2 ,'s2':0.2 ,'s3':0.2 ,'s4':0.2}
MS_clustering_iteration       = {'a1':100 ,'a2':100 ,'a3':100 ,'unbalance':100 ,'s1':100 ,'s2':100 ,'s3':100 ,'s4':100}
MS_clustering_threshold       = {'a1':0.9 ,'a2':0.9 ,'a3':0.9 ,'unbalance':0.9 ,'s1':0.9 ,'s2':0.9 ,'s3':0.9 ,'s4':0.9}


def run_MS(X, y,count,dataset):
	iteration = MS_clustering_iteration[dataset]
	bandwidth = MS_clustering_bandwidth[dataset]
	threshold = MS_clustering_threshold[dataset]
	
	X = torch.from_numpy(X)
	current = copy.deepcopy(X)

	for t in range(iteration):
		gram = Gaussian_kernel(current,X,bandwidth)
		adjusted = gram / torch.sum(gram,1).unsqueeze(1)
		current = (X.T @ adjusted.T).T

	final_gram = Gaussian_kernel(current,current,bandwidth)
	labels = CC(final_gram,threshold)

	NMI = normalized_mutual_info_score(labels,y)
	AMI = adjusted_mutual_info_score(labels,y)
	ARS = adjusted_rand_score(labels,y)
	RS = rand_score(labels,y)
	return NMI, AMI, ARS, RS