from brian2 import *
from brian2tools import *
import scipy as sp
import struct
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
import pickle
from bqplot import *
import ipywidgets as widgets
import warnings
import os
from keras.datasets import cifar10
from utils import *


warnings.filterwarnings("ignore")
prefs.codegen.target = "numpy"
start_scope()
np.random.seed(100)
data_path = './Data/CIFAR10_data/'



###################################
#--------switch setting--------
Switch_monitor = True
Switch_plasticity = True
READ_WEIGHT = True

# -----parameter setting-------
coding_n = 10
CIFAR10_shape = (32, 32, 3)  # 3 Channels for RGB

coding_duration = 10
duration = coding_duration*CIFAR10_shape[0]
F_plasticity = 0.05
F_train = 0.05
F_test = 0.05
Dt = defaultclock.dt = 1*ms

n_ex = 400
n_inh = int(n_ex/4)
n_input = CIFAR10_shape[1]*coding_n
n_read = n_ex+n_inh

R = 2
f = 1

A_EE = 30*f
A_EI = 60*f
A_IE = 19*f
A_II = 19*f
A_inE = 18*f
A_inI = 9*f

p_inE = 0.1
p_inI = 0.1

learning_rate = 0.01


###########################################
#-------class initialization----------------------
function = Function()
base = Base(duration, Dt)
readout = Readout(function.logistic)
result = Result()
CIFAR10 = CIFAR10Classification((32, 32, 3), duration, Dt)

#-------data initialization----------------------
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
df_plasticity = CIFAR10.select_data(F_plasticity, CIFAR10.train)
df_train = CIFAR10.select_data(F_train, CIFAR10.train)
df_test = CIFAR10.select_data(F_test, CIFAR10.test)

df_en_plasticity = CIFAR10.encoding_latency_CIFAR10(CIFAR10._encoding_cos_rank_ignore_0, df_plasticity, coding_n)
df_en_train = CIFAR10.encoding_latency_CIFAR10(CIFAR10._encoding_cos_rank_ignore_0, df_train, coding_n)
df_en_test = CIFAR10.encoding_latency_CIFAR10(CIFAR10._encoding_cos_rank_ignore_0, df_test, coding_n)

data_plasticity_s, label_plasticity = CIFAR10.get_series_data_list(df_en_plasticity, is_group = True)
data_train_s, label_train = CIFAR10.get_series_data_list(df_en_train, is_group = True)
data_test_s, label_test = CIFAR10.get_series_data_list(df_en_test, is_group = True)

#------definition of equation-------------
neuron_in = '''
I = stimulus(t,i) : 1
'''

neuron = '''
dv/dt = (I-v) / (30*ms) : 1 (unless refractory)
dg/dt = (-g)/(3*ms) : 1
dh/dt = (-h)/(6*ms) : 1
I = (g+h)+13.5: 1
x : 1
y : 1
z : 1
'''

neuron_read = '''
dv/dt = (I-v) / (30*ms) : 1
dg/dt = (-g)/(3*ms) : 1 
dh/dt = (-h)/(6*ms) : 1
I = (g+h): 1
'''

synapse = '''
w : 1
'''

on_pre_ex = '''
g+=w
'''

on_pre_inh = '''
h-=w
'''

synapse_stdp = '''
w : 1
w_max : 1
w_min : 1
tau_ahead : second
tau_latter : second
A_ahead : 1
A_latter = -A_ahead * tau_ahead / tau_latter * 1.2 : 1
da_ahead/dt = -a_ahead/tau_ahead : 1 (clock-driven)
da_latter/dt = -a_latter/tau_latter : 1 (clock-driven)
'''

on_pre_ex_stdp = '''
g+=w
a_ahead += A_ahead * int(Switch_plasticity)
w = clip(w+a_latter, w_min, w_max)
a_latter = 0
'''

on_post_ex_stdp = '''
a_latter += A_latter * int(Switch_plasticity)
w = clip(w+a_ahead, w_min, w_max)
a_ahead = 0
'''

# -----Neurons and Synapses setting-------
Input = NeuronGroup(n_input, neuron_in, threshold='I > 0', method='euler', refractory=0 * ms,
                    name = 'neurongroup_input')

G_ex = NeuronGroup(n_ex, neuron, threshold='v > 15', reset='v = 13.5', method='euler', refractory=3 * ms,
                name ='neurongroup_ex')

G_inh = NeuronGroup(n_inh, neuron, threshold='v > 15', reset='v = 13.5', method='euler', refractory=2 * ms,
                name ='neurongroup_in')

G_readout = NeuronGroup(n_read, neuron_read, method='euler', name='neurongroup_read')

S_inE = Synapses(Input, G_ex, synapse, on_pre = on_pre_ex ,method='euler', name='synapses_inE')

S_inI = Synapses(Input, G_inh, synapse, on_pre = on_pre_ex ,method='euler', name='synapses_inI')

S_EE = Synapses(G_ex, G_ex, synapse_stdp, on_pre = on_pre_ex_stdp, on_post= on_post_ex_stdp,method='euler', name='synapses_EE')

S_EI = Synapses(G_ex, G_inh, synapse, on_pre = on_pre_ex ,method='euler', name='synapses_EI')

S_IE = Synapses(G_inh, G_ex, synapse, on_pre = on_pre_inh ,method='euler', name='synapses_IE')

