from sklearn.cluster import MeanShift
import numpy as np
import matplotlib.pyplot as plt
from img_util import *
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.utils import resample
import torch
from RF import *
from sklearn.mixture import GaussianMixture
import copy


def run_MS_seg(img, width, height, count, data, target):

	resized_image = cv2.resize(img, (int(width/6), int(height/6)), interpolation=cv2.INTER_LINEAR)
	new_data = []
	for h in range(int(height/6)):
	    for w in range(int(width/6)):
	        r, g, b = resized_image[h][w]
	        new_data.append([r, g, b])
	new_data = np.array(new_data).astype('float')

	bandwidth = 30

	X = torch.from_numpy(new_data)
	original_X = copy.deepcopy(X)
	for t in range(100):
		gram = Gaussian_kernel(X,original_X,bandwidth)
		adjusted = gram / torch.sum(gram,1).unsqueeze(1)
		X = (X.T @ adjusted.T).T
	data = X.numpy()

	#n_components_range = range(1, 6)
	#values = []
	#for k in n_components_range:
	#	gmm = GaussianMixture(n_components=k)
	#	gmm.fit(data)
	#	values.append(gmm.bic(data))
	#diff1 = np.diff(values)
	#diff2 = np.diff(diff1)
	#elbow = np.argmax(diff2) + 2
	#print(elbow)
	#gmm = GaussianMixture(n_components=elbow)
	#gmm.fit(data)
	#preds = gmm.predict(data)
	#preds = np.reshape(preds, (int(height/6),int(width/6)))
	#print(preds.shape)

	h=100
	cells = np.floor(data/h).astype(int)
	uniques, preds = np.unique(cells,axis=0,return_inverse=True)
	preds = np.reshape(preds, (int(height/6),int(width/6)))

	result = []
	for row in range(preds.shape[0]):
	    new_row = []
	    for element in range(preds.shape[1]):
	        new_row.append(np.full((6,6), preds[row][element]))
	    concatenated_row = np.hstack(new_row)
	    result.append(concatenated_row)
	result = np.vstack(result)


	new_shape = (height, width)
	rows_to_add = new_shape[0] - result.shape[0]
	cols_to_add = new_shape[1] - result.shape[1]
	result = np.pad(result,((0, rows_to_add), (0, cols_to_add)),mode='constant',constant_values=0)


	#print(width)
	#print(height)
	#print(result.shape)



	return result, result.flatten()



def run_BMS_seg(img, width, height, count, data, target):

	resized_image = cv2.resize(img, (int(width/6), int(height/6)), interpolation=cv2.INTER_LINEAR)
	new_data = []
	for h in range(int(height/6)):
	    for w in range(int(width/6)):
	        r, g, b = resized_image[h][w]
	        new_data.append([r, g, b])
	new_data = np.array(new_data).astype('float')

	bandwidth = 30

	X = torch.from_numpy(new_data)
	for t in range(100):
		gram = Gaussian_kernel(X,X,bandwidth)
		adjusted = gram / torch.sum(gram,1).unsqueeze(1)
		X = (X.T @ adjusted.T).T
	data = X.numpy()

	#n_components_range = range(1, 6)
	#values = []
	#for k in n_components_range:
	#	gmm = GaussianMixture(n_components=k)
	#	gmm.fit(data)
	#	values.append(gmm.bic(data))
	#diff1 = np.diff(values)
	#diff2 = np.diff(diff1)
	#elbow = np.argmax(diff2) + 2
	#print(elbow)
	#gmm = GaussianMixture(n_components=elbow)
	#gmm.fit(data)
	#preds = gmm.predict(data)
	#preds = np.reshape(preds, (int(height/6),int(width/6)))
	#print(preds.shape)


	h=100
	cells = np.floor(data/h).astype(int)
	uniques, preds = np.unique(cells,axis=0,return_inverse=True)
	preds = np.reshape(preds, (int(height/6),int(width/6)))


	result = []
	for row in range(preds.shape[0]):
	    new_row = []
	    for element in range(preds.shape[1]):
	        new_row.append(np.full((6,6), preds[row][element]))
	    concatenated_row = np.hstack(new_row)
	    result.append(concatenated_row)
	result = np.vstack(result)


	new_shape = (height, width)
	rows_to_add = new_shape[0] - result.shape[0]
	cols_to_add = new_shape[1] - result.shape[1]
	result = np.pad(result,((0, rows_to_add), (0, cols_to_add)),mode='constant',constant_values=0)


	#print(width)
	#print(height)
	#print(result.shape)



	return result, result.flatten()


























