import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
import math
import torch.utils.model_zoo as model_zoo
from utils.common import interpolateFeatures,_get_num_features
from utils.route_norm import RouteNorm
from utils.trans_norm import TransNorm2d


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class TestBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(TestBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x, ref_x):
        residual = x
        x = torch.cat([x, ref_x], 1)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SampledAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature):
        super(SampledAttention,self).__init__()
        self.temperature = temperature

    def forward(self, q, k, v):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        probs = F.softmax(attn, 2)
        probs = probs.squeeze(1)
        indices = probs.multinomial(num_samples=1)
        output = torch.gather(v, 1, indices.unsqueeze(2).repeat(1, 1, v.size(-1)))
        return output, attn


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, dropout=0.1):
        super(ScaledDotProductAttention,self).__init__()
        self.temperature = temperature
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, q, k, v):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        attn = F.softmax(attn, 2)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn


class AttentionBasedTransfer(nn.Module):
    def __init__(self,tg_layer_id, block_hidden_dim, n_head=1,type='avg',dropout=0.0):
        super(AttentionBasedTransfer, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.block_hidden_dim = block_hidden_dim*block_hidden_dim
        self.in_dim = round(block_hidden_dim/n_head)
        self.n_head = n_head

        self.w_qs = torch.nn.Linear(self.block_hidden_dim, n_head * self.in_dim, bias=False)
        self.w_ks = torch.nn.Linear(self.block_hidden_dim, n_head * self.in_dim, bias=False)
        self.w_vs = torch.nn.Linear(self.block_hidden_dim, n_head * self.in_dim, bias=False)

        torch.nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (self.block_hidden_dim * 2)))
        torch.nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (self.block_hidden_dim * 2)))
        torch.nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (self.block_hidden_dim * 2)))
        if type == 'avg':
            self.attention = ScaledDotProductAttention(temperature=np.power(self.in_dim, 0.5), dropout=dropout)
        else:
            self.attention = SampledAttention(temperature=np.power(self.in_dim, 0.5))

        self.fc = torch.nn.Linear(n_head * self.in_dim, self.block_hidden_dim)
        self.layer_norm = torch.nn.LayerNorm(block_hidden_dim,block_hidden_dim)
        torch.nn.init.xavier_normal_(self.fc.weight)
        self.dropout = torch.nn.Dropout(dropout)
        self.attn = None

        # self.wt = nn.Parameter(torch.ones(2) * 0.5)
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, ref_inputs):
        # add sentinel item to ref_inputs to account for no attention to any items
        batch_size = input1.size(0)
        nchannel = input1.size(1)
        context = torch.stack(ref_inputs,dim=2).view(batch_size*nchannel,len(ref_inputs),self.block_hidden_dim)
        query = input1.view(batch_size*nchannel,1,self.block_hidden_dim)
        # attention over the ref_inputs
        q = self.w_qs(query).view(batch_size*nchannel,1, self.n_head, self.in_dim) # (b*c) x lq x n x dq
        k = self.w_ks(context).view(batch_size*nchannel,len(ref_inputs), self.n_head, self.in_dim) # (b*c) x lq x n x dk
        v = self.w_vs(context).view(batch_size*nchannel,len(ref_inputs), self.n_head, self.in_dim) # (b*c) x lq x n x dv

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, 1, self.in_dim)  # (n*b*c) x lq x dq
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len(ref_inputs), self.in_dim)  # (n*b*c) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len(ref_inputs), self.in_dim)  # (n*b*c) x lv x dv

        input2, attn = self.attention(q, k, v)
        self.attn = attn
        input2 = input2.view(self.n_head, batch_size*nchannel, 1,self.in_dim)
        input2 = input2.permute(1, 2, 0, 3).contiguous().view(batch_size*nchannel,1, -1)  # b x lq x (n*dv)
        input2 = self.dropout(self.fc(input2))
        input2 = input2.view(batch_size,nchannel, int(np.sqrt(self.block_hidden_dim)),int(np.sqrt(self.block_hidden_dim)))
        # input2 = val.mean(-1)
        # y = self.wt[0] * input1 + self.wt[1] * input2
        # y=self.bn(y)
        y = input1 + input2
        y = self.layer_norm(y)
        return y


