import torch
import torch.nn as nn
from copy import deepcopy
from modules import layer

__all__ = ['OPZOSpikingNFResNet', 'opzo_spiking_nfresnet18', 'opzo_spiking_nfresnet34', 'opzo_spiking_nfresnet50', 'opzo_spiking_nfresnet101',
           'opzo_spiking_nfresnet152', 'opzo_spiking_nfresnext50_32x4d', 'opzo_spiking_nfresnext101_32x8d',
           'opzo_spiking_wide_nfresnet50_2', 'opzo_spiking_wide_nfresnet101_2']


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=True, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)


def wsconv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return layer.WSConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=True, gain=True, dilation=dilation)


def wsconv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return layer.WSConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True, gain=True)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, weight_standardization=True, beta=1.0, alpha=1.0,
                 spiking_neuron: callable = None, stochdepth_rate=0.0, 
                 feedback_mode='PZO', momentum_fb=0.99999, p_scale=0.2, h_in=224, w_in=224, num_classes=1000, 
                 local_loss=False, **kwargs):
        super(BasicBlock, self).__init__()
        self.stochdepth_rate = stochdepth_rate

        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        if weight_standardization:
            self.conv1 = wsconv3x3(inplanes, planes, stride)
            self.conv2 = wsconv3x3(planes, planes)
            self.sn_scale = 2.74
        else:
            self.conv1 = conv3x3(inplanes, planes, stride)
            self.conv2 = conv3x3(planes, planes)
            self.sn_scale = 1.
        self.sn1 = spiking_neuron(**kwargs)
        self.sn2 = spiking_neuron(**kwargs)
        self.downsample = downsample
        self.stride = stride

        self.beta, self.alpha = beta, alpha
        self.skipinit_gain = nn.Parameter(torch.zeros(()))

        self.feedback_mode = feedback_mode
        self.momentum_fb = momentum_fb
        self.p_scale = p_scale
        if feedback_mode in ['DFA', 'PZO']:
            hw_dim = h_in * w_in // (stride**2)
            self.fb = nn.ModuleList()
            self.fb.append(nn.Linear(num_classes, hw_dim*planes, bias=False))
            self.fb.append(nn.Linear(num_classes, hw_dim*planes, bias=False))

            if feedback_mode == 'DFA':
                for i in range(len(self.fb)):
                    nn.init.kaiming_normal_(self.fb[i].weight, mode='fan_out')
            else:
                for i in range(len(self.fb)):
                    nn.init.constant_(self.fb[i].weight, 0)

        # for antithetic
        self.perturb_var_list = []

        # local loss
        self.local_loss = local_loss
        if local_loss:
            hw_dim = h_in * w_in // (stride**2)
            self.local_fc = nn.ModuleList()
            self.local_fc.append(nn.Linear(hw_dim*planes, num_classes))
            self.local_fc.append(nn.Linear(hw_dim*planes, num_classes))

            self.ll_lambda = 0.01

    def forward(self, x, bp=False, deterministic=True, antithetic=False, perturb_before_neuron=True, only_local_loss=False):

        if bp or (not self.training):
            conv1 = layer.OTTTSequential(*[self.conv1])
            conv2 = layer.OTTTSequential(*[self.conv2])
            downsample = layer.OTTTSequential(*[self.downsample])

            out = x / self.beta
            out = self.sn1(out)
            if isinstance(out, list):
                out[0] = out[0] * self.sn_scale
                out[1] = out[1] * self.sn_scale
            else:
                out = out * self.sn_scale

            if self.downsample is not None:
                identity = downsample(out)
            else:
                identity = x

            out = conv1(out)
            if (not deterministic) and perturb_before_neuron:
                perturb = self.get_perturb(out, self.p_scale)
                out = out + perturb
            out = self.sn2(out)
            if (not deterministic) and (not perturb_before_neuron):
                if isinstance(out, list):
                    perturb = self.get_perturb(out[0], self.p_scale)
                    out[0] = out[0] + perturb
                    out[1] = out[1] + perturb
                else:
                    perturb = self.get_perturb(out, self.p_scale)
                    out = out + perturb
            if isinstance(out, list):
                out[0] = out[0] * self.sn_scale
                out[1] = out[1] * self.sn_scale
            else:
                out = out * self.sn_scale
            out = conv2(out)
            out = out * self.skipinit_gain * self.alpha + identity
            if (not deterministic):
                perturb = self.get_perturb(out, self.p_scale)
                out = out + perturb

            return out
        else:
            in_for_grad_list = []
            perturb_list = []
            sg_list = []
            if antithetic:
                if len(self.perturb_var_list) > 0:
                    use_antithetic = True
                    self.perturb_var_index = 0
                else:
                    use_antithetic = False
            if self.local_loss:
                local_output = []

            # forward propagation
            with torch.no_grad():
                out = x / self.beta
                # sn1
                out, sg = self.sn1(out, return_grad=True)
                out[0] *= self.sn_scale
                out[1] *= self.sn_scale

                in_for_grad_list.append(out[1])
                out = out[0]

                # downsample
                if self.downsample is not None:
                    identity = self.downsample(out)
                else:
                    identity = x

                # conv1
                out = self.conv1(out)
                if perturb_before_neuron and ((self.feedback_mode in ['PZO', 'ZO']) or (not deterministic)):
                    if antithetic:
                        if use_antithetic:
                            perturb = -self.perturb_var_list[self.perturb_var_index]
                            self.perturb_var_index += 1
                        else:
                            perturb = self.get_perturb(out, self.p_scale)
                            self.perturb_var_list.append(perturb)
                    else:
                        perturb = self.get_perturb(out, self.p_scale)
                    out += perturb
                    if self.feedback_mode in ['PZO', 'ZO']:
                        perturb_list.append(perturb)

                # sn2
                out, sg = self.sn2(out, return_grad=True)
                sg_list.append(sg)
                if self.local_loss:
                    local_output.append(out[0].flatten(1))
                if (not perturb_before_neuron) and ((self.feedback_mode in ['PZO', 'ZO']) or (not deterministic)):
                    if antithetic:
                        if use_antithetic:
                            perturb = -self.perturb_var_list[self.perturb_var_index]
                            self.perturb_var_index += 1
                        else:
                            perturb = self.get_perturb(out[0], self.p_scale)
                            self.perturb_var_list.append(perturb)
                    else:
                        perturb = self.get_perturb(out[0], self.p_scale)
                    out[0] += perturb
                    out[1] += perturb
                    if self.feedback_mode in ['PZO', 'ZO']:
                        perturb_list.append(perturb)
                out[0] *= self.sn_scale
                out[1] *= self.sn_scale
                in_for_grad_list.append(out[1])
                out = out[0]

                # conv2
                out = self.conv2(out)
                # residual
                out *= self.skipinit_gain * self.alpha
                out += identity

                sg_list.append(torch.ones_like(out))
                if self.local_loss:
                    local_output.append(out.flatten(1))
                if (self.feedback_mode in ['PZO', 'ZO']) or (not deterministic):
                    if antithetic:
                        if use_antithetic:
                            perturb = -self.perturb_var_list[self.perturb_var_index]
                            self.perturb_var_index += 1
                        else:
                            perturb = self.get_perturb(out, self.p_scale)
                            self.perturb_var_list.append(perturb)
                    else:
                        perturb = self.get_perturb(out, self.p_scale)
                    out += perturb
                    if self.feedback_mode in ['PZO', 'ZO']:
                        perturb_list.append(perturb)

            # clear antithetic
            if antithetic and use_antithetic:
                self.perturb_var_list = []

            if self.local_loss:
                return out, perturb_list, in_for_grad_list, sg_list, local_output
            else:
                return out, perturb_list, in_for_grad_list, sg_list

    def cal_grad(self, grad_last, in_for_grad_list, perturb_list, sg_list, loss=None, local_loss=False, local_output=None, loss_func=None):
        with torch.no_grad():
            if self.feedback_mode == 'PZO':
                for i in range(len(self.fb)):
                    perturb_list[i] = self.fb[i](grad_last).reshape(perturb_list[i].shape) * sg_list[i]
            elif self.feedback_mode == 'DFA':
                for i in range(len(self.fb)):
                    perturb_list.append(self.fb[i](grad_last).reshape(sg_list[i].shape) * sg_list[i])
            else:
                for i in range(len(perturb_list)):
                    perturb_list[i] = perturb_list[i] * loss * sg_list[i] / self.p_scale

        if local_loss:
            with torch.enable_grad():
                for lo in local_output:
                    lo.requires_grad_(True)
                loss_local = loss_func(self.local_fc[0](local_output[0])) + loss_func(self.local_fc[1](local_output[1]))
            loss_local.backward()

            for i in range(len(perturb_list)):
                perturb_list[i] = perturb_list[i] + self.ll_lambda * local_output[i].grad.reshape(perturb_list[i].shape) * sg_list[i]

        self.set_grad(self.conv1, in_for_grad_list[0], perturb_list[0])
        self.set_grad(self.conv2, in_for_grad_list[1], perturb_list[1], op_scale=self.skipinit_gain*self.alpha)
        if self.downsample is not None:
            self.set_grad(self.downsample, in_for_grad_list[0], perturb_list[1])

    def update_momentum_feedback(self, perturb_list, output):
        if self.feedback_mode == 'PZO':
            for i in range(len(self.fb)):
                #new_weight = perturb_list[i].flatten(1).t().mm(output) / output.shape[0] / self.p_scale
                new_weight = perturb_list[i].flatten(1).t().mm(output) * ((1 - self.momentum_fb) / output.shape[0] / self.p_scale)
                self.fb[i].weight.data *= self.momentum_fb
                #self.fb[i].weight.data += (1 - self.momentum_fb) * new_weight
                self.fb[i].weight.data += new_weight

    def get_perturb(self, x, p_scale):
        perturb = torch.randn_like(x)
        return perturb * p_scale

    def set_p_scale(self, p_scale):
        self.p_scale = p_scale

    def set_grad(self, op, in_for_grad, grad, op_scale=1.):
        with torch.enable_grad():
            in_for_grad.requires_grad_(True)
            tmp = op(in_for_grad) * op_scale
        tmp.backward(grad)

    def get_spike(self):
        spikes = []
        spike = self.sn1.spike.cpu()
        spikes.append(spike.reshape(spike.shape[0], -1))
        spike = self.sn2.spike.cpu()
        spikes.append(spike.reshape(spike.shape[0], -1))
        return spikes


