import numpy as np
import matplotlib.pylab as plt
import matplotlib.animation as am
import sys

sys.path.append('../')
import CANN_mod.CANN_1D.CANN_1D as CANN
import useful_functions as uf

from scipy import signal
import scipy
import matplotlib as mt
import matplotlib.gridspec as gs

mt.rcParams['pdf.fonttype'] = 42

N = 512
mbar = 150 #bimodal150 #unimodal151.2
var_m = 0.12#0.12
vbar = 0.5#0.4
J0 = 1
k0 = 5
J0=J0/N*512
a =0.4
tau = 3#1
tau_v = 144#48
dt = tau/10
m = mbar* tau/tau_v
# kc=(1+tau/tau_v)*(1+2*m-tau/tau_v)*density*J0^2/(32*np.pi*a^2*(1+m)^4)
k=k0/N*512
A = np.sqrt(2*np.pi*a**2)
alpha=0.19
v=a/tau_v*vbar
T=8*np.pi/v


figure_path = './figures/20220515bi/'


placefield = 2.5*a 

J = np.zeros(N)
J0 = 1#1
coordinatex = np.arange(-np.pi,np.pi,2*np.pi/N)
delta_z_squre = (coordinatex)**2
J = J0/A*np.exp(-(delta_z_squre)/(2*a**2))
J = np.roll(J,shift=int(N/2))

net = CANN.CANN_1D(N,k,m,J,tau,tau_v,True)

I = np.zeros(N)

U_profile = np.zeros([int(T/dt),N])
Is = np.empty([int(T/dt),N],dtype=float)
positions = np.empty(int(T/dt),dtype=float)
times = np.empty(int(T/dt),dtype=float)

start = 0

for t in range(int(T/dt)):
    z = start + v*t*dt
    z = z%(2*np.pi) - np.pi + start
    delta_z = coordinatex-z
    index = np.pi < abs(delta_z)
    delta_z[index] = 2*np.pi-abs(delta_z[index])
    I_ext = alpha*np.exp(-(delta_z)**2/(4*a**2))
    net.update(I_ext,dt)
    U_profile[t] = net.r*5000#net.U_bump
    Is[t] = I_ext
    positions[t] = z
    times[t] = t*dt

U_profile = U_profile[int(T/dt/2):]
positions = positions[int(T/dt/2):]
T = T/2

decoded_position = uf.complex_arg((U_profile*np.exp(1j*coordinatex)).sum(1)/U_profile.sum(1))#(relative_loc*coordinate).sum(1)/relative_loc.sum(1) -positions
decoded_position = np.arcsin(np.sin((decoded_position.T-positions).T))

'''plt.figure()
plt.plot(U_profile)'''

smoothed_decoded_position = scipy.ndimage.filters.gaussian_filter(decoded_position,sigma=5)

'''plt.figure()
p2 = plt.plot(smoothed_decoded_position)
plt.legend(['gaussian'])'''


#peaks
valleys = uf.find_valleys(smoothed_decoded_position)
begin = valleys[2]
if len(valleys)%2 == 1:
    end = valleys[-1]
else:
    end = valleys[-2]
smoothed_decoded_position = smoothed_decoded_position[begin:end]
U_profile = U_profile[begin:end]

positions = positions[begin:end]
#print(positions.shape)
times = times[begin:end]
peaks = uf.find_peaks(smoothed_decoded_position)
#print(len(valleys))
valleys = uf.find_valleys(smoothed_decoded_position)
#print(len(valleys))
valleys = np.hstack([np.array([0]),valleys])
valleys = np.hstack([valleys,times.shape[0]])

phases = []

for i in range(1,valleys.shape[0],1):
    l = valleys[i]-valleys[i-1]
    for j in range(l):
        phase = uf.my_round(j/l*360)#720)
        phases.append(phase)
phases = np.array(phases)
#print(valleys[-1],'phase',phases.shape)

phase_pos = np.zeros([721,int(2*placefield/(2*np.pi/N))+1])

