from sklearn.metrics.cluster import normalized_mutual_info_score
import sklearn
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import pylab as pl
import matplotlib as mpl
import shutil

#from HierarchyChunkTest import *
#from ChunkTest import *
from OverlapChunkTest1 import *
from OverlapChunkTest2 import *
#from HierarchyFixedChunkTest import *
#from HierarchyProbChunkTest import *
from LongChunkTest import *
from FixedChunkTest import *
from GraphWalkTest import *
import sys
import time

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
    'axes.labelsize': 11,
    'text.fontsize': 11,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'text.usetex': False,
    'figure.figsize': [100 / 2.54, 60 / 2.54]}

class MRIL:
    def __init__(self,trial=30,N=5):
        self.name = "MRIL"
        self.beta = 5
        self.trial=trial
        self.time_array=np.zeros(trial)
        self.MI_list=np.zeros(trial)
        self.MI_s=[]
        self.MI_m=[]
        self.N=N
        self.dt=1
        self.window=1000
        self.width=10
        self.time_delay = 10
        self.gain=10
        self.tau =15
        self.tau_syn = 50
        self.g_L=1/self.tau
        self.g_d=0.7
        self.eps = 10**(-3)*1
        self.p_connect = 1
        self.w_inh_max =0.5/np.sqrt(self.N)
        self.w_inh =np.ones((self.N,self.N))*self.w_inh_max
        self.spike_time = np.zeros(self.N)
        self.mask = np.zeros((self.N,self.N))
        self.tau_p=20*1
        self.tau_d=40*1
    def g(self,x):
    
        alpha = 1

        theta= 0.5

        ans = 1/(1+alpha*np.exp(self.beta*(-x+theta)))
        return ans
    def input(self,input_sequence):
        start = time.time()
        self.nsecs=input_sequence.shape[0]
        self.simtime = np.arange(0, self.nsecs, self.dt)
        self.simtime_len = len(self.simtime)
        self.n_in=len(input_sequence[0,:])
        self.n_syn = self.n_in
        self.PSP= np.zeros(self.n_in)
        self.I_syn = np.zeros(self.n_in)
        self.w=np.random.rand(self.n_syn,self.N)*0.2
        for i in range(self.N):
            for j in range(self.N):
                if i!=j:
                    if np.random.rand()<self.p_connect:
                        self.mask[i,j]=1
        self.w_inh*=self.mask
        self.w_inh[self.w_inh<0] = 0
        self.V_som_list=np.zeros((self.N,self.window*self.width))
        self.w_inh[self.w_inh>self.w_inh_max] = self.w_inh_max
        self.V_dend = np.zeros(self.N)
        self.V_som = np.zeros(self.N)
        self.f=np.zeros(self.N)
        self.cA_p = -1*0.105*0.02#*2#*0.1#*0.5
        self.cA_d = 0.525*0.1*0.02#*2#*0.1#*0.5
        self.A_pre_p=np.zeros(self.N)
        self.A_pre_d=np.zeros(self.N)
        self.A_post_p=np.zeros(self.N)
        self.A_post_d=np.zeros(self.N)
        self.spike_mat=np.zeros((self.n_in,self.simtime_len),dtype=bool)
        print("")
        print("***********")
        print("Learning... ")
        print("***********")
        for i in range(self.simtime_len):
            if int(i / self.simtime_len * 100) % 5 == 0.0:
                if int(i / self.simtime_len * 100) > int((i - 1) / self.simtime_len * 100):
                    print(" " + str(int(i / self.simtime_len * 100)) )
        #------------------------------
            self.PSP_unit=input_sequence[i,:]
            self.V_dend = np.dot(self.w.T,self.PSP_unit)
            self.V_som_list = np.roll(self.V_som_list, -1,axis=1)

            self.V_som = (1.0-self.dt*self.g_L)*self.V_som +self.g_d*(self.V_dend-self.V_som)+np.dot(-self.w_inh,self.f)
            self.V_som_list[:,-1] = self.V_som

            if i>self.width*self.window:
    
                self.f = self.g((self.V_som-np.mean(self.V_som_list,axis=1)) / np.std(self.V_som_list,axis=1))#+np.random.randn(N)*0.1
                self.f=np.clip(self.f,10**(-6),1-10**(-6))
                self.w += self.eps  *np.outer((self.f-self.g(self.V_dend*self.g_d/(self.g_d+self.g_L))) , self.PSP_unit).T*self.beta*(1-self.g(self.V_dend*self.g_d/(self.g_d+self.g_L)))
                self.w[self.w<=0]=0
                #w-=eps*w*0.01

            self.A_pre_p = (1.0 - self.dt / self.tau_p) * self.A_pre_p
            self.A_pre_d = (1.0 - self.dt / self.tau_d) * self.A_pre_d
            self.A_post_p = (1.0 - self.dt / self.tau_p) * self.A_post_p
            self.A_post_d = (1.0 - self.dt / self.tau_d) * self.A_post_d

            for k in range(self.N):
                if np.random.rand()<self.dt*self.f[k]*self.gain*(10**-3):

                        self.A_pre_p[k]+=self.cA_p
                        self.A_pre_d[k]+=self.cA_d
                        self.A_post_p[k]+=self.cA_p
                        self.A_post_d[k]+=self.cA_d

                        self.w_inh[k,:]+=(self.A_pre_p+self.A_pre_d)*self.w_inh_max
                        self.w_inh[:,k]+=(self.A_post_p+self.A_post_d)*self.w_inh_max
                        self.spike_time[k]=i



            self.w_inh*=self.mask
            self.w_inh[self.w_inh<0] = 0
            self.w_inh[self.w_inh>self.w_inh_max] = self.w_inh_max
        self.elapsed_time = time.time() - start
        return self.elapsed_time
    def evaluation(self,task,env,test_length):