class SequentialModule(nn.Sequential):
    def forward(self, input, **kwargs):
        for module in self._modules.values():
            input = module(input, **kwargs)
        return input

    def get_spike(self):
        spikes = []
        for module in self._modules.values():
            spikes_module = module.get_spike()
            spikes += spikes_module
        return spikes


class OPZOSpikingNFResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, c_in=3, 
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 weight_standardization=True, spiking_neuron: callable = None,
                 alpha=0.2, drop_rate=0.0, 
                 feedback_mode='PZO', momentum_fb=0.99999, p_scale=0.2, h_in=224, w_in=224, local_loss=False, 
                 **kwargs):
        super(OPZOSpikingNFResNet, self).__init__()
        self.ws = weight_standardization
        self.alpha = alpha
        self.drop_rate = drop_rate

        self.feedback_mode = feedback_mode
        self.h_in, self.w_in = h_in, w_in
        self.local_loss = local_loss

        self.num_classes = num_classes

        self.momentum_fb = momentum_fb
        self.p_scale = p_scale

        self.inplanes = 64
        self.dilation = 1
        self.c_in = c_in
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        if weight_standardization:
            self.conv1 = layer.WSConv2d(self.c_in, self.inplanes, kernel_size=7, stride=2, padding=3, bias=True, gain=True)
            self.sn_scale = 2.74
        else:
            self.conv1 = nn.Conv2d(self.c_in, self.inplanes, kernel_size=7, stride=2, padding=3, bias=True)
            self.sn_scale = 1.
        h_in, w_in = h_in // 2, w_in // 2
        #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.maxpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        h_in, w_in = h_in // 2, w_in // 2
        expected_var = 1.0
        self.layer1, expected_var = self._make_layer(block, 64, layers[0], alpha=self.alpha, var=expected_var, spiking_neuron=spiking_neuron, h_in=h_in, w_in=w_in, **kwargs)
        self.layer2, expected_var = self._make_layer(block, 128, layers[1], stride=2, alpha=self.alpha, var=expected_var,
                                                     dilate=replace_stride_with_dilation[0], spiking_neuron=spiking_neuron, h_in=h_in, w_in=w_in, **kwargs)
        h_in, w_in = h_in // 2, w_in // 2
        self.layer3, expected_var = self._make_layer(block, 256, layers[2], stride=2, alpha=self.alpha, var=expected_var,
                                                     dilate=replace_stride_with_dilation[1], spiking_neuron=spiking_neuron, h_in=h_in, w_in=w_in, **kwargs)
        h_in, w_in = h_in // 2, w_in // 2
        self.layer4, expected_var = self._make_layer(block, 512, layers[3], stride=2, alpha=self.alpha, var=expected_var,
                                                     dilate=replace_stride_with_dilation[2], spiking_neuron=spiking_neuron, h_in=h_in, w_in=w_in, **kwargs)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        #if self.drop_rate > 0.0:
        #    self.dropout = nn.Dropout(drop_rate)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        torch.nn.init.zeros_(self.fc.weight)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)

        if feedback_mode in ['DFA', 'PZO']:
            # for conv1 & maxpool
            hw_dim = self.h_in * self.w_in // 16
            self.fb1 = nn.Linear(num_classes, hw_dim * 64, bias=False)
            if feedback_mode == 'DFA':
                nn.init.kaiming_normal_(self.fb1.weight, mode='fan_out')
            else:
                nn.init.constant_(self.fb1.weight, 0)

        # for antithetic
        self.perturb_var_list = []

        if local_loss:
            # for conv1 & maxpool
            hw_dim = self.h_in * self.w_in // 16
            self.local_fc1 = nn.Linear(hw_dim * self.inplanes, num_classes, bias=False)
            self.ll_lambda = 0.01


    def _make_layer(self, block, planes, blocks, stride=1, alpha=1.0, var=1.0, dilate=False, spiking_neuron: callable = None, h_in=56, w_in=56, **kwargs):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            if self.ws:
                downsample = wsconv1x1(self.inplanes, planes * block.expansion, stride)
            else:
                downsample = conv1x1(self.inplanes, planes * block.expansion, stride)

        layers = []
        beta = var ** 0.5
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, self.ws, beta, alpha, spiking_neuron, 
                            feedback_mode=self.feedback_mode, momentum_fb=self.momentum_fb, p_scale=self.p_scale, 
                            h_in=h_in, w_in=w_in, num_classes=self.num_classes, local_loss=self.local_loss, **kwargs))
        self.inplanes = planes * block.expansion
        h_in, w_in = h_in // stride, w_in // stride
        if downsample != None:
            var = 1. + self.alpha ** 2
        else:
            var += self.alpha ** 2
        for _ in range(1, blocks):
            beta = var ** 0.5
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                weight_standardization=self.ws, beta=beta, alpha=alpha, spiking_neuron=spiking_neuron, 
                                feedback_mode=self.feedback_mode, momentum_fb=self.momentum_fb, p_scale=self.p_scale, 
                                h_in=h_in, w_in=w_in, num_classes=self.num_classes, local_loss=self.local_loss, **kwargs))
            var += self.alpha ** 2

        return SequentialModule(*layers), var

    def forward(self, x, bp=False, loss_func=None, deterministic=True, antithetic=True, perturb_before_neuron=False, only_local_loss=False):
        if bp or (not self.training):
            x = self.conv1(x)
            x = self.maxpool(x)
            if not deterministic:
                perturb = self.get_perturb(x, self.p_scale)
                x = x + perturb
            x = self.layer1(x, bp=bp, deterministic=deterministic, antithetic=antithetic, perturb_before_neuron=perturb_before_neuron, only_local_loss=only_local_loss)
            x = self.layer2(x, bp=bp, deterministic=deterministic, antithetic=antithetic, perturb_before_neuron=perturb_before_neuron, only_local_loss=only_local_loss)
            x = self.layer3(x, bp=bp, deterministic=deterministic, antithetic=antithetic, perturb_before_neuron=perturb_before_neuron, only_local_loss=only_local_loss)
            x = self.layer4(x, bp=bp, deterministic=deterministic, antithetic=antithetic, perturb_before_neuron=perturb_before_neuron, only_local_loss=only_local_loss)
            x = self.avgpool(x)
            x = x.flatten(1)
            x = self.fc(x)
            
            if loss_func is not None:
                loss = loss_func(x)
                if self.training:
                    loss.backward()
                return x, loss.item()
            else:
                return x
        else:
            with torch.no_grad():
                perturb_list = []
                in_for_grad_list = []
                sg_list = []
                if antithetic:
                    if len(self.perturb_var_list) > 0:
                        use_antithetic = True
                        self.perturb_var_index = 0
                    else:
                        use_antithetic = False
                if self.local_loss:
                    local_output = []

                # conv1 & maxpool
                in_for_grad_list.append(x)
                x = self.conv1(x)
                x = self.maxpool(x)
                if (self.feedback_mode in ['PZO', 'ZO']) or (not deterministic):
                    if antithetic:
                        if use_antithetic:
                            perturb = -self.perturb_var_list[self.perturb_var_index]
                            self.perturb_var_index += 1
                        else:
                            perturb = self.get_perturb(x, self.p_scale)
                            self.perturb_var_list.append(perturb)
                    else:
                        perturb = self.get_perturb(x, self.p_scale)
                    x += perturb
                    perturb_list.append(perturb)
                else:
                    perturb_list.append(None)
                if self.local_loss:
                    local_output.append(x.flatten(1))
                sg_list.append(torch.ones_like(x))

                alllayers = []
                for block in self.layer1:
                    alllayers.append(block)
                for block in self.layer2:
                    alllayers.append(block)
                for block in self.layer3:
                    alllayers.append(block)
                for block in self.layer4:
                    alllayers.append(block)
                # layers
                for block in alllayers:
                    if self.local_loss:
                        x, perturb_list_, in_for_grad_list_, sg_list_, local_output_ = block(x, bp, deterministic, antithetic, perturb_before_neuron, only_local_loss)
                        local_output.append(local_output_)
                    else:
                        x, perturb_list_, in_for_grad_list_, sg_list_ = block(x, bp, deterministic, antithetic, perturb_before_neuron, only_local_loss)

                    in_for_grad_list.append(in_for_grad_list_)
                    perturb_list.append(perturb_list_)
                    sg_list.append(sg_list_)

                # classifier
                x = torch.flatten(self.avgpool(x), 1)
                in_for_grad_list.append(x)
                x = self.fc(x)

            with torch.enable_grad():
                x.requires_grad_(True)
                loss = loss_func(x)
                # local loss for conv1
                if self.local_loss:
                    local_output[0].requires_grad_(True)
                    loss_local = loss_func(self.local_fc1(local_output[0]))
            loss.backward()
            if self.local_loss:
                loss_local.backward()

            grad_last = x.grad.data

            with torch.no_grad():
                if self.feedback_mode == 'PZO':
                    # update momentum feedback
                    # conv1
                    #new_weight = perturb_list[0].flatten(1).t().mm(x) / x.shape[0] / self.p_scale
                    new_weight = perturb_list[0].flatten(1).t().mm(x) * ((1 - self.momentum_fb) / x.shape[0] / self.p_scale)
                    self.fb1.weight.data *= self.momentum_fb
                    #self.fb1.weight.data += (1 - self.momentum_fb) * new_weight
                    self.fb1.weight.data += new_weight

                    # layers
                    for i in range(len(alllayers)):
                        block = alllayers[i]
                        block.update_momentum_feedback(perturb_list[i+1], x)

            # calculate gradients
            # layers
            for i in range(len(alllayers)):
                block = alllayers[i]
                if self.local_loss:
                    lo = local_output[i + 1]
                else:
                    lo = None
                block.cal_grad(grad_last, in_for_grad_list[i+1], perturb_list[i+1], sg_list[i+1], loss, self.local_loss, lo, loss_func)

            # conv1
            with torch.no_grad():
                if self.feedback_mode in ['PZO', 'DFA']:
                    perturb_list[0] = self.fb1(grad_last).reshape(perturb_list[0].shape) * sg_list[0]
                else:
                    perturb_list[0] = perturb_list[0] * loss * sg_list[0] / (self.p_scale**2)

            if self.local_loss:
                perturb_list[0] = perturb_list[0] + self.ll_lambda * local_output[0].grad.reshape(perturb_list[i].shape) * sg_list[0]
            self.set_grad(nn.Sequential(self.conv1, self.maxpool), in_for_grad_list[0], perturb_list[0])

            # fc
            self.set_grad(self.fc, in_for_grad_list[-1], grad_last)

            # clear antithetic
            if antithetic and use_antithetic:
                self.perturb_var_list = []

            return x, loss.item()

    def get_perturb(self, x, p_scale):
        perturb = torch.randn_like(x)
        return perturb * p_scale

    def set_grad(self, op, in_for_grad, grad):
        with torch.enable_grad():
            in_for_grad.requires_grad_(True)
            tmp = op(in_for_grad)
        tmp.backward(grad)


def _opzo_spiking_resnet(arch, block, layers, pretrained, progress, spiking_neuron, **kwargs):
    model = OPZOSpikingNFResNet(block, layers, spiking_neuron=spiking_neuron, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def opzo_spiking_nfresnet18(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    return _opzo_spiking_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_nfresnet34(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    return _opzo_spiking_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_nfresnet50(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    return _opzo_spiking_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_nfresnet101(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    return _opzo_spiking_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_nfresnet152(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    return _opzo_spiking_resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_nfresnext50_32x4d(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _opzo_spiking_resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_nfresnext101_32x8d(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _opzo_spiking_resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_wide_nfresnet50_2(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    kwargs['width_per_group'] = 64 * 2
    return _opzo_spiking_resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)

def opzo_spiking_wide_nfresnet101_2(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
    kwargs['width_per_group'] = 64 * 2
    return _opzo_spiking_resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs)