cell_types = -np.ones(N)

for index in range(N):
    activity = U_profile[:,index]
    phases_i = np.zeros(721)
    for time in range(activity.shape[0]):
        phase = phases[time]
        phases_i[int(phase/0.5)] += activity[time]

    phases_i = scipy.ndimage.filters.gaussian_filter(phases_i,sigma=14)

    phases_i /= phases_i.max()
    phases_i -= phases_i.min()

    #peaks
    peaks = uf.find_peaks(phases_i)
    peaks = np.array(peaks)

    if len(peaks)==0:
        num_peaks_p = 0
    else:
        wid,_,_,_ = signal.peak_widths(phases_i,peaks)
        p,_,_ = signal.peak_prominences(phases_i,peaks)
        act = phases_i[peaks]
        xxx = np.where(wid>60)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]
        act = act[xxx]
        xxx = np.where(p > 0.05)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]  
        act = act[xxx] 
        min_value = phases_i.min()
        xxx = np.where(act - min_value > 1/10*(phases_i.max()-min_value)) 
        dis = np.diff(peaks)
        xxx = np.where(dis<100)[0]
        invalid = []
        for i in xxx:
            if act[i] < act[i+1]:
                invalid.append(i)
            else:
                invalid.append(i+1)
        peaks = np.delete(peaks,invalid)
        wid = np.delete(wid,invalid)
        p = np.delete(p,invalid)
        act = np.delete(act,invalid)
        num_peaks_p = peaks.shape[0]
    
    #valleys counting number of peaks and valleys at the same time and accept the larger one
    peaks = uf.find_peaks(-phases_i)
    peaks = np.array(peaks)
    if(len(peaks)==0):
        num_peaks_v = 0
    else:
        wid,_,_,_ = signal.peak_widths(-phases_i,peaks)
        p,_,_ = signal.peak_prominences(-phases_i,peaks)
        act = -phases_i[peaks]
        xxx = np.where(wid>60)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]
        act = act[xxx]
        xxx = np.where(p > 0.05)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]  
        act = act[xxx] 
        min_value = (-phases_i).min()
        xxx = np.where(act - min_value > 1/10*(-phases_i.max()-min_value))
        dis = np.diff(peaks)
        xxx = np.where(dis<100)[0]
        invalid = []
        for i in xxx:
            if act[i] < act[i+1]:
                invalid.append(i)
            else:
                invalid.append(i+1)
        peaks = np.delete(peaks,invalid)
        wid = np.delete(wid,invalid)
        p = np.delete(p,invalid)
        act = np.delete(act,invalid)
        num_peaks_v = peaks.shape[0]

    num_peaks = max(num_peaks_p,num_peaks_v)
    
    '''plt.plot(phases_i)
    plt.title('num_peak={}'.format(num_peaks))
    plt.show()'''

    if num_peaks==1:
        cell_types[index] = 0 #unimodal
    elif num_peaks==2:
        cell_types[index] = 1 # bimodal
    else:
        continue

unimodals = np.where(cell_types==0)[0]
bimodals = np.where(cell_types==1)[0]
num_uni = unimodals.shape[0]
num_bi = bimodals.shape[0]

print('unimodel:{} bimodal:{}'.format(num_uni,num_bi))


spec = np.fft.fft(smoothed_decoded_position-smoothed_decoded_position.mean())

spec = spec[:int(spec.shape[0]/2)]

'''plt.figure()
plt.plot(spec)'''

ps = uf.find_peaks(np.abs(spec.real))
cretiria = np.abs(spec.real).max()/2
ps = np.array(ps)
freq = np.where(np.abs(spec.real)[ps] > cretiria)
print('freq',freq)
freq = ps[freq[0][0]]
theta_period = int(T/dt/freq)

print('argmax:',freq,'T:',theta_period)

font_size = 80
plt.rcParams['font.size'] = 40
single_neuron_index = 120

