from keras.utils import np_utils
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import DBSCAN
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial import distance
from datetime import datetime
import matplotlib.cm as cm
class SubSigma:

    def __init__(self, input_size, dimensions, adaptation_rate,end_combine_rate=0.3, noise=False):

        self.name = "SubSigma"
        self.organized= False
        self.space_size= 10
        self.dimensions= int(dimensions)
        self.velocity= np.zeros((input_size,self.dimensions))
        self.input_size= input_size
        #syncmap= np.zeros((input_size,dimensions))
        self.syncmap= np.random.rand(input_size,self.dimensions)
        self.velocity= np.zeros((input_size,dimensions))
        self.synapses_matrix = np.zeros([input_size, input_size])
        self.adaptation_rate= adaptation_rate
        #self.syncmap= np.random.rand(dimensions, input_size)
        self.end_combine_rate=end_combine_rate
        self.ims = []
        self.fps = 0
        self.count_rate=0
        # self.anim = True
        # if self.anim == True:
        #	  self.fig = plt.figure()
        #	  self.ax = self.fig.add_subplot(111, projection='3d')
        self.meta_count=0
        self.meta_label=np.zeros(input_size,dtype=int)
        self.label_history=[]
        self.syncmap_history = []
        self.combinable= True
        self.noise_b = 0
        if noise is True:
            self.noise_b = 1
            self.name = "SyncMap with Noise"
    def get_meta_center(self):
        if 0 in np.unique(self.meta_label):
            #print(self.meta_label,"any==0")
            self.meta_center=np.zeros(((self.meta_label==0).sum()+np.unique(
                self.meta_label).shape[0]-1,self.dimensions))
        else:
            #print(self.meta_label,"any!=0")
            self.meta_center=np.zeros(((self.meta_label==0).sum()+np.unique(
                self.meta_label).shape[0],self.dimensions))
        #print(self.meta_center.shape,"meta_center shape")
        #print(self.meta_center.shape[0])
        #print(self.meta_center.shape,"meta_center.shape")
        #print(meta_center.shape)
        #calcualte all the center of the same label
        self.meta_index=[]
        if self.meta_count==0:
            self.meta_center=self.syncmap
            for a in range(self.input_size):
                self.meta_index.append([a])
        else:
            uniq=np.unique(self.meta_label)
            zero=np.argwhere(self.meta_label==0)
            count_temp=0
            for a in zero:
                self.meta_index.append(a)
                self.meta_center[count_temp]=self.syncmap[a]
                count_temp+=1
            #print(self.meta_index,"meta_index")
            for a in uniq:
                if a!=0:
                    #print(self.meta_label==a)
                    self.meta_index.append(np.argwhere(self.meta_label==a))
                    self.meta_center[count_temp]=np.mean(self.syncmap[self.meta_label==a],axis=0)
                    count_temp+=1
    def meta_combine(self):
        #calculate the distant between
        
        distant=np.zeros((self.meta_center.shape[0],self.meta_center.shape[0]))
        for a in range(self.meta_center.shape[0]):
            for b in range(self.meta_center.shape[0]):
                distant[a,b]=np.linalg.norm(self.meta_center[a]-self.meta_center[b])
        distant[distant==0]=99
        min_distant_index=distant.argmin()
        
        combine_dot=[min_distant_index%self.meta_center.shape[0],
                     int(min_distant_index/self.meta_center.shape[0])]
        
        self.meta_averge_distant=distant[np.triu(distant,1).astype(bool)].mean()
        self.min_distant=distant.min()
        uniq_temp=np.unique(self.meta_label)
        label_num=uniq_temp.shape[0]
        if (self.meta_label==0).sum()!=0:
            label_num=label_num-1+(self.meta_label==0).sum()
        if self.min_distant/self.meta_averge_distant>self.end_combine_rate or label_num==3:                     # choose the rate
#             print("0 is in",self.meta_label)
#             print(0  in self.meta_label)
            self.combinable= False
            return
            if 0 in self.meta_label:
                return
            self.combinable= False
            print("The rate is less then we choose")
            return
        
#         print("meta_index",self.meta_index,"meta_index long",len(self.meta_index))
#         print("should in same label",self.meta_index[combine_dot[0]],self.meta_index[combine_dot[1]])
#         print(self.meta_label)    
#         print("meta_index",self.meta_index)
#         print("combine_dot",combine_dot)
        self.meta_count+=1
        
        #change the meta_label
        for a in combine_dot:
            if len(self.meta_index[a])==1:
                self.meta_label[self.meta_index[a]]=self.meta_count
            else:
                self.meta_label[self.meta_index[a]]=self.meta_count
