#from sparse import sparse_layer
import torch.nn as nn
import torch
from .spikeLayer import *

class MLP_optimalThres(nn.Module):
    def __init__(self, indim, hiddim, outdim) -> None:
        super().__init__()
        self.flatten = nn.Flatten()
        self.Linear1 = nn.Linear(indim, hiddim[0])
        self.Linear2 = nn.Linear(hiddim[0], hiddim[1])
        self.Linear3 = nn.Linear(hiddim[1], hiddim[2])
        self.last_layer = nn.Linear(hiddim[2], outdim)

        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.max_active = [0] * 4

    def init_thresh(self,x):
        #原来是zeros_like(out).非常滑稽，一个batch的数据每个都有不同的thre
        out=self.flatten(x)
        out = self.relu1(self.Linear1(out))
        self.max_active[0] = torch.zeros_like(out)  
        out = self.relu2(self.Linear2(out))
        self.max_active[1] = torch.zeros_like(out)
        out = self.relu3(self.Linear3(out))
        self.max_active[2] = torch.zeros_like(out)
        out = self.last_layer(out)
        self.max_active[3] = torch.zeros_like(out)

    def forward(self, x):
        out=self.flatten(x)
        out = self.relu1(self.Linear1(out))
        self.max_active[0] = torch.maximum(self.max_active[0], out)
        out = self.relu2(self.Linear2(out))
        self.max_active[1] = torch.maximum(self.max_active[1], out)
        out = self.relu3(self.Linear3(out))
        self.max_active[2] = torch.maximum(self.max_active[2], out)
        out = self.last_layer(out)
        self.max_active[3] = torch.maximum(self.max_active[3], out)

        return out
    
class MLP_PosNeg_spiking(nn.Module):
    def __init__(self, thresh_list, model):
        super().__init__()
        self.flatten = model.flatten
        self.Linear1 = SPIKE_PosNeg_layer(thresh_list[0], -thresh_list[0], model.Linear1)
        self.Linear2 = SPIKE_PosNeg_layer(thresh_list[1], -thresh_list[1], model.Linear2)
        self.Linear3 = SPIKE_PosNeg_layer(thresh_list[2], -thresh_list[2], model.Linear3)
        self.last_layer = SPIKE_PosNeg_layer(thresh_list[3], -thresh_list[3], model.last_layer)

    def init_layer(self):
        self.Linear1.init_mem()
        self.Linear2.init_mem()
        self.Linear3.init_mem()
        self.last_layer.init_mem()

    def forward(self, x, time):
        out = self.flatten(x)
        out, m_1 = self.Linear1(out, time)
        out, m_2 = self.Linear2(out, time)
        out, m_3 = self.Linear3(out, time)
        out, m_4 = self.last_layer(out, time)
        return out
    
    def weight_bias_norm(self):
        pass #MLP doesn't need to fold BN
    
'''
backup for SNN range(T)

self.init_layer()
with torch.no_grad():
    out_spike_sum = 0
    for time in range(self.T):
        #forward
    
    out_spike_sum += output
    if (time + 1) == self.T:
        sub_result = out_spike_sum / (time + 1)
    return sub_result #rate coding
'''
    