if True:
    #unimodal
    plt.figure(figsize=(12,4))
    i = single_neuron_index#int(num_uni/2)
    index_uni = bimodals[i]
    prefer_x_uni = coordinatex[index_uni]
    center_uni = np.argmin(abs(prefer_x_uni-positions))
    start = center_uni-int(2/3*np.pi/v/dt)
    end = center_uni+int(2/3*np.pi/v/dt)
    if start<0:
        start += int(2*np.pi/v/dt)
        end += int(2*np.pi/v/dt)
    elif end>int(T/dt):
        start -= int(2*np.pi/v/dt)
        end -= int(2*np.pi/v/dt)
    activity_uni = U_profile[start:end,index_uni]
    uni_peaks = np.array(uf.find_peaks(activity_uni))
    xxx = activity_uni[uni_peaks] > (activity_uni.max()/20)
    start = uni_peaks[xxx][0] - int(theta_period*3/2)
    end = uni_peaks[xxx][-1] + int(theta_period*3/2)
    activity_uni = activity_uni[start:end]
    X_time_steps = np.arange(-int(2/3*np.pi/v/dt),int(2/3*np.pi/v/dt),1)
    #X_cycle_steps = np.arange(0,2*int(2/3*np.pi/v/dt),theta_period) - int(2/3*np.pi/v/dt)
    X_time_steps = X_time_steps[start:end]
    X_cycle_steps = np.concatenate([np.flip(np.arange(0,X_time_steps[0]-100,-theta_period),0),np.arange(0,X_time_steps[-1],theta_period)[1:]])#np.arange(X_time_steps[0],X_time_steps[-1],theta_period)
    X_theta_cycle = (X_cycle_steps/theta_period).astype(np.int)
    plt.plot(X_time_steps,activity_uni,lw=3,color='k')
    plt.xticks(X_cycle_steps)
    plt.xticks(X_cycle_steps,X_theta_cycle,fontsize=40)
    ax = plt.gca()
    ax.spines['top'].set_linewidth(5)
    ax.spines['bottom'].set_linewidth(5)
    ax.spines['left'].set_linewidth(5)
    ax.spines['right'].set_linewidth(5)
    plt.title('bimodal')
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
    #plt.savefig(figure_path+'single_bi.pdf',format='pdf',dpi=1000)

    

##print(phase_pos.shape)
for i in range(N):
    dis = abs(coordinatex[i]-positions)
    xxx = dis>np.pi
    dis[xxx] = 2*np.pi - dis[xxx]
    index = dis<placefield
    
    phase = phases[index]
    ##print(phase.shape)
    act = U_profile[index,i]# U_profile.T[i][index]
    ##print(act.shape)

    dis = positions[index]-coordinatex[i]
    #print('xxx',positions[index].max(),positions[index].min())
    xxx = dis>np.pi
    dis[xxx] = dis[xxx]-2*np.pi
    xxx = dis<-np.pi
    dis[xxx] = dis[xxx]+2*np.pi
    position = uf.my_round_n(dis+placefield,(2*np.pi/N))
    #print(dis.min(),dis.max(),position.min(),position.max())
    ##print(position/(2*np.pi/N))
    for j in range(phase.shape[0]):
        ##print(j,position[j]/(2*np.pi/N))
        phase_pos[int(phase[j]/0.5)][int(position[j]/(2*np.pi/N))] += act[j] #




#print('num_uni={} num_bi={}'.format(num_uni,num_bi))

#from numpy.core.function_base import linspace
import matplotlib.ticker as ticker

plot_position = np.linspace(-1,1,int(2*placefield/(2*np.pi/N))+1)
plot_phase = np.linspace(0,360,721)# np.linspace(0,720,1441)
plot_phase_roll = np.roll(plot_phase,shift=0,axis=0) #shift=+280
#phase_label = [int(plot_phase_roll[0]),int(plot_phase_roll[720]),int(plot_phase_roll[-1])]

