import brainpy_datasets as bd
import brainpy as bp
import brainpy.math as bm

import jax.numpy as jnp
import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

from utils.learning_strategy import layer_stdp #, InputGroup, RandomInit

def getdata():
    train_data = bm.zeros((8,8))
    x = [3,3,3,3,3,4,5,6,7]
    y = [3,4,5,6,7,3,3,3,3]
    train_data[x, y] = 1

    return train_data

class Net_sim(bp.DynamicalSystem):
  def __init__(self, size):
    super().__init__()

    self.n1 = bp.neurons.GIF(size=size)
    self.n = bp.neurons.GIF(size=size, V_initializer=bp.init.OneInit(-40))
    
    self.s = layer_stdp(self.n1, self.n, bp.conn.SmallWorld(num_neighbor=64, prob=1),
                        w_max=1,
                        w_min=-1.,
                        g_max=bp.init.OneInit(0.18),
                        )
    self.map = layer_stdp(self.n1, self.n, bp.conn.One2One(),
                        wdelay=1,
                        w_max=8, w_min=-1.,
                        theta_p= 1, theta_n = 1e-2,
                        g_max=bp.init.Normal(0.5, 0.5, seed=2023),
                        )
    
    self.s1 = layer_stdp(self.n, self.n, bp.conn.SmallWorld(num_neighbor=5, prob=0.2),
                        w_max=1,
                        w_min=-1.,
                        g_max=bp.init.OneInit(0.01),
                        )
    
    self.input_shape = getdata().reshape(1, -1)
    


  def limit(self):
    lower_V = -100
    inh_value = 20
    self_inh = 0

    if_pro = bm.where(self.input_shape, self.n.spike.value, 0)
    if_pro = bm.sum(if_pro, axis=1)
    prohibation = self.input_shape.reshape(self.n.spike.shape[0], -1) * if_pro

    self.n.V.value = bm.where(self.n.spike + prohibation == 2, self.n.V - prohibation * self_inh, self.n.V - prohibation * inh_value)
    self.n.V.value = bm.where(self.n.spike, lower_V, self.n.V)

  def update(self, tdi, x):
    self.s1.update(tdi)
    self.map(tdi)
    x >> self.n1 >> self.s >> self.n

    self.s.weight_update(tdi)
    self.s1.weight_update(tdi)
    self.map.weight_update(tdi)
    self.limit()

  def set_para(self, **kwargs):
    self.s.set_para(**kwargs)
    self.s1.set_para(**kwargs)
    self.map.set_para(**kwargs)
  
  def clear_input(self):
    self.n.clear_input()
    self.n1.clear_input()
  
  def homeostatic(self, batch_size):
    self.map.homeostatic(batch_size)