import torch
import torch.nn as nn
import torch.nn.functional as F


class OnlineNeuron(nn.Module):
    def __init__(self, T, mem_bn=False, num_features=12, parallel_mode=False):
        super(OnlineNeuron, self).__init__()
        self.v_th = nn.Parameter(torch.zeros(T), requires_grad=True)
        self.T = T
        self.mem_bn = mem_bn
        self.parallel_mode = parallel_mode
        self.t = 0
        self.v_mem = None
        if parallel_mode is False:
            self.alpha = nn.Parameter(torch.zeros(T), requires_grad=True)
            if mem_bn is True:
                self.bn_ratio = nn.Parameter(torch.zeros(T), requires_grad=True)
                self.mem_ratio = nn.Parameter(torch.zeros(T), requires_grad=True)
                self.bn = nn.ModuleList([nn.BatchNorm2d(num_features) for i in range(T)])
        
    def forward(self, x):
        if self.t == 0 and self.parallel_mode is False:
            self.v_mem = torch.zeros_like(x)
        if self.parallel_mode is True:
            return F.relu(x - self.v_th[self.t])
            #return F.relu(x - self.v_th[self.t].sigmoid())
        self.v_mem = self.alpha[self.t].sigmoid() * self.v_mem.detach() + x
        if self.mem_bn is True:
            self.v_mem = self.bn_ratio[self.t].sigmoid() * self.bn[self.t](self.v_mem) + self.mem_ratio[self.t].sigmoid() * self.v_mem
        spike = F.relu(self.v_mem - self.v_th[self.t])
        #spike = F.relu(self.v_mem - self.v_th[self.t].sigmoid())
        self.v_mem -= spike
        self.t += 1
        return spike
        
    def reset(self):
        self.t = 0
        self.v_mem = None
