
from get_data import *
from RF import *
import numpy as np
import torch
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import rand_score
from scipy.cluster.hierarchy import DisjointSet
import matplotlib.pyplot as plot
import random
from connected_component import *
import copy
import matplotlib.pyplot as plt
import matplotlib.pyplot as plot



RFMS_simple_iteration       = {'a1':100   ,'a2':100   ,'a3':100    ,'unbalance':100    ,'s1':100   ,'s2':100   ,'s3':100   ,'s4':100   ,'iris':100   }
RFMS_simple_bandwidth       = {'a1':0.1   ,'a2':0.1   ,'a3':0.1     ,'unbalance':0.5    ,'s1':0.2   ,'s2':0.2   ,'s3':0.2   ,'s4':0.2   ,'iris':0.5   }
RFMS_simple_threshold       = {'a1':0.9   ,'a2':0.9   ,'a3':0.9    ,'unbalance':0.9    ,'s1':0.9   ,'s2':0.9   ,'s3':0.9   ,'s4':0.9   ,'iris':0.9   }
RFMS_simple_step_size       = {'a1':0.0005,'a2':0.0005,'a3':0.0005 ,'unbalance':0.0005 ,'s1':0.0005,'s2':0.0005,'s3':0.0005,'s4':0.0005,'iris':0.0005}
RFMS_simple_learning_rate   = {'a1':0.001 ,'a2':0.001 ,'a3':0.001 ,'unbalance':0.001  ,'s1':0.003 ,'s2':0.003 ,'s3':0.003 ,'s4':0.003 ,'iris':0.03  }
RFMS_simple_n_input         = {'a1':2     ,'a2':2     ,'a3':2      ,'unbalance':2      ,'s1':2     ,'s2':2     ,'s3':2     ,'s4':2     ,'iris':4     }
RFMS_simple_dim             = {'a1':300   ,'a2':300   ,'a3':300    ,'unbalance':300    ,'s1':300   ,'s2':300   ,'s3':300   ,'s4':300   ,'iris':300   }




def RFMS_shift(data,dim,bandwidth,iteration,learning_rate,smoothing_parameter,QMC=True,blurring=False):
    n, in_dim = data.shape
    if QMC:
        encoder = QMCRF_Encoder(in_dim,dim,bandwidth)
    else:
        encoder = MCRF_Encoder(in_dim,dim,bandwidth)
    encodings = encoder.encode_x(data)
    
    c = (2*np.pi*(bandwidth**2))**(in_dim/2)
    density_encoding = (1 / (n*c)) * encodings.sum(axis=0)
    
    current_position = copy.deepcopy(data)
    current_encoding = copy.deepcopy(encodings)
    
    for t in range(iteration):
        random_direction = np.random.normal(size=(n, in_dim))
        prob = current_position + (smoothing_parameter*random_direction)
        prob_encoding = encoder.encode_x(prob)
        local_density = encoder.similarity(current_encoding,density_encoding)
        prob_density = encoder.similarity(prob_encoding,density_encoding)
        gradient = random_direction * (((prob_density - local_density).real)/smoothing_parameter)[:, np.newaxis]
        update = learning_rate * (gradient / (local_density[:, np.newaxis]))
        current_position = current_position + update
        current_encoding = encoder.encode_x(current_position)
        if blurring:
            density_encoding = (1 / (n*c)) * current_encoding.sum(axis=0)
    
    return current_position



def run_RFMS(X, y,count,dataset):
	n_input = RFMS_simple_n_input[dataset]
	dim = RFMS_simple_dim[dataset]
	bandwidth = RFMS_simple_bandwidth[dataset]
	step_size = RFMS_simple_step_size[dataset]
	learning_rate = RFMS_simple_learning_rate[dataset]
	threshold = RFMS_simple_threshold[dataset]
	iteration = RFMS_simple_iteration[dataset]

	shifted = RFMS_shift(X,dim,bandwidth,iteration,learning_rate,step_size,True,False)

	final_gram = Gaussian_kernel(shifted,shifted,bandwidth)
	pred = CC(final_gram,threshold)
	return normalized_mutual_info_score(pred,y), adjusted_mutual_info_score(pred,y), adjusted_rand_score(pred,y), rand_score(pred,y)



def run_RFMS_blurring(X, y,count,dataset):
	n_input = RFMS_simple_n_input[dataset]
	dim = RFMS_simple_dim[dataset]
	bandwidth = RFMS_simple_bandwidth[dataset]
	step_size = RFMS_simple_step_size[dataset]
	learning_rate = RFMS_simple_learning_rate[dataset]
	threshold = RFMS_simple_threshold[dataset]
	iteration = RFMS_simple_iteration[dataset]
	shifted = RFMS_shift(X,dim,bandwidth,iteration,learning_rate,step_size,True,True)

	final_gram = Gaussian_kernel(shifted,shifted,bandwidth)
	pred = CC(final_gram,threshold)
	return normalized_mutual_info_score(pred,y), adjusted_mutual_info_score(pred,y), adjusted_rand_score(pred,y), rand_score(pred,y)