phase_pos = np.vstack([phase_pos[0:-1],phase_pos[0:]])
phase_pos = np.roll(phase_pos,shift=+310,axis=0)#shift=0  roll to match the experiment figure
smooth_phase_pos = (scipy.ndimage.filters.gaussian_filter(phase_pos,sigma=(13,7)))[0:721]

lines = np.argmax(smooth_phase_pos,axis=1)
lines = plot_position[lines]
lines = scipy.ndimage.filters.gaussian_filter(lines,sigma=7)
'''plt.figure()
plt.plot(plot_phase,lines) 
print(np.argmin(lines),np.argmax(lines))'''


total_act = scipy.ndimage.filters.gaussian_filter(smooth_phase_pos.max(1),sigma=24)#signal.savgol_filter(smooth_phase_pos_bi.max(1),23,3)
bound = np.array(uf.find_valleys(total_act))

'''plt.figure()
plt.plot(total_act)'''

#np.save(figure_path+'bound.npy',bound)
#bound = np.load(figure_path+'bound.npy')
bound = []
bound.append(np.argmin(lines))
bound.append(np.argmax(lines))

bound = np.array(bound)
if bound[0]==0:
    value = lines[0]
    i = 0
    while value==lines[0]:
        value = lines[i]
        i += 1
    bound[0] = i

#unimodal
smooth_phase_pos = smooth_phase_pos/N #average
smooth_phase_pos = smooth_phase_pos/smooth_phase_pos.sum() #normalize it as prop
f = plt.figure(figsize=(16,28))
gs0 = gs.GridSpec(nrows=2,ncols=1,height_ratios=[3,1])
##np.save('bimodal.npy',phase_pos)
full = f.add_subplot(gs0[0,0])
plt.pcolormesh(plot_position,plot_phase,smooth_phase_pos,cmap='rainbow')
plt.xticks([])

plt.yticks(plot_phase.astype('int'),plot_phase_roll.astype('int'),fontsize=font_size)
plt.yticks(np.linspace(0,360,3),fontsize=font_size)#-40
plt.hlines(bound/2,-1,1,linestyle='dashed',lw=5,colors='black')

cb = plt.colorbar(orientation='horizontal',fraction=0.040,location='top')
ticklocator = ticker.MaxNLocator(nbins=2)
cb.locator = ticklocator
plt.clim([0,2.4e-5])#([0,1.2e-5])
cb.set_ticks([0,1.2e-5])#([np.min(smooth_phase_pos_uni),np.max(smooth_phase_pos_uni)])
cb.update_ticks()
cb.formatter.set_powerlimits((0, 0))
ax = plt.gca()
ax.spines['top'].set_linewidth(5)
ax.spines['bottom'].set_linewidth(5)
ax.spines['left'].set_linewidth(5)
ax.spines['right'].set_linewidth(5)

part = f.add_subplot(gs0[1,0])
low = bound[0]#bound[1]
high = bound[1]#bound[2]


'''low = 0
high = bound[0]'''
plt.pcolormesh(plot_position,plot_phase_roll[low:high],smooth_phase_pos[low:high],cmap='rainbow')
plt.xticks(np.linspace(-1,1,3),fontsize=font_size)
plt.yticks([int(plot_phase_roll[low]),int(plot_phase_roll[high])],fontsize=font_size)
cb = plt.colorbar(orientation='horizontal',fraction=0.12,location='top')
ticklocator = ticker.MaxNLocator(nbins=2)
cb.locator = ticklocator
plt.clim([0,1.8e-5])#([0,8.7e-6])
cb.set_ticks([0,8.7e-6])#([np.min(smooth_phase_pos_uni[low:high]),np.max(smooth_phase_pos_uni[low:high])])
cb.update_ticks()
cb.formatter.set_powerlimits((0, 0))
ax = plt.gca()
ax.spines['top'].set_linewidth(5)
ax.spines['bottom'].set_linewidth(5)
ax.spines['left'].set_linewidth(5)
ax.spines['right'].set_linewidth(5)
f.tight_layout(pad=5)



