from PIL import Image, ImageCms, ImageFilter
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import matplotlib
import random
from sklearn import preprocessing
import cv2
from mpl_toolkits.mplot3d import Axes3D
from os import walk


def load_img(id):
	img = Image.open('data/BSDS500/images/'+str(id)+'.jpg')
	img = img.filter(ImageFilter.GaussianBlur(radius = 1))
	img = img.filter(ImageFilter.MedianFilter(size=3))
	img_rgb = img.convert('RGB')
	labels = scipy.io.loadmat('data/BSDS500/groundTruth/'+str(id)+'.mat')
	srgb_p = ImageCms.createProfile("sRGB")
	lab_p  = ImageCms.createProfile("LAB")
	rgb2lab = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "RGB", "LAB")
	img = ImageCms.applyTransform(img_rgb, rgb2lab)
	img = np.array(img)
	return img, labels['groundTruth'][0][0][0][0][0]


def show_img(img):
	plt.imshow(img)
	plt.title('Image')
	plt.show()
	return


def get_mapping():
	colors = [(1,1,1)] + [(random.random(),random.random(),random.random()) for i in range(255)]
	new_map = matplotlib.colors.LinearSegmentedColormap.from_list('new_map', colors, N=256)
	return new_map


def show_labels(labels,mapping):
	plt.imshow(labels, cmap=mapping)
	plt.title('Labels')
	plt.show()
	return



def relabel(y):
	d = {}
	count = 0
	for i in y:
		if i not in d:
			d[i] = count
			count += 1
	new_y = [d[i] for i in y]
	return np.array(new_y), count


def norm_data(X):
	scaler = preprocessing.StandardScaler().fit(X)
	X = scaler.transform(X)
	return X


def create_dataset(img,labels,norm):
	data = []
	target = []
	height, width, _ = img.shape
	for h in range(height):
	    for w in range(width):
	        r, g, b = img[h][w]
	        data.append([r, g, b])
	        target.append(labels[h][w])

	assert len(data)==len(target)

	target, count = relabel(target)
	if norm == True:
		data = np.array(norm_data(data)).astype('float')
	else:
		data = np.array(data).astype('float')


	return width, height, count, data, target




def post_processing(centers,preds):
	print("===Post-processing===")

	unique, counts = np.unique(preds, return_counts=True)

	print(unique, counts)

	print("===Post-processing===")
	return preds



def get_all_id():
	f = []
	for (dirpath, dirnames, filenames) in walk('data/BSDS500/images/'):
	    f.extend(filenames)
	    break
	f = [int(_id.split('.')[0]) for _id in f]
	return f



def save_img(_id,_method,img,mapping):
	if len(img.shape) == 2:
		plt.imsave('result/'+_id+'_'+_method+'.png', img, cmap=mapping)
	if len(img.shape) == 3:
		img = img/255
		img[img>1] = 1
		img[img<0] = 0
		plt.imsave('result/'+_id+'_'+_method+'.png', img)
	return 1









def seg(img,new_img):
	mean_shift = new_img.astype(np.uint8)
	gray = cv2.cvtColor(mean_shift, cv2.COLOR_RGB2GRAY)
	_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
	kernel = np.ones((3,3), np.uint8)
	opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
	dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
	_, foreground = cv2.threshold(dist_transform, 0.7 * dist_transform.max(), 255, 0)
	background = cv2.dilate(opening, kernel, iterations=3)
	foreground = np.uint8(foreground)
	unknown = cv2.subtract(background, foreground)
	markers = cv2.connectedComponents(foreground)[1]
	markers = markers + 1  # Avoid 0 since watershed treats 0 as unknown region
	markers[unknown == 255] = 0  # Mark unknown region as 0
	image_copy = img.copy()
	cv2.watershed(image_copy, markers)
	image_copy[markers == -1] = [255, 0, 0] 
	segmented_image = np.zeros_like(img)
	for label in np.unique(markers):
	    if label == 1:  # Background, ignore it
	        continue
	    mask = markers == label
	    segmented_image[mask] = np.random.randint(0, 255, 3, dtype=np.uint8)
	# Step 10: Display Results
	plt.figure(figsize=(15, 5))
	plt.subplot(1, 3, 1)
	plt.title('Original Image')
	plt.imshow(img)
	plt.axis('off')

	plt.subplot(1, 3, 2)
	plt.title('Watershed Boundaries')
	plt.imshow(image_copy)
	plt.axis('off')

	plt.subplot(1, 3, 3)
	plt.title('Segmented Image (Colored)')
	plt.imshow(segmented_image)
	plt.axis('off')

	plt.show()








if __name__ == '__main__':
	
	print(get_all_id())

	'''
	img, labels = load_img(3063)
	print(img.shape)
	print(labels.shape)

	mapping = get_mapping()
	show_img(img)
	show_labels(labels,mapping)

	width, height, count, data, target = create_dataset(img,labels,False)
	print(width)
	print(height)
	print(count)
	print(data.shape)
	print(target.shape)


	fig = plt.figure()
	ax = fig.add_subplot(111, projection='3d')
	np.random.shuffle(data)
	ax.scatter(data[:1000,0], data[:1000,1], data[:1000,2])
	ax.set_xlabel('R')
	ax.set_ylabel('G')
	ax.set_zlabel('B')
	plt.show()
	'''


































