from abc import abstractmethod
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import logging

from .neuron_spikingjelly import OTTTLIFNode
from . import surrogate


class OPZOLIFNode(OTTTLIFNode):
    def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1.,
                 v_reset: float = None, surrogate_function: Callable = surrogate.SigmoidOPZO(alpha=4.),
                 detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False):

        super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)

    def neuronal_fire(self, return_grad=False):
        return self.surrogate_function(self.v - self.v_threshold, return_grad=return_grad)

    def single_step_forward(self, x: torch.Tensor, return_grad=False, return_trace=True):
        """
        训练时，输出脉冲和迹；推理时，输出脉冲
        训练时需要将后续参数模块用layer.py中定义的GradwithTrace进行包装，根据迹计算梯度
        
        output spike and trace during training; output spike during inference
        during training, successive parametric modules shoule be wrapped by GradwithTrace defined in layer.py, to calculate gradients with traces
        """

        if not hasattr(self, 'v'):
            if self.v_reset is None:
                self.register_buffer('v', torch.zeros_like(x))
            else:
                self.register_buffer('v', torch.ones_like(x) * self.v_reset)

        if self.training:
            if return_trace:
                if not hasattr(self, 'trace'):
                    self.register_buffer('trace', torch.zeros_like(x))
    
            if self.backend == 'torch':
                self.neuronal_charge(x)
                if return_grad:
                    spike, grad = self.neuronal_fire(return_grad=True)
                else:
                    spike = self.neuronal_fire()
                self.neuronal_reset(spike)

                if not return_trace:
                    if return_grad:
                        return spike, grad
                    return spike

                self.trace = self.track_trace(spike, self.trace, self.tau)

                if return_grad:
                    return [spike, self.trace], grad
                return [spike, self.trace]
            else:
                raise ValueError(self.backend)
        else:
            if self.v_reset is None:
                if self.decay_input:
                    spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
                                                                                             self.v_threshold, self.tau)
                else:
                    spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
                                                                                                self.v_threshold,
                                                                                                self.tau)
            else:
                if self.decay_input:
                    spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
                                                                                             self.v_threshold,
                                                                                             self.v_reset, self.tau)
                else:
                    spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
                                                                                                self.v_threshold,
                                                                                                self.v_reset,
                                                                                                self.tau)
            return spike

