import torch
import torch.nn as nn


__all__ = ['resnet20', 'resnet32', 'resnet44', 'resnet56']

NUM_CLASSES = 10


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.block_gates = block_gates
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=False)  # To enable layer removal inplace must be False
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = out = x

        if self.block_gates[0]:
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu1(out)

        if self.block_gates[1]:
            out = self.conv2(out)
            out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu2(out)

        return out


class ResNetCifar(nn.Module):

    def __init__(self, block, layers, extra_conv, extra_fc):
        self.extra_conv = extra_conv
        self.extra_fc = extra_fc
        self.nlayers = 0
        # Each layer manages its own gates
        self.layer_gates = []
        for layer in range(3):
            # For each of the 3 layers, create block gates: each block has two layers
            self.layer_gates.append([])  # [True, True] * layers[layer])
            for blk in range(layers[layer]):
                self.layer_gates[layer].append([True, True])

        self.inplanes = 16  # 64
        super(ResNetCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0])
        self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2)
        if extra_conv:
            self.conv1000 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        if extra_fc:
            self.fc1000 = nn.Linear(64, 64)
        self.avgpool = nn.AvgPool2d(8, stride=1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(layer_gates[i], self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        if self.extra_conv:
            x = self.conv1000(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        if self.extra_fc:
            x = self.fc1000(x)

        return x


def resnet20(extra_conv=False, extra_fc=False):
    model = ResNetCifar(block=BasicBlock, layers=[3, 3, 3], extra_conv=extra_conv, extra_fc=extra_fc)
    return model


def resnet32(extra_conv=False, extra_fc=False):
    model = ResNetCifar(BasicBlock, [5, 5, 5], extra_conv, extra_fc)
    return model


def resnet44(extra_conv=False, extra_fc=False):
    model = ResNetCifar(BasicBlock, [7, 7, 7], extra_conv, extra_fc)
    return model


def resnet56(extra_conv=False, extra_fc=False):
    model = ResNetCifar(BasicBlock, [9, 9, 9], extra_conv, extra_fc)
    return model


class FeatureRefiner(nn.Module):

    def __init__(self, bias, relu, batch_norm, num_layers, input_dims, num_classes, num_per_class, bn_start, **kwargs):
        super(FeatureRefiner, self).__init__()

        self.num_samples = num_classes * num_per_class
        self.num_per_class = num_per_class
        self.num_layers = num_layers
        self.relu = relu
        self.batch_norm = batch_norm
        self.input_dims = input_dims
        self.bias = bias
        self.bn_start = bn_start
        self.wds = kwargs['weight_decays']

        # Creating layer names
        weight_names = []
        relu_names = None
        batch_norm_names = None
        if relu:
            relu_names = []
        if batch_norm:
            batch_norm_names = []
        if self.bn_start:
            bn_start_name = 'bn0'

        for layer_num in range(num_layers):
            weight_names.append('weight' + str(layer_num + 1))
            if relu and layer_num != num_layers - 1:
                relu_names.append('relu' + str(layer_num + 1))
            if batch_norm and layer_num != num_layers - 1:
                batch_norm_names.append('bn' + str(layer_num + 1))

        self.weight_names = weight_names
        self.relu_names = relu_names
        self.batch_norm_names = batch_norm_names

        # Defining the layers
        if self.bn_start:
            setattr(self, bn_start_name, nn.BatchNorm1d(num_features=self.input_dims[0]))
        for layer_num in range(num_layers):
            setattr(self, weight_names[layer_num], nn.Linear(in_features=self.input_dims[layer_num],
                                                             out_features=self.input_dims[layer_num + 1],
                                                             bias=bias))
            if relu and layer_num != num_layers - 1:
                setattr(self, relu_names[layer_num], nn.ReLU())
            if batch_norm and layer_num != num_layers - 1:
                setattr(self, batch_norm_names[layer_num], nn.BatchNorm1d(num_features=layer_num + 1))

        self._initialize_weights(**kwargs)

    def forward(self, features):
        if self.bn_start:
            features = getattr(self, 'bn0')(features)
        for layer_num in range(self.num_layers):
            features = getattr(self, self.weight_names[layer_num])(features)
            if self.batch_norm and layer_num != self.num_layers - 1:
                features = getattr(self, self.batch_norm_names[layer_num])(features)
            if self.relu and layer_num != self.num_layers - 1:
                features = getattr(self, self.relu_names[layer_num])(features)

        return features

    def forward_until_layer(self, features, until_layer_n):
        if self.bn_start:
            features = getattr(self, 'bn0')(features)
        for layer_num in range(until_layer_n):
            features = getattr(self, self.weight_names[layer_num])(features)
            if self.batch_norm and layer_num != self.num_layers - 1:
                features = getattr(self, self.batch_norm_names[layer_num])(features)
            if self.relu and layer_num != self.num_layers - 1:
                features = getattr(self, self.relu_names[layer_num])(features)

        return features

    def forward_until_layer_clean(self, features, until_layer_n):
        if self.bn_start:
            features = getattr(self, 'bn0')(features)
        for layer_num in range(until_layer_n - 1):
            features = getattr(self, self.weight_names[layer_num])(features)
            if self.batch_norm and layer_num != self.num_layers - 1:
                features = getattr(self, self.batch_norm_names[layer_num])(features)
            if self.relu and layer_num != self.num_layers - 1:
                features = getattr(self, self.relu_names[layer_num])(features)

        if until_layer_n >= 1:
            features = getattr(self, self.weight_names[until_layer_n - 1])(features)

        return features

    def _initialize_weights(self, **kwargs):
        weight_decays = kwargs['weight_decays']
        scaling = kwargs['scaling']
        dist = kwargs['dist']
        if dist == 'gaussian':
            with torch.no_grad():
                for layer_idx in range(self.num_layers):
                    getattr(self, self.weight_names[layer_idx]).weight.data = \
                        scaling * torch.randn_like(getattr(self, self.weight_names[layer_idx]).weight.data) / \
                        weight_decays[layer_idx + 1] ** (1 / 2)
        elif dist == 'kaiming':
            with torch.no_grad():
                for layer_idx in range(self.num_layers):
                    nn.init.kaiming_uniform_(getattr(self, self.weight_names[layer_idx]).weight.data,
                                             nonlinearity='relu')
                    getattr(self, self.weight_names[layer_idx]).weight.data *= scaling
        else:
            raise NotImplementedError

    def perturb_weights(self, scaling):
        with torch.no_grad():
            for layer_idx in range(self.num_layers):
                getattr(self, self.weight_names[layer_idx]).weight.data += \
                    scaling * torch.randn_like(getattr(self, self.weight_names[layer_idx]).weight.data) / \
                    torch.sqrt(torch.tensor([getattr(self, self.weight_names[layer_idx]).weight.data.shape[0] *
                                             getattr(self, self.weight_names[layer_idx]).weight.data.shape[1]])) * \
                    torch.frobenius_norm(getattr(self, self.weight_names[layer_idx]).weight.data)


class ExtendedResNet20(nn.Module):
    def __init__(self, **kwargs):
        super(ExtendedResNet20, self).__init__()
        self.backbone = resnet20(kwargs['extra_conv'], kwargs['extra_fc'])
        self.fr = FeatureRefiner(**kwargs)

    def forward(self, inputs):
        features = self.backbone(inputs)
        outputs = self.fr(features)
        return outputs


class ExtendedResNet32(nn.Module):
    def __init__(self, **kwargs):
        super(ExtendedResNet32, self).__init__()
        self.backbone = resnet32(kwargs['extra_conv'], kwargs['extra_fc'])
        self.fr = FeatureRefiner(**kwargs)

    def forward(self, inputs):
        features = self.backbone(inputs)
        outputs = self.fr(features)
        return outputs


class ExtendedResNet44(nn.Module):
    def __init__(self, **kwargs):
        super(ExtendedResNet44, self).__init__()
        self.backbone = resnet44(kwargs['extra_conv'], kwargs['extra_fc'])
        self.fr = FeatureRefiner(**kwargs)

    def forward(self, inputs):
        features = self.backbone(inputs)
        outputs = self.fr(features)
        return outputs


class ExtendedResNet56(nn.Module):
    def __init__(self, **kwargs):
        super(ExtendedResNet56, self).__init__()
        self.backbone = resnet56(kwargs['extra_conv'], kwargs['extra_fc'])
        self.fr = FeatureRefiner(**kwargs)

    def forward(self, inputs):
        features = self.backbone(inputs)
        outputs = self.fr(features)
        return outputs