#plt.title('unimodal cells')
#plt.savefig(figure_path+'bimodal.pdf',format='pdf',dpi=1000)

N = 512
mbar = 151.2 #bimodal150 #unimodal151.2
var_m = 0.12#0.12
vbar = 0.5#0.4
J0 = 1
k0 = 5
J0=J0/N*512
a =0.4
tau = 3#1
tau_v = 144#48
dt = tau/10
m = mbar* tau/tau_v
# kc=(1+tau/tau_v)*(1+2*m-tau/tau_v)*density*J0^2/(32*np.pi*a^2*(1+m)^4)
k=k0/N*512
A = np.sqrt(2*np.pi*a**2)
alpha=0.19
v=a/tau_v*vbar
T=8*np.pi/v


figure_path = './figures/20220515bi/'


placefield = 2.5*a 

J = np.zeros(N)
J0 = 1#1
coordinatex = np.arange(-np.pi,np.pi,2*np.pi/N)
delta_z_squre = (coordinatex)**2
J = J0/A*np.exp(-(delta_z_squre)/(2*a**2))
J = np.roll(J,shift=int(N/2))

net = CANN.CANN_1D(N,k,m,J,tau,tau_v,True)

I = np.zeros(N)

U_profile = np.zeros([int(T/dt),N])
Is = np.empty([int(T/dt),N],dtype=float)
positions = np.empty(int(T/dt),dtype=float)
times = np.empty(int(T/dt),dtype=float)

start = 0

for t in range(int(T/dt)):
    z = start + v*t*dt
    z = z%(2*np.pi) - np.pi + start
    delta_z = coordinatex-z
    index = np.pi < abs(delta_z)
    delta_z[index] = 2*np.pi-abs(delta_z[index])
    I_ext = alpha*np.exp(-(delta_z)**2/(4*a**2))
    net.update(I_ext,dt)
    U_profile[t] = net.r*5000#net.U_bump
    Is[t] = I_ext
    positions[t] = z
    times[t] = t*dt

U_profile = U_profile[int(T/dt/2):]
positions = positions[int(T/dt/2):]
T = T/2

decoded_position = uf.complex_arg((U_profile*np.exp(1j*coordinatex)).sum(1)/U_profile.sum(1))#(relative_loc*coordinate).sum(1)/relative_loc.sum(1) -positions
decoded_position = np.arcsin(np.sin((decoded_position.T-positions).T))

'''plt.figure()
plt.plot(U_profile)'''

smoothed_decoded_position = scipy.ndimage.filters.gaussian_filter(decoded_position,sigma=5)

'''plt.figure()
p2 = plt.plot(smoothed_decoded_position)
plt.legend(['gaussian'])'''


#peaks
valleys = uf.find_valleys(smoothed_decoded_position)
begin = valleys[2]
if len(valleys)%2 == 1:
    end = valleys[-1]
else:
    end = valleys[-2]
smoothed_decoded_position = smoothed_decoded_position[begin:end]
U_profile = U_profile[begin:end]

positions = positions[begin:end]
#print(positions.shape)
times = times[begin:end]
peaks = uf.find_peaks(smoothed_decoded_position)
#print(len(valleys))
valleys = uf.find_valleys(smoothed_decoded_position)
#print(len(valleys))
valleys = np.hstack([np.array([0]),valleys])
valleys = np.hstack([valleys,times.shape[0]])

phases = []

for i in range(1,valleys.shape[0],1):
    l = valleys[i]-valleys[i-1]
    for j in range(l):
        phase = uf.my_round(j/l*360)#720)
        phases.append(phase)
phases = np.array(phases)
#print(valleys[-1],'phase',phases.shape)

phase_pos = np.zeros([721,int(2*placefield/(2*np.pi/N))+1])

cell_types = -np.ones(N)

