import torch.nn as nn
import math

from models.utils.ResNetConfig import get_resnet_config
from models.utils.utils import get_activation_function, make_conv_block


class PreActResNet(nn.Module):
    def __init__(self, depth, num_classes, activation_type, oper_order='cba', dataset='cifar10', block_type='Basic',
                 tau=None):
        super(PreActResNet, self).__init__()

        self.block_depth, self.plane_list, self.strides, self.oper_order, block = \
            get_resnet_config(arch=type(self).__name__, block_type=block_type,
                              depth=depth, dataset=dataset, oper_order=oper_order)

        self.activation_generator = get_activation_function(activation_type, tau=tau)
        self.dataset = dataset
        self.inplanes = self.plane_list[0]

        self.layer_num = len(self.plane_list)
        self.layer_depth = len(self.plane_list)

        self.layer0 = self._make_layer0(depth)
        for layer_ind in range(0, self.layer_depth):
            self.__setattr__('layer' + str(layer_ind + 1), self._make_layer(
                block, self.plane_list[layer_ind], self.block_depth[layer_ind], self.strides[layer_ind]))
        self.last = self._make_last_layer()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.plane_list[-1] * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):                
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
                    m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def _make_layer0(self):
        if self.dataset == 'cifar10':
            layer0 = make_conv_block(3, self.inplanes, self.activation_generator, kernel_size=3, stride=1,
                                     padding=1, oper_order=self.oper_order['front1'])

        elif self.dataset == 'cifar100':
            layer0 = make_conv_block(3, self.inplanes, self.activation_generator, kernel_size=3, stride=1,
                                     padding=1, oper_order=self.oper_order['front1'])

        elif self.dataset == "tinyImageNet":
            layer0 = make_conv_block(3, self.inplanes, self.activation_generator, kernel_size=3, stride=1,
                                     padding=1, oper_order=self.oper_order['front1'])
        elif self.dataset == "ImageNet":
            layer0 = make_conv_block(3, self.inplanes, self.activation_generator, kernel_size=7, stride=2,
                                     padding=3, oper_order=self.oper_order['full'])
            layer0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        return layer0

    def _make_layer(self, block, planes, blocks, stride=1):
        layers = []
        strides = [stride] + [1] * (blocks - 1)

        for stride in strides:
            shortcut = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                shortcut = make_conv_block(self.inplanes, planes * block.expansion, self.activation_generator,
                                             kernel_size=1, stride=stride, padding=0,
                                             oper_order=self.oper_order['front1'])

            layers.append(block(self.inplanes, planes, self.activation_generator, self.oper_order, stride, shortcut))
            self.inplanes = planes * block.expansion

        return nn.Sequential(*layers)

    def _make_last_layer(self):
        last_layer = make_conv_block(self.inplanes, self.inplanes, self.activation_generator,
                                      oper_order=self.oper_order['end2'])
        return last_layer

    def forward(self, x):
        for i in range(0, self.layer_depth + 1):
            x = getattr(self, 'layer' + str(i))(x)
        x = self.last(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        x = self.fc(x)

        return x