import torch
import torch.nn as nn
import torch.nn.functional as F
import util

class FINAL_ourmodel_PLIF_PTRACE(nn.Module):
    def __init__(
        self,
        num_steps: int = 5,
        init_tau: float = 2.0,              # membrane decaying time constant
        init_spk_trace_tau: float = 0.5,    # spike trace decaying time constant
        init_acc_tau: float = 2.0,          # accumulative membrane decaying time constant
        init_parametric_tau = True,
        init_version = 'v1',
        init_decay_acc = False,
        scale = 64, 
        subthresh  = 0.5,
        init_spk_trace_th = -0.35,
        init_spk_trace_a = -0.8
    ):
        super().__init__()
        self.num_steps = num_steps
        self.scale = scale
        self.subthresh = subthresh

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, self.scale, 3, stride=1, padding=1, bias=False), 
            nn.BatchNorm2d(self.scale))
        self.pPLIF0 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pPTRACE1 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)        
        self.SRB1_conv1 = nn.Sequential(
            nn.Conv2d(self.scale, self.scale, 5, stride=1, padding=2, bias=False), 
            nn.BatchNorm2d(self.scale))
        self.pPLIF1 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pPTRACE2 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)
        self.SRB1_conv2 = nn.Sequential(
            nn.Conv2d(self.scale, self.scale, 5, stride=1, padding=2, bias=False), 
            nn.BatchNorm2d(self.scale))
        self.SRB1_skip = nn.Sequential(nn.BatchNorm2d(self.scale))
        self.pPLIF2 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pool1 = util.OnePool2d.apply

        self.pPTRACE3 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)
        self.SRB2_conv1 = nn.Sequential(
            nn.Conv2d(self.scale, self.scale*2, 5, stride=1, padding=2, bias=False), 
            nn.BatchNorm2d(self.scale*2))
        self.pPLIF3 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pPTRACE4 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)     
        self.SRB2_conv2 = nn.Sequential(
            nn.Conv2d(self.scale*2, self.scale*2, 5, stride=1, padding=2, bias=False), 
            nn.BatchNorm2d(self.scale*2))
        self.SRB2_skip = nn.Sequential(
            nn.Conv2d(self.scale, self.scale*2, 1, stride=1, bias=False), 
            nn.BatchNorm2d(self.scale*2))
        self.pPLIF4 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))
        
        self.pool2 = util.OnePool2d.apply

        self.pPTRACE5 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)       
        self.SRB3_conv1 = nn.Sequential(
            nn.Conv2d(self.scale*2, self.scale*4, 3, stride=1, padding=1, bias=False), 
            nn.BatchNorm2d(self.scale*4))
        self.pPLIF5 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pPTRACE6 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)       
        self.SRB3_conv2 = nn.Sequential(
            nn.Conv2d(self.scale*4, self.scale*4, 3, stride=1, padding=1, bias=False), 
            nn.BatchNorm2d(self.scale*4))     
        
        self.SRB3_skip = nn.Sequential(
            nn.Conv2d(self.scale*2, self.scale*4, 1, stride=1, bias=False), 
            nn.BatchNorm2d(self.scale*4))
        self.pPLIF6 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pPTRACE7 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)
        self.SRB4_conv1 = nn.Sequential(
            nn.Conv2d(self.scale*4, self.scale*4, 5, stride=1, padding=2, bias=False), 
            nn.BatchNorm2d(self.scale*4))
        self.pPLIF7 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pPTRACE8 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)   
        self.SRB4_conv2 = nn.Sequential(
            nn.Conv2d(self.scale*4, self.scale*4, 5, stride=1, padding=2, bias=False), 
            nn.BatchNorm2d(self.scale*4))
        
        self.SRB4_skip = nn.Sequential(nn.BatchNorm2d(self.scale*4))
        self.pPLIF8 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))

        self.pool3 = util.OnePool2d.apply

        self.pPTRACE9 = util.pPTRACE_Node(parametric_tau=init_parametric_tau, version=init_version, surrogate_function1=util.SpkTraceBinaryModifiedCall, surrogate_function2= util.SpkTraceModifiedCall, surrogate_function3= util.SpkTraceSwitchCall)
        self.fc1 = nn.Linear(self.scale*4 * 4 * 4, 100, bias=False)
        self.pPLIF9 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall(thresh=1.0, subthresh=self.subthresh, alpha=1.0, spiking=True))
        
        self.boost1 = nn.AvgPool1d(10, 10)

        self.pPLI = util.pPLI_Node(decay_acc=init_decay_acc, surrogate_function=util.AccAlwaysGradCall())

        # self.tau_0 = nn.Parameter(torch.ones(1, dtype=torch.float)*init_tau)
        self.tau_0 = nn.Parameter(torch.tensor([1.9870]))
        self.tau_0.to("cuda" if torch.cuda.is_available() else "cpu")

        # self.tau_vector = nn.Parameter(torch.ones(9, dtype=torch.float)*init_tau)
        self.tau_vector = nn.Parameter(torch.tensor([1.6541, 0.9080, 0.9243, 0.6587, 1.0764, 0.9103, 0.9568, 0.9435, 1.0807]))
        self.tau_vector.to("cuda" if torch.cuda.is_available() else "cpu")

        # self.spk_trace_tau_vector = nn.Parameter(torch.ones(9, dtype=torch.float)*init_spk_trace_tau) # 1.0
        self.spk_trace_tau_vector = nn.Parameter(torch.tensor([1.6541, 0.9080, 0.9243, 0.6587, 1.0764, 0.9103, 0.9568, 0.9435, 1.0807]))
        self.spk_trace_tau_vector.to("cuda" if torch.cuda.is_available() else "cpu")

        # self.spk_threshold_vector = nn.Parameter(torch.ones(9, dtype=torch.float)*init_spk_trace_th) # -0.35
        self.spk_threshold_vector = nn.Parameter(torch.tensor([ 0.0060, -0.0637, -0.8105, -0.0786, -1.4248, -0.3059,  0.2116, -0.0589,-0.7510]))
        self.spk_threshold_vector.to("cuda" if torch.cuda.is_available() else "cpu")

        # self.spk_a_vector = nn.Parameter(torch.ones(9, dtype=torch.float)*init_spk_trace_a) # -0.8
        self.spk_a_vector = nn.Parameter(torch.tensor([-0.3246, -0.6443, -0.6722, -0.9817, -0.9581, -0.9971, -1.1768, -0.9332, -0.7385]))
        self.spk_a_vector.to("cuda" if torch.cuda.is_available() else "cpu")

        # self.acc_tau = nn.Parameter(torch.ones(1, dtype=torch.float)*init_acc_tau)
        self.acc_tau = nn.Parameter(torch.tensor([5.6898]))
        self.acc_tau.to("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, x):
        self.device = x.device
        # spike_recording = []
        batch_size = x.size(0)
        x = x.view(batch_size, 3, 32, 32)

        c1_mem = c1_spike = torch.zeros(batch_size, self.scale, 32, 32, device=self.device)
        srb1_1_mem = srb1_1_spike = torch.zeros(batch_size, self.scale, 32, 32, device=self.device)
        srb1_2_mem = srb1_2_spike = torch.zeros(batch_size, self.scale, 32, 32, device=self.device)
        srb1_pool1 = torch.zeros(batch_size, self.scale, 16, 16, device=self.device)
        srb2_1_mem = srb2_1_spike = torch.zeros(batch_size, self.scale*2, 16, 16, device=self.device)
        srb2_2_mem = srb2_2_spike = torch.zeros(batch_size, self.scale*2, 16, 16, device=self.device)
        srb2_pool2 = torch.zeros(batch_size, self.scale*2, 8, 8, device=self.device)
        srb3_1_mem = srb3_1_spike = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)
        srb3_2_mem = srb3_2_spike = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)
        srb4_1_mem = srb4_1_spike = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)
        srb4_2_mem = srb4_2_spike = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)  
        srb4_pool3 = torch.zeros(batch_size, self.scale*4, 4, 4, device=self.device) 
        h1_mem = h1_spike = torch.zeros(batch_size, 100, device=self.device)
        boost1 = torch.zeros(batch_size, 10, device=self.device)
        acc_mem = torch.zeros(batch_size, 10, device=self.device)

        c1_mem.fill_(0.5)
        srb1_1_mem.fill_(0.5)
        srb1_2_mem.fill_(0.5)
        srb2_1_mem.fill_(0.5)
        srb2_2_mem.fill_(0.5)
        srb3_1_mem.fill_(0.5)
        srb3_2_mem.fill_(0.5)
        srb4_1_mem.fill_(0.5)
        srb4_2_mem.fill_(0.5)
        h1_mem.fill_(0.5)

        decay_0 =torch.sigmoid(self.tau_0)
        decay_vector = torch.sigmoid(self.tau_vector)
        acc_decay = torch.sigmoid(self.acc_tau)

        for step in range(self.num_steps-1):
            with torch.no_grad():
                c1_mem, c1_spike = self.pPLIF0(c1_mem.detach(), c1_spike.detach(), decay_0, self.conv1(x))

                srb1_1_mem, srb1_1_spike = self.pPLIF1(srb1_1_mem.detach(), srb1_1_spike.detach(), decay_vector[0], self.SRB1_conv1(c1_spike))

                srb1_2_mem, srb1_2_spike = self.pPLIF2(srb1_2_mem.detach(), srb1_2_spike.detach(), decay_vector[1], self.SRB1_conv2(srb1_1_spike)+self.SRB1_skip(c1_spike))

                srb1_pool1 = self.pool1(srb1_2_spike)

                srb2_1_mem, srb2_1_spike = self.pPLIF3(srb2_1_mem.detach(), srb2_1_spike.detach(), decay_vector[2], self.SRB2_conv1(srb1_pool1))

                srb2_2_mem, srb2_2_spike = self.pPLIF4(srb2_2_mem.detach(), srb2_2_spike.detach(), decay_vector[3], self.SRB2_conv2(srb2_1_spike)+self.SRB2_skip(srb1_pool1))

                # skip block

                srb2_pool2 = self.pool2(srb2_2_spike)

                srb3_1_mem, srb3_1_spike = self.pPLIF5(srb3_1_mem.detach(), srb3_1_spike.detach(), decay_vector[4], self.SRB3_conv1(srb2_pool2))

                srb3_2_mem, srb3_2_spike = self.pPLIF6(srb3_2_mem.detach(), srb3_2_spike.detach(), decay_vector[5], self.SRB3_conv2(srb3_1_spike)+self.SRB3_skip(srb2_pool2))

                srb4_1_mem, srb4_1_spike = self.pPLIF7(srb4_1_mem.detach(), srb4_1_spike.detach(), decay_vector[6], self.SRB4_conv1(srb3_2_spike))

                srb4_2_mem, srb4_2_spike = self.pPLIF8(srb4_2_mem.detach(), srb4_2_spike.detach(), decay_vector[7], self.SRB4_conv2(srb4_1_spike)+self.SRB4_skip(srb3_2_spike))

                srb4_pool3 = self.pool3(srb4_2_spike)

                h1_mem, h1_spike = self.pPLIF9(h1_mem.detach(), h1_spike.detach(), decay_vector[8], self.fc1(srb4_pool3.view(batch_size, -1)))

                boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

                acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        c1_mem, c1_spike = self.pPLIF0(c1_mem.detach(), c1_spike.detach(), decay_0, self.conv1(x))

        srb1_1_mem, srb1_1_spike = self.pPLIF1(srb1_1_mem.detach(), srb1_1_spike.detach(), decay_vector[0], self.SRB1_conv1(c1_spike))

        srb1_2_mem, srb1_2_spike = self.pPLIF2(srb1_2_mem.detach(), srb1_2_spike.detach(), decay_vector[1], self.SRB1_conv2(srb1_1_spike)+self.SRB1_skip(c1_spike))

        srb1_pool1 = self.pool1(srb1_2_spike)

        srb2_1_mem, srb2_1_spike = self.pPLIF3(srb2_1_mem.detach(), srb2_1_spike.detach(), decay_vector[2], self.SRB2_conv1(srb1_pool1))

        srb2_2_mem, srb2_2_spike = self.pPLIF4(srb2_2_mem.detach(), srb2_2_spike.detach(), decay_vector[3], self.SRB2_conv2(srb2_1_spike)+self.SRB2_skip(srb1_pool1))

        # skip block

        srb2_pool2 = self.pool2(srb2_2_spike)

        srb3_1_mem, srb3_1_spike = self.pPLIF5(srb3_1_mem.detach(), srb3_1_spike.detach(), decay_vector[4], self.SRB3_conv1(srb2_pool2))

        srb3_2_mem, srb3_2_spike = self.pPLIF6(srb3_2_mem.detach(), srb3_2_spike.detach(), decay_vector[5], self.SRB3_conv2(srb3_1_spike)+self.SRB3_skip(srb2_pool2))

        srb4_1_mem, srb4_1_spike = self.pPLIF7(srb4_1_mem.detach(), srb4_1_spike.detach(), decay_vector[6], self.SRB4_conv1(srb3_2_spike))

        srb4_2_mem, srb4_2_spike = self.pPLIF8(srb4_2_mem.detach(), srb4_2_spike.detach(), decay_vector[7], self.SRB4_conv2(srb4_1_spike)+self.SRB4_skip(srb3_2_spike))

        srb4_pool3 = self.pool3(srb4_2_spike)

        h1_mem, h1_spike = self.pPLIF9(h1_mem.detach(), h1_spike.detach(), decay_vector[8], self.fc1(srb4_pool3.view(batch_size, -1)))

        boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

        acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        return acc_mem, self.num_steps
        # return next - softmax and cross-entropy loss


