
import brainpy as bp
from utils.learning_strategy import layer_stdp, RandomInit, RandomInit_3D

class M_2D(bp.DynamicalSystem):
  def __init__(self, tau, alpha, beta, T_dur, K, delay):
    super().__init__()

    self.n = bp.neurons.GIF(size=57 * 57, 
                            V_reset=-70, V_rest= -70, V_th_inf=-46, V_th_reset = -58, 
                            tau=tau, R=1.,
                            V_initializer=bp.init.OneInit(-70),
                            Vth_initializer= bp.init.OneInit(-58.),
                            )
    
    self.s = layer_stdp(self.n, self.n, bp.conn.SmallWorld(K, 0.2),
                        wdelay=delay,
                        w_max=1023,
                        alpha = alpha,
                        beta = beta,
                        g_max=RandomInit(1024),
                        T_duration = T_dur,
                        T = 0.01,
                        )
        


  def update(self, tdi, x):
    self.s.update(tdi)
    x  >> self.n
    self.s.weight_update(tdi)

  def set_para(self, **kwargs):
    self.s.set_para(**kwargs)
  
  def clear_input(self):
    self.n.clear_input()

  def homeostatic(self, batch_size):
    self.s.homeostatic(batch_size)



class M_3D(bp.DynamicalSystem):
  def __init__(self, tau, alpha, beta, T_dur, K, delay):
    super().__init__()

    self.n = bp.neurons.GIF(size=57 * 57, 
                            V_reset=-70, V_rest= -70, V_th_inf=-46, V_th_reset = -58, 
                            tau=tau, R=1.,
                            V_initializer=bp.init.OneInit(-70),
                            Vth_initializer= bp.init.OneInit(-58.),
                            )
    self.n1 = bp.neurons.GIF(size=57 * 57, 
                            V_reset=-70, V_rest= -70, V_th_inf=-46, V_th_reset = -58, 
                            tau=tau, R=1.,
                            V_initializer=bp.init.OneInit(-70),
                            Vth_initializer= bp.init.OneInit(-58.),
                            )     
    
 
    self.s = layer_stdp(self.n, self.n, bp.conn.SmallWorld(K, 0.2),
                        wdelay=delay,w_max=1023,
                        theta_p = 1e-1, theta_n = 1e-2,
                        # para for AMPA
                        alpha = alpha, beta = beta,
                        g_max=RandomInit_3D(1024),
                        T_duration = T_dur, T = 0.01,
                        )
    self.s1 = layer_stdp(self.n1, self.n1, bp.conn.SmallWorld(K, 0.05),
                        wdelay=delay,w_max=1023,
                        theta_p = 1e-1, theta_n = 1e-2,
                        # para for AMPA
                        alpha = alpha, beta = beta,
                        g_max=RandomInit_3D(1024),
                        T_duration = T_dur, T = 0.01,
                        )

    self.n2n1 = layer_stdp(self.n, self.n1, bp.conn.SmallWorld(K, 0.05),
                        wdelay=delay,w_max=1023,
                        theta_p = 1e-1, theta_n = 1e-2,
                        # para for AMPA
                        alpha = alpha, beta = beta,
                        g_max=RandomInit_3D(1024),
                        T_duration = T_dur, T = 0.01,
                        )
    self.n12n = layer_stdp(self.n1, self.n, bp.conn.SmallWorld(K, 0.05),
                        wdelay=delay,w_max=1023,
                        theta_p = 1e-1, theta_n = 1e-2,
                        # para for AMPA
                        alpha = alpha, beta = beta,
                        g_max=RandomInit_3D(1024),
                        T_duration = T_dur, T = 0.01,
                        )    

        

  def update(self, tdi, x):
    self.s.update(tdi)
    self.s1.update(tdi)
    self.n12n.update(tdi)
    x  >> self.n >> self.n2n1 >> self.n1
    self.s.weight_update(tdi)
    self.s1.weight_update(tdi)
    self.n12n.weight_update(tdi)
    self.n2n1.weight_update(tdi)
    

  def set_para(self, **kwargs):
    self.s.set_para(**kwargs)
    self.s1.set_para(**kwargs)
    self.n2n1.set_para(**kwargs)
    self.n12n.set_para(**kwargs)
  
  def clear_input(self):
    self.n.clear_input()
    self.n1.clear_input()

  def homeostatic(self, batch_size):
    self.s.homeostatic(batch_size)
    self.s1.homeostatic(batch_size)
    self.n12n.homeostatic(batch_size)
    self.n2n1.homeostatic(batch_size)