#input_sequence, input_class = env.getSequence(simtime_len)
        print("")
        print("***********")
        print("Testing... ")
        print("***********")
        test_len=test_length
        input_sequence, input_class = env.getSequence(test_length)
        self.plot_len=500

        self.V_dend_list =np.zeros((self.N,test_length))

        self.V_dend = np.zeros(self.N)

        self.V_som = np.zeros(self.N)
        self.f_list = np.zeros((self.N,test_length))


        for i in range(test_length):
            self.PSP_unit=input_sequence[i,:]

            self.V_dend = np.dot(self.w.T,self.PSP_unit)

            self.V_som = (1.0-self.dt*self.g_L)*self.V_som +self.g_d*(self.V_dend-self.V_som)+np.dot(-self.w_inh,self.f)
            for k in range(self.N):
                self.f[k] = self.g(self.V_som[k])

            self.f_list[:,i]=self.f

        self.sample_len = test_length

        self.max1 = np.zeros(self.N)
        self.min1 = np.zeros(self.N)
        for i in range(self.N):
            self.max1[i] = np.max(self.f_list[i,0:self.sample_len])
            self.min1[i] = np.min(self.f_list[i,0:self.sample_len])
        self.avg_norm1 = np.zeros((self.N,self.sample_len))

        for i in range(self.N):
            self.avg_norm1[i,:] = (self.f_list[i,0:self.sample_len]-self.min1[i])/(self.max1[i]-self.min1[i])

        self.t = np.zeros(self.N)
        for j in range(self.N):
            self.arg = np.angle(np.dot(self.avg_norm1[j,:],np.exp(np.arange(self.sample_len)/(self.sample_len)*2*np.pi*1j))/sum(self.avg_norm1[j,:]))
            if self.arg<0:
                self.arg += 2*np.pi
            self.t[j] = self.sample_len/(2*np.pi)*self.arg

        self.index = np.zeros(self.N)

        self.index = np.argsort(self.t)
        self.avg_sorted = np.zeros((self.N,self.sample_len))
        for i in range(self.N):
            self.avg_sorted[i,:] = self.avg_norm1[int(self.index[i]),:]

        fig, ax = plt.subplots(figsize=(4,3))

        self.cax=plt.imshow(self.avg_sorted, interpolation='nearest', aspect="auto",cmap='jet')

        self.cbar = fig.colorbar(self.cax, ticks=[0, 1], orientation='vertical')
        self.cbar.ax.set_yticklabels(['min', 'max'],fontsize=10)

        plt.xlabel("time steps",fontsize=10)
        plt.ylabel("Neurons (sorted)",fontsize=10)
        plt.yticks([0,self.N-1],["1","%d"%self.N],fontsize=10)
        ax.tick_params(length=1.3, width=0.05, labelsize=10)
        ax.xaxis.set_ticks_position('none')
        ax.yaxis.set_ticks_position('none')
        plt.ylim([-0.5,self.N-0.5])
        pl.xlim([0,self.plot_len])

        fig.subplots_adjust(left=0.15,bottom=0.25,right=1)
        plt.savefig('activity_map.pdf', fmt='pdf',dpi=350)

        self.groups=[]
        for i in range(self.N):
            count=False
            #print("test",len(groups))
            for mm in range(len(self.groups)):
                if i in self.groups[mm]:
                    count=True
            if count==False:
                self.groups.append([])
                self.groups[-1].append(i)
                for j in range(i,self.N):
                    if j>i:
                        cor=np.corrcoef(self.f_list[i,:], self.f_list[j,:])[0][1]
                        if cor>0.5:
                            self.groups[-1].append(j)
        self.group_num=len(self.groups)
        print(self.group_num)
        self.pop_act = np.zeros((self.group_num,self.sample_len ))
        for mm in range(self.group_num):
            for nn in self.groups[mm]:
                self.pop_act[mm,:]+=self.f_list[nn,:]/len(self.groups[mm])
        self.labels=[]
        for mm in range(self.sample_len):
            self.labels.append(np.argmax(self.pop_act[:,mm]))
        self.truth_list=[]

        if task==GraphWalkTest:
            self.truth=env.true_label
            for i in range(test_len):
                self.truth_list.append(self.truth[int(input_class[i])])
        if task==OverlapChunkTest1 or task==OverlapChunkTest2:
            self.truth=env.trueLabel()
            for i in range(test_len):
                self.truth_list.append(self.truth[int(input_class[i])])
        if task==LongChunkTest or task==FixedChunkTest:
            for i in range(test_len):
                self.truth_list.append(int(input_class[i]))
        if task== HierarchyFixedChunkTest or task==HierarchyProbChunkTest:
            mi_a=[]
            in_class=np.array(input_class)
            for ___a in range(in_class.shape[1]-1):
                #print(labels.shape)
                #print(sklearn.metrics.cluster.normalized_mutual_info_score(labels,in_class[:,___a]))
                mi_a.append(sklearn.metrics.cluster.normalized_mutual_info_score(self.labels
                                                                                 ,in_class[:,___a+1]))
            self.MI=np.mean(mi_a)
        else:
            self.MI=normalized_mutual_info_score(self.labels, self.truth_list)
            print("none hierarchy")
        print("MI_score",self.MI)
        return self.MI
#np.savetxt('Mixed_MI.txt',[np.mean(MI_list),np.std(MI_list)], delimiter=',')
