# -*- coding: utf-8 -*-

from typing import Union, Dict

import brainpy as bp
import brainpy.math as bm
# from brainpy._src.dynsys import NeuGroup, NeuGroupNS
#from brainpy._src.connect import TwoEndConnector, One2One
from brainpy.types import ArrayType, Shape
import jax.numpy as jnp

from brainpy._src.initialize import variable


class RandomInit_3D(bp.init.Initializer):
    def __init__(self, max_value):
      bp.math.random.seed(2023)
      self.max_value = max_value
      pass
    def __call__(self, shape):
        mat = bp.math.random.randint(0, self.max_value, shape) * 1.0 - self.max_value // 2
        return mat

class RandomInit(bp.init.Initializer):
    def __init__(self, max_value, scale=1., move = 0.):
      bp.math.random.seed(2023)
      self.max_value = max_value
      self.scale = scale
      self.move = move
      pass
    def __call__(self, shape):
        mat = bp.math.random.randint(0, self.max_value, shape) * 1.0 / self.scale + self.move
        return mat




# class InputGroup(NeuGroupNS):

#   def __init__(
#       self,
#       size: Shape,
#       keep_size: bool = False,
#       mode: bm.Mode = None,
#       name: str = None,
#   ):
#     super(InputGroup, self).__init__(name=name,
#                                      size=size,
#                                      keep_size=keep_size,
#                                      mode=mode)
#     self.spike = variable(bm.zeros, self.mode, self.varshape)

#   def update(self, x):
#     self.spike.value = x
#     return x

#   def reset_state(self, batch_size=None):
#     self.spike = variable(bm.zeros, batch_size, self.varshape)


class layer_stdp(bp.synapses.AMPA):
  def __init__(
      self,
      pre,
      post,
      conn,
      mod: bool = False,
      tau_s: float = 40., 
      tau_t: float = 40.,  
      A1: float = 1,
      A2: float = 1,
      theta_p: float = 1e-4,
      theta_n: float = 1e-4,
      update_interval: int = 1,
      method: str = 'exp_auto',
      w_max: float = 1.,
      w_min: float = 0.,
      wdelay: float = 8.,
      **kwargs
  ):
    super(layer_stdp, self).__init__(pre=pre,
                               post=post,
                               conn=conn,
                               method=method,
                               **kwargs)

    # parameters
    self.tau_s = tau_s
    self.tau_t = tau_t
    self.A1 = A1
    self.A2 = A2
    self.mod = mod
    self.theta_p = theta_p
    self.theta_n = theta_n
    self.update_interval = update_interval
    self.w_max = w_max
    self.w_min = w_min
    self.wdelay = wdelay




    self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
    self.Wshape = self.conn.require('conn_mat').shape
    self.num = len(self.pre_ids)

    self.trace_pre = variable(bm.zeros, self.mode, self.num)
    self.trace_post = variable(bm.zeros, self.mode, self.num)
    self.trace_pre_register = variable(bm.zeros, self.mode, self.Wshape[0])
    self.trace_post_register = variable(bm.zeros, self.mode, self.Wshape[1])

    self.integral_stdp = bp.odeint(method=method, f=self.derivative_stdp)

  @property
  def derivative_stdp(self):
    dtrace_pre = lambda trace_pre, t: - trace_pre / self.tau_s
    dtrace_post = lambda trace_post, t: - trace_post / self.tau_t
    return bp.JointEq([dtrace_pre, dtrace_post]) 

  def set_para(self, **kwargs):
    self.mod = kwargs.get("mod", False)

   
  def reset_state(self, batch_size=None):
    super().reset_state(batch_size=batch_size)
    self.trace_pre = variable(bm.zeros, batch_size, self.num)
    self.trace_post = variable(bm.zeros, batch_size, self.num)

    # =======================================================
    self.trace_pre_register = variable(bm.zeros, batch_size, self.Wshape[0])
    self.trace_post_register = variable(bm.zeros, batch_size, self.Wshape[1])
    # =======================================================   

  def _trace_update(self, pre_spikes, post_spikes, t, dt):
      self.trace_pre.value, self.trace_post.value = self.integral_stdp(self.trace_pre, self.trace_post, t, dt)
      
      trace_pre = bm.where(pre_spikes, self.trace_pre + self.A1, self.trace_pre)
      trace_post = bm.where(post_spikes, self.trace_post + self.A2, self.trace_post)

      self.trace_pre.value = trace_pre
      self.trace_post.value = trace_post
  
  def _norm(self):
    w_max = bm.max(bm.max(self.g_max))
    w_min = bm.min(bm.min(self.g_max))
    self.g_max.value = (self.g_max - w_min) / (w_max - w_min) * (self.w_max - self.w_min) + self.w_min

  def homeostatic(self, batch_size):

    in_features = self.Wshape[0]
    out_features = self.Wshape[1]

    if isinstance(self.conn, bp.connect.One2One):
      temp = (self.trace_pre_register == 0).view(batch_size, in_features) * \
            (self.trace_post_register == 1).view(batch_size, out_features)
    else:
      temp = (self.trace_pre_register == 0).view(batch_size, in_features, 1) @ \
            (self.trace_post_register == 1).view(batch_size, 1, out_features)
      
    temp = bm.sum(temp, axis = 0)
    self.g_max.value = self.g_max - temp * self.wdelay / batch_size
    self.g_max.value = bm.where(self.g_max < self.w_min, self.w_min, self.g_max)

  def weight_update(self, tdi):
    if not self.mod:
      return 
    
    self.trace_pre_register.value = bm.where(self.pre.spike.value == 1, 1, self.trace_pre_register)
    self.trace_post_register.value = bm.where(self.post.spike.value == 1, 1, self.trace_post_register)

    pre_spikes = self.pre.spike.value[:, self.pre_ids].reshape(-1, self.num)
    post_spikes = self.post.spike.value[:, self.post_ids].reshape(-1, self.num)

    if len(pre_spikes.shape) == 2:
      batch_size = pre_spikes.shape[0]
    else:
      batch_size = 1


    w_pre = jnp.sum(bm.where(pre_spikes, - self.trace_post , 0).value,
                            axis=0)
    w_post = jnp.sum(bm.where(post_spikes, self.trace_pre, 0).value,
                            axis=0)
    delta_w_pre = bm.zeros(self.Wshape)
    delta_w_post = bm.zeros(self.Wshape)
    delta_w_pre[self.pre_ids, self.post_ids] = w_pre
    delta_w_post[self.pre_ids, self.post_ids] = w_post
    
    if isinstance(self.conn, bp.connect.One2One):
      self.g_max.value = self.g_max + w_pre * self.theta_n / batch_size / self.update_interval + \
                  w_post * self.theta_p / batch_size / self.update_interval
    else:
      self.g_max.value = self.g_max + delta_w_pre * self.theta_n / batch_size / self.update_interval + \
                  delta_w_post * self.theta_p / batch_size / self.update_interval 
      
    self.g_max.value = bm.where(self.g_max > self.w_max, self.w_max, self.g_max)
    self.g_max.value = bm.where(self.g_max < self.w_min, self.w_min, self.g_max)

    self._trace_update(pre_spikes, post_spikes, tdi.t, tdi.dt)