S_II = Synapses(G_inh, G_inh, synapse, on_pre = on_pre_inh ,method='euler', name='synapses_I')

S_E_readout = Synapses(G_ex, G_readout, 'w = 1 : 1', on_pre=on_pre_ex, method='euler')

S_I_readout = Synapses(G_inh, G_readout, 'w = 1 : 1', on_pre=on_pre_inh, method='euler')

#-------initialization of neuron parameters----------
G_ex.v = '13.5+1.5*rand()'
G_inh.v = '13.5+1.5*rand()'
G_readout.v = '0'
G_ex.g = '0'
G_inh.g = '0'
G_readout.g = '0'
G_ex.h = '0'
G_inh.h = '0'
G_readout.h = '0'

[G_ex,G_in] = base.allocate([G_ex,G_inh],5,5,20)

# -------initialization of network topology and synapses parameters----------
S_inE.connect(condition='j<0.3*N_post', p = p_inE)
S_inI.connect(condition='j<0.3*N_post', p = p_inI)
S_EE.connect(condition='i != j', p='0.3*exp(-((x_pre-x_post)**2+(y_pre-y_post)**2+(z_pre-z_post)**2)/R**2)')
S_EI.connect(p='0.2*exp(-((x_pre-x_post)**2+(y_pre-y_post)**2+(z_pre-z_post)**2)/R**2)')
S_IE.connect(p='0.4*exp(-((x_pre-x_post)**2+(y_pre-y_post)**2+(z_pre-z_post)**2)/R**2)')
S_II.connect(condition='i != j', p='0.1*exp(-((x_pre-x_post)**2+(y_pre-y_post)**2+(z_pre-z_post)**2)/R**2)')
S_E_readout.connect(j='i')
S_I_readout.connect(j='i+n_ex')

S_inE.w = function.gamma(A_inE, S_inE.w.shape)
S_inI.w = function.gamma(A_inI, S_inI.w.shape)
S_EE.w = function.gamma(A_EE, S_EE.w.shape)
S_IE.w = function.gamma(A_IE, S_IE.w.shape)
S_EI.w = function.gamma(A_EI, S_EI.w.shape)
S_II.w = function.gamma(A_II, S_II.w.shape)

S_EE.pre.delay = '1.5*ms'
S_EI.pre.delay = '0.8*ms'
S_IE.pre.delay = '0.8*ms'
S_II.pre.delay = '0.8*ms'

S_EE.w_max = np.max(S_EE.w)
S_EE.w_min = np.min(S_EE.w)
S_EE.A_ahead = learning_rate
S_EE.tau_ahead = S_EE.tau_latter = '3*ms'

# --------monitors setting----------
if Switch_monitor :
    m_g_in = StateMonitor(G_in, (['I', 'v']), record=True)
    m_g_ex = StateMonitor(G_ex, (['I', 'v']), record=True)
    m_read = StateMonitor(G_readout, (['I', 'v']), record=True)
    m_input = StateMonitor(Input, ('I'), record=True)
    m_s_ee = StateMonitor(S_EE, ('w'), record=True)

# ------create network-------------
net = Network(collect())
net.store('init')


###############################################
# ------run for plasticity-------
if Switch_plasticity:
    weight_changed, spectral_radius, monitor_record_pre_train = run_net_plasticity(data_plasticity_s)

#------save monitor data and results------
if Switch_monitor:
    result.result_save('monitor_pre_train.pkl', **monitor_record_pre_train)
    result.result_save('weight_changed.pkl', weight_changed = weight_changed)
    result.result_save('spectral_radius.pkl', spectral_radius=spectral_radius)

#-------close plasticity--------
Switch_plasticity = False
net._stored_state['init'][S_EE.name]['w'] = S_EE._full_state()['w']


###############################################
#-------read weight ------------
if READ_WEIGHT:
    try:
        weight = result.result_pick('weight.pkl')
        S_EE.w = weight
    except FileNotFoundError:
        print ('Have not trained by plasticity, initial weight have been used')

# ------run for train-------
states_train, monitor_record_train = run_net(data_train_s)

# ------run for test--------
states_test, monitor_record_test = run_net(data_test_s)

# ------Readout---------------
score_train, score_test = readout.readout_sk(states_train, states_test, label_train, label_test, solver="lbfgs",
                                             multi_class="multinomial")
#----------show results-----------
print('Train score: ',score_train)
print('Test score: ',score_test)

#------save monitor data and results------
if Switch_monitor:
    result.result_save('monitor_train.pkl', **monitor_record_train)
    result.result_save('monitor_test.pkl', **monitor_record_test)
    result.result_save('states_records.pkl', states_train = states_train, states_test = states_test)


#####################################
# ------vis of results-------
fig_init_w =plt.figure(figsize=(16,16))
subplot(421)
brian_plot(S_EE.w)
subplot(422)
brian_plot(S_EI.w)
subplot(423)
brian_plot(S_IE.w)
subplot(424)
brian_plot(S_II.w)
show()

#-------for animation in Jupyter-----------
monitor = result.result_pick('monitor_test.pkl')
play, slider, fig = result.animation(np.arange(monitor['m_read.v'].shape[1]), monitor['m_read.v'], duration, 10*duration)
widgets.VBox([widgets.HBox([play, slider]),fig])