for index in range(N):
    activity = U_profile[:,index]
    phases_i = np.zeros(721)
    for time in range(activity.shape[0]):
        phase = phases[time]
        phases_i[int(phase/0.5)] += activity[time]

    phases_i = scipy.ndimage.filters.gaussian_filter(phases_i,sigma=14)

    phases_i /= phases_i.max()
    phases_i -= phases_i.min()

    #peaks
    peaks = uf.find_peaks(phases_i)
    peaks = np.array(peaks)

    if len(peaks)==0:
        num_peaks_p = 0
    else:
        wid,_,_,_ = signal.peak_widths(phases_i,peaks)
        p,_,_ = signal.peak_prominences(phases_i,peaks)
        act = phases_i[peaks]
        xxx = np.where(wid>60)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]
        act = act[xxx]
        xxx = np.where(p > 0.05)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]  
        act = act[xxx] 
        min_value = phases_i.min()
        xxx = np.where(act - min_value > 1/10*(phases_i.max()-min_value)) 
        dis = np.diff(peaks)
        xxx = np.where(dis<100)[0]
        invalid = []
        for i in xxx:
            if act[i] < act[i+1]:
                invalid.append(i)
            else:
                invalid.append(i+1)
        peaks = np.delete(peaks,invalid)
        wid = np.delete(wid,invalid)
        p = np.delete(p,invalid)
        act = np.delete(act,invalid)
        num_peaks_p = peaks.shape[0]
    
    #valleys counting number of peaks and valleys at the same time and accept the larger one
    peaks = uf.find_peaks(-phases_i)
    peaks = np.array(peaks)
    if(len(peaks)==0):
        num_peaks_v = 0
    else:
        wid,_,_,_ = signal.peak_widths(-phases_i,peaks)
        p,_,_ = signal.peak_prominences(-phases_i,peaks)
        act = -phases_i[peaks]
        xxx = np.where(wid>60)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]
        act = act[xxx]
        xxx = np.where(p > 0.05)[0]
        peaks = peaks[xxx]
        wid = wid[xxx]
        p = p[xxx]  
        act = act[xxx] 
        min_value = (-phases_i).min()
        xxx = np.where(act - min_value > 1/10*(-phases_i.max()-min_value))
        dis = np.diff(peaks)
        xxx = np.where(dis<100)[0]
        invalid = []
        for i in xxx:
            if act[i] < act[i+1]:
                invalid.append(i)
            else:
                invalid.append(i+1)
        peaks = np.delete(peaks,invalid)
        wid = np.delete(wid,invalid)
        p = np.delete(p,invalid)
        act = np.delete(act,invalid)
        num_peaks_v = peaks.shape[0]

    num_peaks = max(num_peaks_p,num_peaks_v)
    
    '''plt.plot(phases_i)
    plt.title('num_peak={}'.format(num_peaks))
    plt.show()'''

    if num_peaks==1:
        cell_types[index] = 0 #unimodal
    elif num_peaks==2:
        cell_types[index] = 1 # bimodal
    else:
        continue

unimodals = np.where(cell_types==0)[0]
bimodals = np.where(cell_types==1)[0]
num_uni = unimodals.shape[0]
num_bi = bimodals.shape[0]

print('unimodel:{} bimodal:{}'.format(num_uni,num_bi))


spec = np.fft.fft(smoothed_decoded_position-smoothed_decoded_position.mean())

spec = spec[:int(spec.shape[0]/2)]

'''plt.figure()
plt.plot(spec)'''

ps = uf.find_peaks(np.abs(spec.real))
cretiria = np.abs(spec.real).max()/2
ps = np.array(ps)
freq = np.where(np.abs(spec.real)[ps] > cretiria)
print('freq',freq)
freq = ps[freq[0][0]]
theta_period = int(T/dt/freq)

print('argmax:',freq,'T:',theta_period)

font_size = 80
plt.rcParams['font.size'] = 40
single_neuron_index = 120