#         print("meta_label",self.meta_label)
        #if the meta is more
        #for a in range(self.meta_count):
        #print(self.meta_label)   

    def combine_meta(self):
        self.get_meta_center()
        self.meta_combine()
    def meta_train(self,x):
        sequence_size = x.shape[0]
        self.get_meta_center()
#        print("meta_label",self.meta_label)
        if 0 in np.unique(self.meta_label):
            #print(self.meta_label,"any==0")
            x_meta=np.zeros((sequence_size,(self.meta_label==0).sum()+np.unique(
                self.meta_label).shape[0]-1))
        else:
            #print(self.meta_label,"any!=0")
            x_meta=np.zeros((sequence_size,(self.meta_label==0).sum()+np.unique(
                self.meta_label).shape[0]))
#         print("syncmap",self.syncmap)
#         print("x_meta.shape",x_meta.shape)
#         print("meta_index",self.meta_index)
#         print("meta_center",self.meta_center)
        if len(self.meta_index)==3:
            self.combinable=False
            return
        self.count=np.zeros(len(self.meta_index))
        for a in range(len(self.meta_index)):
            x_meta[:,a]=np.mean(x[:,self.meta_index[a]],axis=1).reshape(sequence_size)  #ask for the x

        #training
        self.velocity= np.zeros((len(self.meta_index),self.dimensions))
        plus= x_meta > 0.1
        #minus = ~ plus
        for i in range(sequence_size):
            vplus= plus[i,:]
            #print(vplus,"vplus-----------")
            #print(self.count,"self.count")
            #vplus[self.count>3000]=False   # choose the maxima activate number
            #print(vplus,"vplus change")
            vminus=  ~vplus
            plus_mass = vplus.sum()
            minus_mass = vminus.sum()
            if plus_mass <= 1:
                continue

            if minus_mass <= 1:
                continue
            self.count+=vplus
            #print(self.count,"count")
            self.feq=self.count/self.count.sum()
            #print(self.feq,"feq")
            
            freq_ratio=1/(np.sum(self.feq*vplus)/plus_mass)/(1/self.feq[self.feq!=0].min())
            #print(freq_ratio,"freq_ratio=")
            center_plus= np.dot(vplus,self.meta_center)/plus_mass
            center_minus= np.dot(vminus,self.meta_center)/minus_mass

            #print(self.syncmap.shape)
            #exit()
            dist_plus= distance.cdist(center_plus[None,:], self.meta_center, 'euclidean')
            dist_minus= distance.cdist(center_minus[None,:], self.meta_center, 'euclidean')
            dist_plus= np.transpose(dist_plus)
            dist_minus= np.transpose(dist_minus)
            update_plus= vplus[:,np.newaxis]*(3*(center_plus - self.meta_center)/dist_plus - (self.meta_center - center_minus)/dist_minus)
            update_minus= vminus[:,np.newaxis]*(3*(center_minus -self.meta_center)/dist_minus- (self.meta_center - center_plus)/dist_plus)
            #update_plus= vplus[:,np.newaxis]*((center_plus - self.meta_center)/dist_plus )
            #update_minus= vminus[:,np.newaxis]*((center_minus -self.meta_center)/dist_minus)                                 
            #damping force
            self.velocity*= 0.999
            #self.adaptation_rate*=0.999999

            #noise = np.random.normal(0, 0.01, self.syncmap.shape)
            noise= np.amax(self.meta_center)*np.random.rand(len(self.meta_index), self.dimensions)
            #self.syncmap= self.space_size*self.syncmap/maximum
            #print(np.mean(self.count),"mean")
            #print(self.count[:, np.newaxis],"self.count")
            update= (update_plus - update_minus) + noise
            self.velocity += update#   freq_ratio*update #     #change
            #print("self.velocity",self.velocity)
            self.meta_center +=self.adaptation_rate*self.velocity
            maximum=self.meta_center.max()
            self.meta_center= self.space_size*self.meta_center/maximum
            #bulit training the syncmap
            for a in range(len(self.meta_index)):
                #print("self.syncmap[self.meta_index[a]]",self.syncmap[self.meta_index[a]])
                #print("self.velocity[a]",self.velocity[a])
                self.syncmap[self.meta_index[a]]+=self.adaptation_rate*self.velocity[a]
            
            
            #self.syncmap+= self.adaptation_rate*update 

            #+ noise*self.noise_b
            # self.noise_b *= 0.99
            maximum=self.syncmap.max()
            self.syncmap= self.space_size*self.syncmap/maximum
            if i%10==0:
                self.syncmap_history.append(self.syncmap)
                #print(self.meta_label,"self.meta_label")
                self.label_history.append(self.meta_label.copy())
        #print("counting",self.count)

            
    def inputGeneral(self, x):
        plus= x > 0.1
        minus = ~ plus

        sequence_size = x.shape[0]
        #print(sequence_size, "asfasdfasdfasd")
        for i in range(sequence_size):

            vplus= plus[i,:]
            vminus= minus[i,:]
            plus_mass = vplus.sum()
            minus_mass = vminus.sum()

            #print(plus_mass)
            #print(minus_mass)

            if plus_mass <= 1:
                continue

            if minus_mass <= 1:
                continue

            #print("vplus")
            #print(vplus)

            center_plus= np.dot(vplus,self.syncmap)/plus_mass
            center_minus= np.dot(vminus,self.syncmap)/minus_mass

            #print(self.syncmap.shape)
            #exit()
            dist_plus= distance.cdist(center_plus[None,:], self.syncmap, 'euclidean')
            dist_minus= distance.cdist(center_minus[None,:], self.syncmap, 'euclidean')
            dist_plus= np.transpose(dist_plus)
            dist_minus= np.transpose(dist_minus)

            #update_plus= vplus[:,np.newaxis]*((center_plus - self.syncmap)/dist_plus + (self.syncmap - center_minus)/dist_minus)
            #update_minus= vminus[:,np.newaxis]*((center_minus -self.syncmap)/dist_minus + (self.syncmap - center_plus)/dist_plus)
            update_plus= vplus[:,np.newaxis]*(3*(center_plus - self.syncmap)/dist_plus - (self.syncmap - center_minus)/dist_minus)
            update_minus= vminus[:,np.newaxis]*(3*(center_minus -self.syncmap)/dist_minus - (self.syncmap - center_plus)/dist_plus)

            #damping force
            self.velocity*= 0.999
            #self.adaptation_rate*=0.999999

            #noise = np.random.normal(0, 0.01, self.syncmap.shape)
            noise= np.amax(self.syncmap)*np.random.rand(self.input_size, self.dimensions)

            #self.syncmap= self.space_size*self.syncmap/maximum
            update= update_plus - update_minus + noise
            self.velocity += update
            self.syncmap+= self.adaptation_rate*self.velocity
            #self.syncmap+= self.adaptation_rate*update 

            #+ noise*self.noise_b
            # self.noise_b *= 0.99

            maximum=self.syncmap.max()
            self.syncmap= self.space_size*self.syncmap/maximum
            #print("max",self.syncmap.max())

            # self.syncmap = self.syncmap - self.syncmap.mean(axis=0)
            # self.syncmap = self.syncmap / np.abs(self.syncmap).max(axis=0)

            # if self.anim == True:
            #	  if self.fps >= 100:
            #		  img = self.ax.scatter(self.syncmap[:,0], self.syncmap[:,1], self.syncmap[:,2], color="blue")
            #		  self.ims.append([img])
            #		  self.fps = 0
            #	  else:
            #		  self.fps += 1

            self.syncmap_history.append(self.syncmap)


    def input(self, x):
        x_split=np.split(x,1)
        data_choose=0
        self.meta_train(x_split[data_choose%len(x_split)])
        data_choose+=1
        combine_num=0
        while self.combinable:
            
            #self.combine_logic()
            self.combine_meta()
            combine_num+=1
            if combine_num==5:
                self.meta_train(x_split[data_choose%len(x_split)])
                data_choose+=1
                combine_num=0
            
        #print("traning time",data_choose)
        return

        print(x.shape)
        plus= x > 0.1
        minus = ~ plus
    #	 print(plus)
    #	 print(minus)

    #	 print(plus.shape)
    #	 print(type(plus))

    #	 print(x.shape)
    #	 print("in",x[1,:])
    #	 print("map",self.syncmap)


        sequence_size = x.shape[0]
        for i in range(sequence_size):
            vplus= plus[i,:]
            vminus= minus[i,:]
            plus_mass = vplus.sum()
            minus_mass = vminus.sum()
            #print(self.syncmap)
            #print("plus",vplus)
            if plus_mass <= 1:
                continue

            if minus_mass <= 1:
                continue

            #if plus_mass > 0:
            center_plus= np.dot(vplus,self.syncmap)/plus_mass
            #else:
            #	 center_plus= np.dot(vplus,self.syncmap)

            #print(center_plus)
            #exit()
            #if minus_mass > 0:
            center_minus= np.dot(vminus,self.syncmap)/minus_mass
            #else:
            #	 center_minus= np.dot(vminus,self.syncmap)


            #print("mass", minus_mass)
            #print(center_plus)
            #print("minus",vminus)
            #print(center_minus/minus_mass)
            #print(self.syncmap)
            #exit()

            #print(vplus)
            #print(self.syncmap.shape)
            #a= np.matmul(np.transpose(vplus),self.syncmap)
            #a= vplus.dot(self.syncmap)
            #a= (vplus*self.syncmap.transpose()).transpose()
            #update_plus= vplus[:,np.newaxis]*self.syncmap
        #	 update_plus= vplus[:,np.newaxis]*(center_plus -center_minus)*plus_mass
            update_plus= vplus[:,np.newaxis]*(center_plus -center_minus)
        #	 update_plus= vplus[:,np.newaxis]*(center_plus -center_minus)/plus_mass
            #update_plus= vplus[:,np.newaxis]*(center_plus -self.syncmap)
        #	 update_minus= vminus[:,np.newaxis]*(center_minus -center_plus)*minus_mass
            update_minus= vminus[:,np.newaxis]*(center_minus -center_plus)
        #	 update_minus= vminus[:,np.newaxis]*(center_minus -center_plus)/minus_mass
            #update_minus= vminus[:,np.newaxis]*(center_minus -self.syncmap)
            #print(self.syncmap)
            #print(center_plus)
            #print(center_plus - self.syncmap)
            #update_minus= vminus[:,np.newaxis]*self.syncmap

            #self.plot()

            #ax.scatter(center_plus[0], center_plus[1])
            #ax.scatter(center_minus[0], center_minus[1])

            #plt.show()

            update= update_plus + update_minus
            self.syncmap+= self.adaptation_rate*update

            maximum=self.syncmap.max()
            self.syncmap= self.space_size*self.syncmap/maximum


    def organize(self, eps=3):


        return self.meta_label

    def hierarchical_organize(self, n=2):

        self.organized= True
        self.labels= AgglomerativeClustering(distance_threshold=None, n_clusters=n).fit_predict(self.syncmap)

        return self.labels

    def activate(self, x):
        '''
        Return the label of the index with maximum input value
        '''

        if self.organized == False:
            print("Activating a non-organized SyncMap")
            return

        #maximum output
        max_index= np.argmax(x)

        return self.labels[max_index]

    def plotSequence(self, input_sequence, input_class,filename="plot.png"):

        input_sequence= input_sequence[500:1000]
        input_class= input_class[500:1000]

        a= np.asarray(input_class)
        t = [i for i,value in enumerate(a)]
        c= [self.activate(x) for x in input_sequence] 


        plt.plot(t, a, '-g')
        plt.plot(t, c, '-.k')
        #plt.ylim([-0.01,1.2])


        plt.savefig(filename,quality=1, dpi=300)
        plt.show()
        plt.close()


    # def plot(self, label, color=None, save = False, filename= "plot_map.png"):

    #	  if color is None:
    #		  color= label

    #	  print(self.syncmap)
    #	  #print(self.syncmap)
    #	  #print(self.syncmap[:,0])
    #	  #print(self.syncmap[:,1])
    #	  if self.dimensions == 2:
    #		  #print(type(color))
    #		  #print(color.shape)
    #		  ax= plt.scatter(self.syncmap[:,0],self.syncmap[:,1], c=color)

    #	  if self.dimensions == 3:
    #		  fig = plt.figure()
    #		  ax = plt.axes(projection='3d')

    #		  ax.scatter3D(self.syncmap[:,0],self.syncmap[:,1], self.syncmap[:,2], c=color)
    #		  #ax.plot3D(self.syncmap[:,0],self.syncmap[:,1], self.syncmap[:,2])

    #	  if save == True:
    #		  plt.savefig(filename)

    #	  plt.show()
    #	  plt.close()

    def plot(self, color=None, save = False, filename= "plot_map.png"):

        if color is None:
            color= self.labels

        #print(self.syncmap)
        #print(self.syncmap[:,0])
        #print(self.syncmap[:,1])
        if self.dimensions == 2:
            #print(type(color))
            #print(color.shape)
            ax= plt.scatter(self.syncmap[:,0],self.syncmap[:,1], c=color)

        if self.dimensions == 3:
            fig = plt.figure()
            ax = plt.axes(projection='3d')

            ax.scatter3D(self.syncmap[:,0],self.syncmap[:,1], self.syncmap[:,2], c=color);
            #ax.plot3D(self.syncmap[:,0],self.syncmap[:,1], self.syncmap[:,2])

        if save == True:
            plt.savefig(filename)

        plt.show()
        plt.close()

    def plot_w_line(self, label, color=None, save = False, filename= "plot_map.png"):

        if color is None:
            color= label

        print(self.syncmap)
        #print(self.syncmap)
        #print(self.syncmap[:,0])
        #print(self.syncmap[:,1])
        if self.dimensions == 2:
            #print(type(color))
            #print(color.shape)
            ax= plt.scatter(self.syncmap[:,0],self.syncmap[:,1], c=color)

        if self.dimensions == 3:
            fig = plt.figure()
            ax = plt.axes(projection='3d')

            ax.scatter3D(self.syncmap[:,0],self.syncmap[:,1], self.syncmap[:,2], c=color)

        if save == True:
            plt.savefig(filename)

        plt.show()
        plt.close()

    def TP_matrix(self, x):

        if len(x) == 0:
            return

        output_size = x.shape[1]
        Dtable = np.zeros([output_size, output_size])
        TPMatrix = np.zeros([output_size, output_size])

        prev_state = np.argmax(x[0])
        state = None
        for i in x[1:]:
            if np.max(i) != 1:
                continue
            state = np.argmax(i)
            Dtable[prev_state][state]+= 1		 
            prev_state = state

        for i, j in enumerate(Dtable):
            state_total = np.sum(j)
            TPMatrix[i] = Dtable[i]/state_total

        self.TPMatrix = TPMatrix
        return TPMatrix

    def save(self, filename):
        """save class as self.name.txt"""
        file = open(filename+'.txt','w')
        file.write(cPickle.dumps(self.__dict__))
        file.close()

    def load(self, filename):
        """try load self.name.txt"""
        file = open(filename+'.txt','r')
        dataPickle = file.read()
        file.close()

        self.__dict__ = cPickle.loads(dataPickle)

    def save_weight(self, path):
        np.save(path, self.syncmap)

    def animation_init(self):
        return self.img1, self.anos
        
    def animation_update(self, n):    
        lab = self.label_history[n]
        #print(lab,"labe")
        uniq=np.unique(lab)
        color=cm.rainbow(np.linspace(0, 1, len(uniq)))
        color,_=np.split(color,[3],axis=1)
        if self.dimensions == 2:
            self.img1.set_offsets(self.syncmap_history[n])
        if self.dimensions == 3:
            self.img1._offsets3d = (self.syncmap_history[n,:,0], self.syncmap_history[n,:,1], self.syncmap_history[n,:,2])
        self.img1.set_array(lab)
        for i, txt in enumerate(self.true_labels):
            self.anos[i].set_x(self.syncmap_history[n,i,0])
            self.anos[i].set_y(self.syncmap_history[n,i,1])
            #print(color[np.where(uniq==lab[i])],'color',color[np.where(uniq==lab[i])].shape)
            self.anos[i].set_c(color[np.where(uniq==lab[i])][0])
        return self.img1, self.anos
    
    def plot_animation(self, name, true_labels=None):
        fig = plt.figure()
        if self.dimensions == 2:
            self.ax = fig.add_subplot()
        if self.dimensions ==3:
            self.ax = fig.add_subplot(111, projection='3d')
        
        self.syncmap_history = np.array(self.syncmap_history[::5])
        self.label_history=np.array(self.label_history[::5])
        self.true_labels = true_labels
        self.anos = []
        for i, txt in enumerate(self.true_labels):
            self.anos.append(self.ax.annotate(txt, (self.syncmap_history[0,i,0], self.syncmap_history[0,i,1])))
        
        ax_lim_max = np.max(self.syncmap_history)
        ax_lim_max = ax_lim_max + ax_lim_max * 0.1
        ax_lim_min = np.min(self.syncmap_history)
        ax_lim_min = ax_lim_min + ax_lim_min * 0.1

        self.ax.set_xlim(ax_lim_min, ax_lim_max)
        self.ax.set_ylim(ax_lim_min, ax_lim_max)
        if self.dimensions == 2:
            self.img1 = self.ax.scatter([],[],[], cmap="rgb")
        if self.dimensions ==3:
            self.ax.set_zlim(ax_lim_min, ax_lim_max)
            self.img1 = self.ax.scatter3D([],[],[], cmap="rgb")
        
        print("Generating animation...")
        ani = animation.FuncAnimation(fig, self.animation_update, frames=len(self.syncmap_history), init_func=self.animation_init, interval=50, blit=False)
        print("Generating animation completed!")
        print("Saving animation as MP4")
        animation_path =  "movies_output/movie_" + name + "_" + self.name + datetime.now().strftime('_%d%m%Y_%H%M%S') + ".mp4"
        ani.save(animation_path)
        print("Saving animation as MP4 completed!")
        
        self.syncmap_history = []
        