class FINAL_ourmodel_et_PLIF_PTRACE_v3(FINAL_ourmodel_PLIF_PTRACE):
    def forward(self, x):

        self.device = x.device
        # spike_recording = []
        batch_size = x.size(0)
        x = x.view(batch_size, 3, 32, 32)

        c1_mem = c1_spike = srb1_1_spike_trace = srb1_1_spike_trace_b = torch.zeros(batch_size, self.scale, 32, 32, device=self.device)
        srb1_1_mem = srb1_1_spike = srb1_2_spike_trace = srb1_2_spike_trace_b = torch.zeros(batch_size, self.scale, 32, 32, device=self.device)
        srb1_2_mem = srb1_2_spike = srb2_1_spike_trace = srb2_1_spike_trace_b = torch.zeros(batch_size, self.scale, 32, 32, device=self.device)
        srb1_pool1 = srb1_pool1_b =torch.zeros(batch_size, self.scale, 16, 16, device=self.device)
        srb2_1_mem = srb2_1_spike = srb2_2_spike_trace = srb2_2_spike_trace_b = torch.zeros(batch_size, self.scale*2, 16, 16, device=self.device)
        srb2_2_mem = srb2_2_spike = srb3_1_spike_trace = srb3_1_spike_trace_b = torch.zeros(batch_size, self.scale*2, 16, 16, device=self.device)
        srb2_pool2 = srb2_pool2_b =torch.zeros(batch_size, self.scale*2, 8, 8, device=self.device)
        srb3_1_mem = srb3_1_spike = srb3_2_spike_trace = srb3_2_spike_trace_b = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)
        srb3_2_mem = srb3_2_spike = srb4_1_spike_trace = srb4_1_spike_trace_b = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)
        srb4_1_mem = srb4_1_spike = srb4_2_spike_trace = srb4_2_spike_trace_b = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)
        srb4_2_mem = srb4_2_spike = h1_spike_trace = h1_spike_trace_b = torch.zeros(batch_size, self.scale*4, 8, 8, device=self.device)  
        srb4_pool3 = srb4_pool3_b = torch.zeros(batch_size, self.scale*4, 4, 4, device=self.device) 
        h1_mem = h1_spike = torch.zeros(batch_size, 100, device=self.device)
        boost1 = torch.zeros(batch_size, 10, device=self.device)
        acc_mem = torch.zeros(batch_size, 10, device=self.device)

        c1_mem.fill_(0.5)
        srb1_1_mem.fill_(0.5)
        srb1_2_mem.fill_(0.5)
        srb2_1_mem.fill_(0.5)
        srb2_2_mem.fill_(0.5)
        srb3_1_mem.fill_(0.5)
        srb3_2_mem.fill_(0.5)
        srb4_1_mem.fill_(0.5)
        srb4_2_mem.fill_(0.5)
        h1_mem.fill_(0.5)

        decay_0 =torch.sigmoid(self.tau_0)
        decay_vector = torch.sigmoid(self.tau_vector)
        spk_trace_decay_vector = torch.sigmoid(self.spk_trace_tau_vector)
        acc_decay = torch.sigmoid(self.acc_tau)
        spk_threshold_vector = torch.sigmoid(self.spk_threshold_vector)
        spk_a_vector = torch.sigmoid(self.spk_a_vector)

        # print(self.tau_vector)
        # print(self.spk_trace_tau_vector)
        

        for step in range(self.num_steps-1):
            with torch.no_grad():
                c1_mem, c1_spike = self.pPLIF0(c1_mem.detach(), c1_spike.detach(),  decay_0, self.conv1(x))

                srb1_1_spike_trace, srb1_1_spike_trace_b = self.pPTRACE1(srb1_1_spike_trace.detach(), spk_trace_decay_vector[0], c1_spike, spk_threshold_vector[0], spk_a_vector[0])
                # srb1_1_spike_trace, srb1_1_spike_trace_b = self.pPTRACE1(srb1_1_spike_trace.detach(), spk_trace_decay_vector[0], c1_spike, 0.5)
                srb1_1_mem, srb1_1_spike = self.pPLIF1(srb1_1_mem.detach(), srb1_1_spike.detach(), decay_vector[0], self.SRB1_conv1(c1_spike))

                srb1_2_spike_trace, srb1_2_spike_trace_b = self.pPTRACE2(srb1_2_spike_trace.detach(), spk_trace_decay_vector[1], srb1_1_spike, spk_threshold_vector[1], spk_a_vector[1])
                # srb1_2_spike_trace, srb1_2_spike_trace_b = self.pPTRACE2(srb1_2_spike_trace.detach(), spk_trace_decay_vector[1], srb1_1_spike, 0.5)
                srb1_2_mem, srb1_2_spike = self.pPLIF2(srb1_2_mem.detach(), srb1_2_spike.detach(), decay_vector[1], self.SRB1_conv2(srb1_1_spike)+self.SRB1_skip(c1_spike))

                srb2_1_spike_trace, srb2_1_spike_trace_b = self.pPTRACE3(srb2_1_spike_trace.detach(), spk_trace_decay_vector[2], srb1_2_spike, spk_threshold_vector[2], spk_a_vector[2])
                # srb2_1_spike_trace, srb2_1_spike_trace_b = self.pPTRACE3(srb2_1_spike_trace.detach(), spk_trace_decay_vector[2], srb1_2_spike, 0.5)
                srb1_pool1 = self.pool1(srb1_2_spike)
                srb2_1_mem, srb2_1_spike = self.pPLIF3(srb2_1_mem.detach(), srb2_1_spike.detach(), decay_vector[2], self.SRB2_conv1(srb1_pool1))

                srb2_2_spike_trace, srb2_2_spike_trace_b = self.pPTRACE4(srb2_2_spike_trace.detach(), spk_trace_decay_vector[3], srb2_1_spike, spk_threshold_vector[3], spk_a_vector[3])
                # srb2_2_spike_trace, srb2_2_spike_trace_b = self.pPTRACE4(srb2_2_spike_trace.detach(), spk_trace_decay_vector[3], srb2_1_spike, 0.5)
                srb2_2_mem, srb2_2_spike = self.pPLIF4(srb2_2_mem.detach(), srb2_2_spike.detach(), decay_vector[3], self.SRB2_conv2(srb2_1_spike)+self.SRB2_skip(srb1_pool1))

                # skip block
                srb3_1_spike_trace, srb3_1_spike_trace_b = self.pPTRACE5(srb3_1_spike_trace.detach(), spk_trace_decay_vector[4], srb2_2_spike, spk_threshold_vector[4], spk_a_vector[4])
                # srb3_1_spike_trace, srb3_1_spike_trace_b = self.pPTRACE5(srb3_1_spike_trace.detach(), spk_trace_decay_vector[4], srb2_2_spike, 0.5)
                srb2_pool2 = self.pool2(srb2_2_spike)
                srb3_1_mem, srb3_1_spike = self.pPLIF5(srb3_1_mem.detach(), srb3_1_spike.detach(), decay_vector[4], self.SRB3_conv1(srb2_pool2))

                srb3_2_spike_trace, srb3_2_spike_trace_b = self.pPTRACE6(srb3_2_spike_trace.detach(), spk_trace_decay_vector[5], srb3_1_spike, spk_threshold_vector[5], spk_a_vector[5])
                # srb3_2_spike_trace, srb3_2_spike_trace_b = self.pPTRACE6(srb3_2_spike_trace.detach(), spk_trace_decay_vector[5], srb3_1_spike, 0.5)
                srb3_2_mem, srb3_2_spike = self.pPLIF6(srb3_2_mem.detach(), srb3_2_spike.detach(), decay_vector[5], self.SRB3_conv2(srb3_1_spike)+self.SRB3_skip(srb2_pool2))

                srb4_1_spike_trace, srb4_1_spike_trace_b = self.pPTRACE7(srb4_1_spike_trace.detach(), spk_trace_decay_vector[6], srb3_2_spike, spk_threshold_vector[6], spk_a_vector[6])
                # srb4_1_spike_trace, srb4_1_spike_trace_b = self.pPTRACE7(srb4_1_spike_trace.detach(), spk_trace_decay_vector[6], srb3_2_spike, 0.5)
                srb4_1_mem, srb4_1_spike = self.pPLIF7(srb4_1_mem.detach(), srb4_1_spike.detach(), decay_vector[6], self.SRB4_conv1(srb3_2_spike))

                srb4_2_spike_trace, srb4_2_spike_trace_b = self.pPTRACE8(srb4_2_spike_trace.detach(), spk_trace_decay_vector[7], srb4_1_spike, spk_threshold_vector[7], spk_a_vector[7])
                # srb4_2_spike_trace, srb4_2_spike_trace_b = self.pPTRACE8(srb4_2_spike_trace.detach(), spk_trace_decay_vector[7], srb4_1_spike, 0.5)
                srb4_2_mem, srb4_2_spike = self.pPLIF8(srb4_2_mem.detach(), srb4_2_spike.detach(), decay_vector[7], self.SRB4_conv2(srb4_1_spike)+self.SRB4_skip(srb3_2_spike))

                h1_spike_trace, h1_spike_trace_b = self.pPTRACE9(h1_spike_trace.detach(), spk_trace_decay_vector[8], srb4_2_spike, spk_threshold_vector[8], spk_a_vector[8])
                # h1_spike_trace, h1_spike_trace_b = self.pPTRACE9(h1_spike_trace.detach(), spk_trace_decay_vector[8], srb4_2_spike, 0.5)
                srb4_pool3 = self.pool3(srb4_2_spike)
                h1_mem, h1_spike = self.pPLIF9(h1_mem.detach(), h1_spike.detach(), decay_vector[8], self.fc1(srb4_pool3.view(batch_size, -1)))

                boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

                acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        c1_mem, c1_spike = self.pPLIF0(c1_mem.detach(), c1_spike.detach(),  decay_0, self.conv1(x))

        srb1_1_spike_trace, srb1_1_spike_trace_b = self.pPTRACE1(srb1_1_spike_trace.detach(), spk_trace_decay_vector[0], c1_spike, spk_threshold_vector[0], spk_a_vector[0])
        # srb1_1_spike_trace, srb1_1_spike_trace_b = self.pPTRACE1(srb1_1_spike_trace.detach(), spk_trace_decay_vector[0], c1_spike, 0.5)
        I_in = util.SpkTraceSwitchCall()((self.SRB1_conv1(c1_spike)).detach(), self.SRB1_conv1(srb1_1_spike_trace_b))
        srb1_1_mem, srb1_1_spike = self.pPLIF1(srb1_1_mem.detach(), srb1_1_spike.detach(), decay_vector[0], I_in)

        srb1_2_spike_trace, srb1_2_spike_trace_b = self.pPTRACE2(srb1_2_spike_trace.detach(), spk_trace_decay_vector[1], srb1_1_spike, spk_threshold_vector[1], spk_a_vector[1])
        # srb1_2_spike_trace, srb1_2_spike_trace_b = self.pPTRACE2(srb1_2_spike_trace.detach(), spk_trace_decay_vector[1], srb1_1_spike, 0.5)
        I_in = util.SpkTraceSwitchCall()((self.SRB1_conv2(srb1_1_spike)+self.SRB1_skip(c1_spike)).detach(), (self.SRB1_conv2(srb1_2_spike_trace_b)+self.SRB1_skip(srb1_1_spike_trace_b)))
        srb1_2_mem, srb1_2_spike = self.pPLIF2(srb1_2_mem.detach(), srb1_2_spike.detach(), decay_vector[1], I_in)

        srb2_1_spike_trace, srb2_1_spike_trace_b = self.pPTRACE3(srb2_1_spike_trace.detach(), spk_trace_decay_vector[2], srb1_2_spike, spk_threshold_vector[2], spk_a_vector[2])
        # srb2_1_spike_trace, srb2_1_spike_trace_b = self.pPTRACE3(srb2_1_spike_trace.detach(), spk_trace_decay_vector[2], srb1_2_spike, 0.5)
        srb1_pool1 = self.pool1(srb1_2_spike)
        srb1_pool1_b = self.pool1(srb2_1_spike_trace_b)
        I_in = util.SpkTraceSwitchCall()((self.SRB2_conv1(srb1_pool1)).detach(), (self.SRB2_conv1(srb1_pool1_b)))
        srb2_1_mem, srb2_1_spike = self.pPLIF3(srb2_1_mem.detach(), srb2_1_spike.detach(), decay_vector[2], I_in)

        srb2_2_spike_trace, srb2_2_spike_trace_b = self.pPTRACE4(srb2_2_spike_trace.detach(), spk_trace_decay_vector[3], srb2_1_spike, spk_threshold_vector[3], spk_a_vector[3])
        # srb2_2_spike_trace, srb2_2_spike_trace_b = self.pPTRACE4(srb2_2_spike_trace.detach(), spk_trace_decay_vector[3], srb2_1_spike, 0.5)
        I_in = util.SpkTraceSwitchCall()((self.SRB2_conv2(srb2_1_spike)+self.SRB2_skip(srb1_pool1)).detach(), (self.SRB2_conv2(srb2_2_spike_trace_b)+self.SRB2_skip(srb1_pool1_b)))
        srb2_2_mem, srb2_2_spike = self.pPLIF4(srb2_2_mem.detach(), srb2_2_spike.detach(), decay_vector[3], I_in)

        # skip block
        srb3_1_spike_trace, srb3_1_spike_trace_b = self.pPTRACE5(srb3_1_spike_trace.detach(), spk_trace_decay_vector[4], srb2_2_spike, spk_threshold_vector[4], spk_a_vector[4])
        # srb3_1_spike_trace, srb3_1_spike_trace_b = self.pPTRACE5(srb3_1_spike_trace.detach(), spk_trace_decay_vector[4], srb2_2_spike, 0.5)
        srb2_pool2 = self.pool2(srb2_2_spike)
        srb2_pool2_b = self.pool2(srb3_1_spike_trace_b)        
        I_in = util.SpkTraceSwitchCall()((self.SRB3_conv1(srb2_pool2)).detach(), (self.SRB3_conv1(srb2_pool2_b)))
        srb3_1_mem, srb3_1_spike = self.pPLIF5(srb3_1_mem.detach(), srb3_1_spike.detach(), decay_vector[4], I_in)

        srb3_2_spike_trace, srb3_2_spike_trace_b = self.pPTRACE6(srb3_2_spike_trace.detach(), spk_trace_decay_vector[5], srb3_1_spike, spk_threshold_vector[5], spk_a_vector[5])
        # srb3_2_spike_trace, srb3_2_spike_trace_b = self.pPTRACE6(srb3_2_spike_trace.detach(), spk_trace_decay_vector[5], srb3_1_spike, 0.5)
        I_in = util.SpkTraceSwitchCall()((self.SRB3_conv2(srb3_1_spike)+self.SRB3_skip(srb2_pool2)).detach(), (self.SRB3_conv2(srb3_2_spike_trace_b)+self.SRB3_skip(srb2_pool2_b)))
        srb3_2_mem, srb3_2_spike = self.pPLIF6(srb3_2_mem.detach(), srb3_2_spike.detach(), decay_vector[5], I_in)

        srb4_1_spike_trace, srb4_1_spike_trace_b = self.pPTRACE7(srb4_1_spike_trace.detach(), spk_trace_decay_vector[6], srb3_2_spike, spk_threshold_vector[6], spk_a_vector[6])
        # srb4_1_spike_trace, srb4_1_spike_trace_b = self.pPTRACE7(srb4_1_spike_trace.detach(), spk_trace_decay_vector[6], srb3_2_spike, 0.5)
        I_in = util.SpkTraceSwitchCall()((self.SRB4_conv1(srb3_2_spike)).detach(), (self.SRB4_conv1(srb4_1_spike_trace_b)))
        srb4_1_mem, srb4_1_spike = self.pPLIF7(srb4_1_mem.detach(), srb4_1_spike.detach(), decay_vector[6], I_in)

        srb4_2_spike_trace, srb4_2_spike_trace_b = self.pPTRACE8(srb4_2_spike_trace.detach(), spk_trace_decay_vector[7], srb4_1_spike, spk_threshold_vector[7], spk_a_vector[7])
        # srb4_2_spike_trace, srb4_2_spike_trace_b = self.pPTRACE8(srb4_2_spike_trace.detach(), spk_trace_decay_vector[7], srb4_1_spike, 0.5)
        I_in = util.SpkTraceSwitchCall()((self.SRB4_conv2(srb4_1_spike)+self.SRB4_skip(srb3_2_spike)).detach(), (self.SRB4_conv2(srb4_2_spike_trace_b)+self.SRB4_skip(srb4_1_spike_trace_b)))
        srb4_2_mem, srb4_2_spike = self.pPLIF8(srb4_2_mem.detach(), srb4_2_spike.detach(), decay_vector[7], I_in)

        h1_spike_trace, h1_spike_trace_b = self.pPTRACE9(h1_spike_trace.detach(), spk_trace_decay_vector[8], srb4_2_spike, spk_threshold_vector[8], spk_a_vector[8])
        # h1_spike_trace, h1_spike_trace_b = self.pPTRACE9(h1_spike_trace.detach(), spk_trace_decay_vector[8], srb4_2_spike, 0.5)
        srb4_pool3 = self.pool3(srb4_2_spike)
        srb4_pool3_b = self.pool3(h1_spike_trace_b)
        I_in = util.SpkTraceSwitchCall()((self.fc1(srb4_pool3.view(batch_size, -1))).detach(), (self.fc1(srb4_pool3_b.view(batch_size, -1))))
        h1_mem, h1_spike = self.pPLIF9(h1_mem.detach(), h1_spike.detach(), decay_vector[8], I_in)

        boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

        acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        return acc_mem, self.num_steps
        # return next - softmax and cross-entropy loss