if True:
    #unimodal
    plt.figure(figsize=(12,4))
    i = single_neuron_index#int(num_uni/2)
    index_uni = bimodals[i]
    prefer_x_uni = coordinatex[index_uni]
    center_uni = np.argmin(abs(prefer_x_uni-positions))
    start = center_uni-int(2/3*np.pi/v/dt)
    end = center_uni+int(2/3*np.pi/v/dt)
    if start<0:
        start += int(2*np.pi/v/dt)
        end += int(2*np.pi/v/dt)
    elif end>int(T/dt):
        start -= int(2*np.pi/v/dt)
        end -= int(2*np.pi/v/dt)
    activity_uni = U_profile[start:end,index_uni]
    uni_peaks = np.array(uf.find_peaks(activity_uni))
    xxx = activity_uni[uni_peaks] > (activity_uni.max()/20)
    start = uni_peaks[xxx][0] - int(theta_period*3/2)
    end = uni_peaks[xxx][-1] + int(theta_period*3/2)
    activity_uni = activity_uni[start:end]
    X_time_steps = np.arange(-int(2/3*np.pi/v/dt),int(2/3*np.pi/v/dt),1)
    #X_cycle_steps = np.arange(0,2*int(2/3*np.pi/v/dt),theta_period) - int(2/3*np.pi/v/dt)
    X_time_steps = X_time_steps[start:end]
    X_cycle_steps = np.concatenate([np.flip(np.arange(0,X_time_steps[0]-100,-theta_period),0),np.arange(0,X_time_steps[-1],theta_period)[1:]])#np.arange(X_time_steps[0],X_time_steps[-1],theta_period)
    X_theta_cycle = (X_cycle_steps/theta_period).astype(np.int)
    plt.plot(X_time_steps,activity_uni,lw=3,color='k')
    plt.xticks(X_cycle_steps)
    plt.xticks(X_cycle_steps,X_theta_cycle,fontsize=40)
    ax = plt.gca()
    ax.spines['top'].set_linewidth(5)
    ax.spines['bottom'].set_linewidth(5)
    ax.spines['left'].set_linewidth(5)
    ax.spines['right'].set_linewidth(5)
    plt.title('unimodal')
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
    #plt.savefig(figure_path+'single_bi.pdf',format='pdf',dpi=1000)

    

##print(phase_pos.shape)
for i in range(N):
    dis = abs(coordinatex[i]-positions)
    xxx = dis>np.pi
    dis[xxx] = 2*np.pi - dis[xxx]
    index = dis<placefield
    
    phase = phases[index]
    ##print(phase.shape)
    act = U_profile[index,i]# U_profile.T[i][index]
    ##print(act.shape)

    dis = positions[index]-coordinatex[i]
    #print('xxx',positions[index].max(),positions[index].min())
    xxx = dis>np.pi
    dis[xxx] = dis[xxx]-2*np.pi
    xxx = dis<-np.pi
    dis[xxx] = dis[xxx]+2*np.pi
    position = uf.my_round_n(dis+placefield,(2*np.pi/N))
    #print(dis.min(),dis.max(),position.min(),position.max())
    ##print(position/(2*np.pi/N))
    for j in range(phase.shape[0]):
        ##print(j,position[j]/(2*np.pi/N))
        phase_pos[int(phase[j]/0.5)][int(position[j]/(2*np.pi/N))] += act[j] #




#print('num_uni={} num_bi={}'.format(num_uni,num_bi))

#from numpy.core.function_base import linspace
import matplotlib.ticker as ticker

plot_position = np.linspace(-1,1,int(2*placefield/(2*np.pi/N))+1)
plot_phase = np.linspace(0,360,721)# np.linspace(0,720,1441)
plot_phase_roll = np.roll(plot_phase,shift=0,axis=0) #shift=+280
#phase_label = [int(plot_phase_roll[0]),int(plot_phase_roll[720]),int(plot_phase_roll[-1])]

