import numpy as np; import pylab as plt

class LIF():
    def __init__(self):
        self.init_misc()
        self.init_params() 
        self.init_logs()
        self.run_sim()
        
    def init_misc(self):
        self.t = 0
        self.N = 200
        self.wait = 0 
        self.trial_ind = 0
        self.trials = 10
        self.steps = 60
        self.stim_dur = 25
        self.blocks = 3
        self.PGOs = np.linspace(0,.9, 4)
        self.bins = self.trials*self.blocks*self.steps
        
    def init_params(self):
        self.V_leak = -70
        self.V_thresh = -50 
        self.Cm = .2
        self.gL = 20
        
        
        self.g_ext = 1.5
        self.g_nogo = 0
        self.g_go = 2
        self.I_E_W = -10*(np.random.rand(self.N, self.N))/self.N
        self.E_I_W = 10*(np.random.rand(self.N, self.N))/self.N
        self.E_E_W = 25*(np.random.rand(self.N, self.N))/self.N
        self.E_ACT_W = 10*(np.random.rand(self.N, 1))/self.N
        self.g_leak = .1
        self.g_noise = 1
        
    def init_logs(self):
        self.ACT_neuron =  self.resting*np.ones((1, self.bins)) 
        self.stim_log, self.I_log = [np.zeros(self.bins) for _ in range(2)]
        self.E_neurons, self.I_neurons = [self.resting*np.ones((self.N, self.bins)) for _ in range(2)]
        self.act_log, self.wait_log, self.PGO_log = [1e3*np.ones(self.blocks*self.trials) for _ in range(3)]
        
    def run_sim(self):
        for self.block in range(self.blocks): 
            self.PGO = [0, .45, .9][self.block]#np.random.choice(self.PGOs)
            for self.trial in range(self.trials):
                self.PGO_log[self.trial_ind] = self.PGO
                self.not_acted = 1 
                for self.step in range(self.steps):
                    if self.t < self.bins - 1:
                        self.GO = (self.not_acted and (self.step > self.stim_dur)) or np.random.binomial(1, p = self.PGO) 
                        self.NOGO = 1 - self.GO 
                        self.update_neurons()
                        self.logging()
                        self.t += 1
                self.trial_ind +=  1 
             
    def update_neurons(self,):   
        self.check_spikes()

        I_syn =  self.g_ext + self.E_I_W.T @ self.E_spikes 
        self.update(self.I_neurons, I_syn)        

        I_syn = self.g_go*self.GO + self.E_E_W.T @ self.E_spikes + self.I_E_W.T @ self.I_spikes 
        self.update(self.E_neurons, I_syn)        

        I_syn = self.E_ACT_W.T @ (self.E_neurons[:, self.t] == 1)
        self.update(self.ACT_neuron, I_syn)

    def check_spikes(self):
        self.ACT_spike = 1 == self.ACT_neuron[:, self.t-1]
        self.I_spikes = 1 == self.I_neurons[:, self.t-1]
        self.E_spikes = 1 == (self.E_neurons[:, self.t-1] == 1) 
        
    def get_I_syn(self, V, g_AMPA, g_NMDA, g_GABA, s_GABA):
        g_AMPA*s_AMPA*(V[self.t])
    
                
    def update(self, neurons, inputs):
        reversal = self.g_leak*(self.resting - neurons[:, self.t])
        noise = (self.g_noise*np.random.randn(neurons.shape[0]))
        d_E = reversal + inputs + noise 
        neurons[:, self.t + 1] = neurons[:, self.t] + d_E
        self.spiking(neurons)
        
    def spiking(self, neurons):        
        self.spike = (neurons[:,self.t + 1] > self.thresh)
        neurons[:, self.t] = neurons[:, self.t] * (1 - self.spike) + self.spike
        neurons[:, self.t + 1] = neurons[:, self.t + 1] * (1 - self.spike) + self.spike * self.resting

    def logging(self):
        self.stim_log[self.t] = self.GO
        self.wait = (1+self.wait)*self.GO
        if self.ACT_spike and self.not_acted:
            self.wait_log[self.trial_ind] = self.wait 
            self.act_log[self.trial_ind] = self.step 
            self.not_acted = 0 
                        
            
    def plot(self, x1, title, leg = None, x2 = None, x3 = None, convolve = False, ylim1 = None, ylim2 = None):
        plt.figure(figsize=(40,5))
        if convolve:
            x1 = x1#np.convolve(x1, np.ones(10)/10)[10:-10]
        plt.plot(x1, alpha = .5) 
        if x2 is not None:
            plt.plot(x2, alpha = .5)
        if x3 is not None:
            plt.plot(x3, alpha = .5)
        if ylim1 is not None:
            plt.ylim([ylim1, ylim2])
        if leg is not None:
            plt.legend(leg)
        plt.title(title)
        plt.show()

L = LIF() 
L.plot(L.stim_log[:-1], title = "stim", convolve = True)
L.plot(L.E_neurons[0:3].T, title = "example E neurons")
L.plot(L.I_neurons[0:3].T, title = "example I neurons")

L.plot(L.E_neurons.mean(0).T, x2 = L.I_neurons.mean(0).T, x3 = L.ACT_neuron.T, title = "mean E neurons", leg = ["mean E", "mean I", "act"])
L.plot(L.act_log, x2 = 100*L.PGO_log, x3 = L.wait_log, title = "logs", leg = ["action times", "PGO", "wait times"], convolve = True, ylim1 = 0, ylim2 = 110)