from sklearn.cluster import KMeans
import numpy as np
import matplotlib.pyplot as plt
from img_util import *
from RF import *
import torch
import copy
from scipy import stats
from sys import exit
from sklearn.metrics.cluster import adjusted_rand_score
from numpy import linalg as LA
from mpl_toolkits.mplot3d import Axes3D
import cv2
from connected_component import *
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import MeanShift
from sklearn.mixture import GaussianMixture
from itertools import product



def RFMS_shift(data,dim,bandwidth,iteration,learning_rate,smoothing_parameter,QMC=True,blurring=False):
    n, in_dim = data.shape
    if QMC:
        encoder = QMCRF_Encoder(in_dim,dim,bandwidth)
    else:
        encoder = MCRF_Encoder(in_dim,dim,bandwidth)
    encodings = encoder.encode_x(data)
    
    c = (2*np.pi*(bandwidth**2))**(in_dim/2)
    density_encoding = (1 / (n*c)) * encodings.sum(axis=0)
    
    current_position = copy.deepcopy(data)
    current_encoding = copy.deepcopy(encodings)
    
    for t in range(iteration):
        random_direction = np.random.normal(size=(n, in_dim))
        prob = current_position + (smoothing_parameter*random_direction)
        prob_encoding = encoder.encode_x(prob)
        local_density = encoder.similarity(current_encoding,density_encoding)
        prob_density = encoder.similarity(prob_encoding,density_encoding)
        gradient = random_direction * (((prob_density - local_density).real)/smoothing_parameter)[:, np.newaxis]
        
        update = learning_rate * (gradient / (local_density[:, np.newaxis]))
        
        #print(np.linalg.norm(update, axis=1).sum())
        current_position = current_position + update
        current_position[current_position>255] = 255
        current_position[current_position<0] = 0
        current_encoding = encoder.encode_x(current_position)
        if blurring:
            density_encoding = (1 / (n*c)) * current_encoding.sum(axis=0)
        
    return current_position





def run_RFMS_seg(img, width, height, count, data, target):
	n_clusters = count
	n_input = 3
	dim = 10
	bandwidth = 25
	step_size = 0.1
	learning_rate = 30
	random_state = None
	iteration = 100

	data = RFMS_shift(data,dim,bandwidth,iteration,learning_rate,step_size,True,True)
	
	new_img = data
	new_img = np.reshape(new_img,(height,width,3)).astype(int)
	new_img[new_img>255] = 255
	new_img[new_img<0] = 0
	
	#print(data)
	#print(cumulative_shift)
	#show_img(img)
	#show_img(new_img)


	h=100
	cells = np.floor(data/h).astype(int)
	uniques, preds = np.unique(cells,axis=0,return_inverse=True)

	return np.reshape(preds, (height,width)), preds









if __name__ == '__main__':
	img, labels = load_img(385028)
	width, height, count, data, target = create_dataset(img,labels,False)
	print(count)
	new_img, pred, center = run_RFMS_seg(img, width, height, count, data, target)

	pred = np.reshape(pred, labels.shape)

	mapping = get_mapping()
	#show_img(img)
	#show_labels(labels,mapping)
	show_img(new_img)
	show_labels(pred,mapping)





