phase_pos = np.vstack([phase_pos[0:-1],phase_pos[0:]])
phase_pos = np.roll(phase_pos,shift=+310,axis=0)#shift=0  roll to match the experiment figure
smooth_phase_pos = (scipy.ndimage.filters.gaussian_filter(phase_pos,sigma=(13,7)))[0:721]

lines = np.argmax(smooth_phase_pos,axis=1)
lines = plot_position[lines]
lines = scipy.ndimage.filters.gaussian_filter(lines,sigma=7)
'''plt.figure()
plt.plot(plot_phase,lines) 
print(np.argmin(lines),np.argmax(lines))'''


total_act = scipy.ndimage.filters.gaussian_filter(smooth_phase_pos.max(1),sigma=24)#signal.savgol_filter(smooth_phase_pos_bi.max(1),23,3)
bound = np.array(uf.find_valleys(total_act))

'''plt.figure()
plt.plot(total_act)'''

#np.save(figure_path+'bound.npy',bound)
#bound = np.load(figure_path+'bound.npy')
bound = []
bound.append(np.argmin(lines))
bound.append(np.argmax(lines))

bound = np.array(bound)
if bound[0]==0:
    value = lines[0]
    i = 0
    while value==lines[0]:
        value = lines[i]
        i += 1
    bound[0] = i

#unimodal
smooth_phase_pos = smooth_phase_pos/N #average
smooth_phase_pos = smooth_phase_pos/smooth_phase_pos.sum() #normalize it as prop
f = plt.figure(figsize=(16,28))
gs0 = gs.GridSpec(nrows=2,ncols=1,height_ratios=[3,1])
##np.save('bimodal.npy',phase_pos)
full = f.add_subplot(gs0[0,0])
plt.pcolormesh(plot_position,plot_phase,smooth_phase_pos,cmap='rainbow')
plt.xticks([])

plt.yticks(plot_phase.astype('int'),plot_phase_roll.astype('int'),fontsize=font_size)
plt.yticks(np.linspace(0,360,3),fontsize=font_size)#-40
plt.hlines(bound/2,-1,1,linestyle='dashed',lw=5,colors='black')

cb = plt.colorbar(orientation='horizontal',fraction=0.040,location='top')
ticklocator = ticker.MaxNLocator(nbins=2)
cb.locator = ticklocator
plt.clim([0,2.4e-5])#([0,1.2e-5])
cb.set_ticks([0,1.2e-5])#([np.min(smooth_phase_pos_uni),np.max(smooth_phase_pos_uni)])
cb.update_ticks()
cb.formatter.set_powerlimits((0, 0))
ax = plt.gca()
ax.spines['top'].set_linewidth(5)
ax.spines['bottom'].set_linewidth(5)
ax.spines['left'].set_linewidth(5)
ax.spines['right'].set_linewidth(5)

part = f.add_subplot(gs0[1,0])
low = bound[0]#bound[1]
high = bound[1]#bound[2]


'''low = 0
high = bound[0]'''
plt.pcolormesh(plot_position,plot_phase_roll[low:high],smooth_phase_pos[low:high],cmap='rainbow')
plt.xticks(np.linspace(-1,1,3),fontsize=font_size)
plt.yticks([int(plot_phase_roll[low]),int(plot_phase_roll[high])],fontsize=font_size)
cb = plt.colorbar(orientation='horizontal',fraction=0.12,location='top')
ticklocator = ticker.MaxNLocator(nbins=2)
cb.locator = ticklocator
plt.clim([0,1.8e-5])#([0,8.7e-6])
cb.set_ticks([0,8.7e-6])#([np.min(smooth_phase_pos_uni[low:high]),np.max(smooth_phase_pos_uni[low:high])])
cb.update_ticks()
cb.formatter.set_powerlimits((0, 0))
ax = plt.gca()
ax.spines['top'].set_linewidth(5)
ax.spines['bottom'].set_linewidth(5)
ax.spines['left'].set_linewidth(5)
ax.spines['right'].set_linewidth(5)
f.tight_layout(pad=5)

plt.show()