import numpy as np
from CANN_1D import CANN_v2
from scipy import ndimage

def CANN(alphas,fano=0.001,sampling_times=10):
    N = 512
    tau = 1
    trans = True

    a = 0.5
    cor = np.arange(-np.pi, np.pi, 2*np.pi/N)
    J0 = 1
    rho = N / (2*np.pi)

    kc = np.pi*a*rho*J0**2/(4*np.sqrt(2*np.pi))
    k = 0.04*kc

    exp = np.exp(-(cor)**2/(4*a**2))
    Jexp = np.exp(-(cor)**2/(2*a**2))
    J = J0 * Jexp
    J = np.roll(J,shift=int(N/2))

    T = 500
    dt = 0.1

    net = CANN_v2(N,k,J,tau,trans)

    In = exp
    shift = int(N/5)
    I0 = np.roll(In,shift=shift)
    
    set_time = 100

    threshold = 0.004
    distance = 10

    frames = int(T/dt)
    if type(alphas) == float or type(alphas) == int:
        trials = 1
        alphas = np.array([alphas])
    trials = alphas.shape[0]

    FR_profile = np.zeros((trials,sampling_times,frames,N))
    centers = np.zeros((trials,sampling_times,frames))
    Interval = np.zeros((trials,sampling_times))
    final_center = np.zeros((trials,sampling_times))
    

    for trial in range(trials):
        alpha = alphas[trial]
        print('alpha:', alpha)
        if alpha > 0:
            I_init = I0
            target = cor[In.argmax()]
        else:
            I_init = np.roll(In,shift=2)
            target = cor[I0.argmax()]
        for samp in range(sampling_times):
            net.reset()

            distance = 10
            final_t = frames
            for t in range(frames):
                
                if t < set_time:
                    I_ext = 0.02 * I_init

                else:
                    I_ext = alpha * In

                I_ext = I_ext + np.sqrt(tau*net.U*fano)*np.random.randn(N)

                net.update(I_ext,dt)

                smooth_r = ndimage.gaussian_filter1d(net.r,sigma=10)
                smooth_center = cor[smooth_r.argmax()]
                center = np.angle(np.exp(1j*(cor-smooth_center))) @ net.r / net.r.sum() + smooth_center
                # print('center:', center)
                distance = np.abs(center-target)
                # print('distance:', distance)
                if distance < threshold*2*np.pi and t < final_t:
                    final_t = t
                
                FR_profile[trial][samp][t] = net.r
                centers[trial][samp][t] = center

            Interval[trial][samp] = (final_t-set_time)*dt
            final_center[trial][samp] = center
    
    FR_profile = np.array(FR_profile)[:,:,set_time:,:]
    centers = np.array(centers)[:,:,set_time:]
    return FR_profile,centers,Interval

# plot the result of CANN(alpha=0.02) as animation
import matplotlib.pyplot as plt
import numpy as np
import os

figure_path = './figures/NeurIPS_Figures/Neural_Sequence/0/'
figsize = [9,6]

cor = np.arange(-np.pi, np.pi, 2*np.pi/512)

bottom_idx = np.where(np.abs(cor+2)<(2*np.pi/512))[0][0]
top_idx = np.where(np.abs(cor-3)<(2*np.pi/512))[0][0]

alpha_0 = 0.04
factors_1 = np.arange(0.2,3.5,0.2)
factors_2 = -factors_1
factors = np.concatenate((factors_2,factors_1))
Alphas = alpha_0 * factors

trials = factors.shape[0]

plot_inter = 3
# select the reference curve
ref_idx = np.where(np.abs(factors-1)<1e-4)[0][0]
print('ref_idx',ref_idx)
plot_index_pos = np.arange(ref_idx-plot_inter,trials,plot_inter)

neg_ref_idx = np.where(np.abs(factors+1)<1e-4)[0][0]
print('neg_ref_idx',neg_ref_idx)
plot_index_neg = np.arange(neg_ref_idx-plot_inter,int(trials/2),plot_inter)

plot_idx = np.concatenate([plot_index_neg,plot_index_pos])

alpha_plot = Alphas[plot_idx]


# set font type to make it editable in illustrator in eps format
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
# set figure style
plt.style.use('seaborn-white')

# set no margin for the figure
plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0

# choose a color map
cmap = plt.cm.get_cmap('hot')
# set the color map to be reversed
cmap = cmap.reversed()

z_cmap = plt.cm.Spectral
z_colors = z_cmap(np.linspace(0,1,trials))
plot_z_colors = z_colors[plot_idx]


# if figure_path does not exit, build it
if not os.path.exists(figure_path):
    os.makedirs(figure_path)

r_record, _, Intervals = CANN(alpha_plot)

Intervals = Intervals.mean(axis=1).astype(int)

# average over samples
r_record = r_record.mean(axis=1)

# calculate the center of the bump
smooth_r = ndimage.gaussian_filter1d(r_record,sigma=10,axis=2)
smooth_center = cor[smooth_r.argmax(axis=2)]
# centers = np.angle(np.exp(1j*(cor-smooth_center))) @ net.r / net.r.sum() + smooth_center
mm = np.zeros((plot_idx.shape[0],r_record.shape[1]))
for i in range(r_record.shape[0]):
    mm[i] = (r_record[i,:,:] * np.angle(np.exp(1j*(cor[None,:]-smooth_center[i][:,None])))).sum(axis=1) 
num = r_record.sum(axis=2)
centers =  mm/num + smooth_center

# down-sample r_record per 1ms
r_record_down = r_record[:,::10,:]


x_max = Intervals.max()

# plot the result
for i in range(plot_idx.shape[0]):
    plt.figure(figsize=figsize)
    end = int(Intervals[i]*10)
    Plot_figure = r_record[i,:end,bottom_idx:top_idx]
    down_rate = 30
    Plot_figure_down = Plot_figure[::down_rate,:]
    frames = Plot_figure_down.shape[0]
    z_time = np.arange(0,Plot_figure.shape[0])*0.1
    x = np.arange(frames)*0.1*down_rate
    y = cor[bottom_idx:top_idx]
    plt.pcolormesh(x, y, Plot_figure_down.T, cmap=cmap)
    plt.xlabel('Time(ms)')
    plt.ylabel('z')
    plt.xlim(0,x_max)
    plt.colorbar()
    #plot centers
    plt.plot(z_time,centers[i,:end],color=plot_z_colors[i],linewidth=8)
    # save as pdf in a minimal size
    plt.tick_params(axis='both', which='both', length=4, width=1, direction='out')
    plt.savefig(figure_path+'NeuralSeq_$\\alpha=%.2f$'%(Alphas[plot_idx[i]])+'.pdf', format='pdf', dpi=100, bbox_inches='tight')


plt.show()