class Transnorm(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(Transnorm, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.tn = TransNorm2d(planes)

    def forward(self, input1, input2):
        batch_size = input1.size()[0]
        input = torch.cat([input1,input2],dim=0)
        input = self.tn(input)
        y = input[:batch_size]
        return y


class CrossStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(CrossStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.wt = nn.Parameter(torch.ones(2) * 0.5)
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, input2):
        y = self.wt[0] * input1 + self.wt[1] * input2
        # y=self.bn(y)

        return y


class LinearXStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, tg_feat_size=48, src_feat_size=48):
        super(LinearXStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.su = nn.Linear(src_feat_size, 1)
        self.tu = nn.Linear(tg_feat_size, 1)
        nn.init.constant_(self.su.bias, 1.0)
        nn.init.constant_(self.tu.bias, 1.0)

    def forward(self, input1, input2):
        f1 = F.avg_pool2d(input1, input1.size(2)).view(-1, input1.size(1))
        f2 = F.avg_pool2d(input2, input2.size(2)).view(-1, input2.size(1))
        y = self.su(f2).reshape(input1.size(0), 1, 1, 1) * input1 + self.tu(f1).reshape(input2.size(0), 1, 1, 1) * input2
        return y


class LinearStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, tg_feat_size=48, src_feat_size=48):
        super(LinearStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.su = nn.Linear(src_feat_size, 1)
        self.tu = nn.Linear(tg_feat_size, 1)
        nn.init.constant_(self.su.bias, 1.0)
        nn.init.constant_(self.tu.bias, 1.0)

    def forward(self, input1, input2):
        f1 = F.avg_pool2d(input1, input1.size(2)).view(-1, input1.size(1))
        f2 = F.avg_pool2d(input2, input2.size(2)).view(-1, input2.size(1))
        y = self.tu(f1).reshape(input1.size(0), 1, 1, 1) * input1 + self.su(f2).reshape(input2.size(0), 1, 1, 1) * input2
        return y


class ShareableResidualLayer(nn.Sequential):
    def __init__(self, id, block, inplanes, planes, blocks, block_hidden_dim,stride=1, decisioner=None,transfer_types=None, source_input_info=None, ru_units=True):
        self.id = id
        self.transfer_types = transfer_types
        self.inplanes = inplanes
        self.num_blocks = blocks
        self.block_hidden_dim = block_hidden_dim
        self.decisioner = decisioner
        self.route_wts = []
        self.ru_units = ru_units

        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        super(ShareableResidualLayer, self).__init__(*layers)

        if self.ru_units:
            self.ru = nn.ModuleList([])
        self.cs = nn.ModuleDict({})
        self.id2mod = None
        if transfer_types:
            print('Layer ' + str(id) + ' using ' + str(transfer_types) + '.')
            assert source_input_info is not None
            source_feature_ids, source_input_channels = source_input_info
            self.id2mod = {ids: i for i, ids in enumerate(source_feature_ids)}  # maps arm d to the module
            self.id2mod[-1] = -1
            self.source_input_channels = source_input_channels  # [64, 64, 128, 256, 512]
            # self.src_chan_vec = [200704, 200704, 50176, 12544]  # This has to be computed for each dataset

            if self.ru_units:
                self.ru.extend([nn.Sequential(
                    nn.Conv2d(src_planes, planes, 1),
                    nn.BatchNorm2d(planes),
                ) for src_planes in self.source_input_channels])
            for transfer_type in set(self.transfer_types):
                if transfer_type == 'combine': # No RU
                    self.cs['combine'] = nn.ModuleList([nn.Sequential(
                        nn.Conv2d(inplanes+src_planes, inplanes, 1),
                        nn.BatchNorm2d(inplanes),
                    ) for src_planes in self.source_input_channels])
                elif transfer_type == 'shared':
                    self.cs['shared'] = nn.ModuleList([nn.Sequential(
                        nn.Conv2d(inplanes*2, inplanes, 1),
                        nn.BatchNorm2d(inplanes))
                        for i in range(len(self.source_input_channels))])
                elif transfer_type == 'simpleblock':
                    self.cs['residual'] = TestBlock(inplanes*2, inplanes, 1)
                elif transfer_type == 'block':
                    self.cs['block'] = nn.ModuleList([TestBlock(inplanes*2, inplanes, 1)
                                                     for i in range(len(self.source_input_channels))])
                elif transfer_type == 'transnorm':
                    self.cs['transnorm'] = nn.ModuleList(
                        [Transnorm(id, i, inplanes) for i in range(len(self.source_input_channels))])
                elif transfer_type == 'xstitch':
                    self.cs['xstitch'] = nn.ModuleList(
                        [CrossStitch(id, i, inplanes) for i in range(len(self.source_input_channels))])
                elif transfer_type == 'linstitch':
                    self.cs['linstitch'] = nn.ModuleList(
                        [LinearXStitch(id, i, inplanes, inplanes) for i in range(len(self.source_input_channels))])
                elif transfer_type == 'attention':
                    # sample or avg type
                    self.cs['attention'] = AttentionBasedTransfer(id, self.block_hidden_dim,n_head=8,type='sample')
                elif transfer_type == 'routenorm':
                    self.cs['routenorm'] = nn.ModuleList(
                        [RouteNorm(inplanes) for i in range(len(self.source_input_channels))])
                    # self.cs.extend([RouteNorm(inplanes, self.src_chan_vec[id - 1], mode='all') for i in range(len(self.src_planes_sizes))])
            # elif self.transfer_type == 'routestitch2':
            #     self.ru.extend([nn.Sequential(
            #         nn.Conv2d(src_planes, inplanes, 1),
            #         RouteNorm(inplanes),
            #         # RouteNorm(inplanes, self.src_chan_vec[id - 1], mode='all')
            #     ) for src_planes in self.source_input_channels])

    def forward(self, input, src_input=None):
        ref_id, transfer_type, ref_input = -1, '', None
        if src_input is not None:
            src_id, transfer_type, ref_input = src_input
            ref_id = self.id2mod[src_id] if self.id2mod else -1
        
        for idx in range(self.num_blocks):
            input = self._modules.get(str(idx))(input)
        
        output = input
        if ref_input is not None:
            if self.ru_units:
                if type(ref_input) is not list:
                    ref_input = interpolateFeatures(ref_input, input.size(3)) # Matches the last two dimensions
                    ref_input = self.ru[ref_id](ref_input)
                else:
                    ref_input = [interpolateFeatures(rin, input.size(3)) for rin in ref_input]
                    ref_input = [rproj(rin) for (rin, rproj) in zip(ref_input, self.ru)]
                    ref_input.append(torch.zeros_like(input)) # PASS option or sentinel attention
            if ref_id == -1 and transfer_type != 'attention':
                # take expectation over the ref_inputs
                ref_input = torch.stack(ref_input,dim=-1)
                p = torch.from_numpy(self.decisioner.probabilities).float().to(input.device)
                ref_input = torch.matmul(ref_input,p)
                ref_id = self.id-1

            if transfer_type == 'combine':
                output = torch.cat([input,ref_input],1) # Input channels + Ref_input Channels
                output = self.cs['combine'][ref_id](output)
            elif transfer_type == 'shared':
                output = torch.cat([input,self.ru[ref_id](ref_input)],1) # Input channels + Ref_input Channels
                output = self.cs['shared'][ref_id](output)
            elif transfer_type == 'simpleblock':
                # output = self.cs['simpleblock'](input, self.ru[ref_id](ref_input))
                output = self.cs['simpleblock'](input, ref_input)
            elif transfer_type in ['block','transnorm', 'xstitch', 'linstitch', 'routenorm']:
                # output=self.cs[transfer_type][ref_id](input, self.ru[ref_id](ref_input))
                output = self.cs[transfer_type][ref_id](input, ref_input)
            elif transfer_type in ['attention']:
                # proj_ref_input=[torch.zeros_like(input)] # PASS option or sentinel attention
                # for (rin, rproj) in zip(ref_input,self.ru):
                    # rin = interpolateFeatures(rin, input.size(3))
                    # proj_ref_input.append(rproj(rin))
                output = self.cs[transfer_type](input, ref_input)
        return output


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, name="resnet18", transfer_types=None, source_info=None, ru_units=True):
        super(ResNet, self).__init__()

        self.num_classes = num_classes
        self.inplanes = 64
        self.nlayers = len(layers)


        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        if transfer_types:
            if source_info:
                sic = _get_num_features(name)
                self.source_feature_ids, self.source_input_pass_id, decisioners = source_info
                self.source_input_channels = [sic[ids] for ids in self.source_feature_ids]
                source_input_info = (self.source_feature_ids, self.source_input_channels)
                self.transfer_types = transfer_types
                self.ru_units = ru_units

                self.layer1 = ShareableResidualLayer(1, block, self.inplanes, 64, layers[0],block_hidden_dim=56,
                                                     decisioner=decisioners[0],
                                                     transfer_types=self.transfer_types,
                                                     source_input_info=source_input_info,
                                                     ru_units=self.ru_units)
                self.inplanes = 64 * block.expansion
                self.layer2 = ShareableResidualLayer(2, block, self.inplanes, 128, layers[1], stride=2,block_hidden_dim=56,
                                                     decisioner=decisioners[1],
                                                     transfer_types=self.transfer_types,
                                                     source_input_info=source_input_info,
                                                     ru_units=self.ru_units)
                self.inplanes = 128 * block.expansion
                self.layer3 = ShareableResidualLayer(3, block, self.inplanes, 256, layers[2], stride=2,block_hidden_dim=28,
                                                     decisioner=decisioners[2],
                                                     transfer_types=self.transfer_types,
                                                     source_input_info=source_input_info,
                                                     ru_units=self.ru_units)
                self.inplanes = 256 * block.expansion
                self.layer4 = ShareableResidualLayer(4, block, self.inplanes, 512, layers[3], stride=2,block_hidden_dim=14,
                                                     decisioner=decisioners[3],
                                                     transfer_types=self.transfer_types,
                                                     source_input_info=source_input_info,
                                                     ru_units=self.ru_units)
                self.inplanes = 512 * block.expansion
            else: # Spottune
                inplanes = self.inplanes
                self.layer1 = self._make_layer(block, 64, layers[0])
                self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
                self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
                self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

                self.inplanes = inplanes
                self.parallel_layer1 = self._make_layer(block, 64, layers[0])
                self.parallel_layer2 = self._make_layer(block, 128, layers[1], stride=2)
                self.parallel_layer3 = self._make_layer(block, 256, layers[2], stride=2)
                self.parallel_layer4 = self._make_layer(block, 512, layers[3], stride=2)
                self.parallel_layers = [self.parallel_layer1, self.parallel_layer2, self.parallel_layer3, self.parallel_layer4]
        else:
            self.layer1 = self._make_layer(block, 64, layers[0])
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.layers = [self.layer1,self.layer2,self.layer3,self.layer4]

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
                    m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x, ref_features=None, pairs=None, weights=None):
        c0 = self.conv1(x)
        b0 = self.bn1(c0)
        r0 = self.relu(b0)
        p0 = self.maxpool(r0)
        input = p0
        s_feat = []
        if ref_features is None:
            if weights is None:
                f1 = self.layer1(p0)
                f2 = self.layer2(f1)
                f3 = self.layer3(f2)
                f4 = self.layer4(f3)
            else:  # Spottune
                assert weights is not None
                for t, (layer, parallel_layer) in enumerate(zip(self.layers, self.parallel_layers)):
                    action = weights[:, t].contiguous()
                    action_mask = action.float().view(-1, 1, 1, 1)
                    output = layer(input)
                    parallel_output = parallel_layer(input)
                    input = output * (1 - action_mask) + parallel_output * action_mask
                    s_feat.append(input)
                [f1, f2, f3, f4] = s_feat

        else: # Auto-transfer
            assert pairs is not None
            # pairs format (target_layer_id, source_feature_id, transfer_type)
            # print(pairs)
            # layer_func = [self.layer1, self.layer2, self.layer3, self.layer4]
            for l, (target_layer_id, source_feature_id, transfer_type) in enumerate(pairs):
                ref_input = None
                if source_feature_id != self.source_input_pass_id:
                    if source_feature_id == -1:
                        ref_input = [ref_features[sid] for sid in self.source_feature_ids]
                        # take expectation over the available features
                    else:
                        ref_input = ref_features[source_feature_id]
                else:
                    source_feature_id = self.source_feature_ids[l]
                    if transfer_type in ['combine', 'shared', 'block']:
                        ref_input = torch.zeros(
                            [input.size(0), self.source_input_channels[l], input.size(2), input.size(3)],
                            dtype=input.dtype, device=device)
                input = self.layers[l](input, (source_feature_id, transfer_type, ref_input))
                s_feat.append(input)
            [f1, f2, f3, f4] = s_feat

        f5 = self.avgpool(f4)
        f5 = f5.view(f5.size(0), -1)
        out = self.fc(f5)
        return out, [r0, f1, f2, f3, f4]

    def forward_with_features(self, x):
        return self.forward(x)


def resnet10(pretrained=False, **kwargs):
    """Constructs a ResNet-10 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet10"
    model = ResNet(BasicBlock, [1, 1, 1, 1], name=name,**kwargs)
    if pretrained:
        print('Pretraining not supported now')
    return model


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet18"
    model = ResNet(BasicBlock, [2, 2, 2, 2], name=name,**kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet34"
    model = ResNet(BasicBlock, [3, 4, 6, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet50"
    model = ResNet(Bottleneck, [3, 4, 6, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet101"
    model = ResNet(Bottleneck, [3, 4, 23, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet152"
    model = ResNet(Bottleneck, [3, 8, 36, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model
