import math
import torch
from torch import nn


# from: https://github.com/felixgwu/img_classification_pk_pytorch/blob/master/models/resnet.py
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), bias=False)


class BasicBlockWithDeathRate(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, death_rate=0., downsample=None):
        super(BasicBlockWithDeathRate, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.death_rate = death_rate

    def forward(self, x):
        residual = x
        if self.downsample is not None:
            x = self.downsample(x)
        # TODO: fix the bug of original Stochatic depth
        if not self.training or torch.rand(1)[0] >= self.death_rate:
            residual = self.conv1(residual)
            residual = self.bn1(residual)
            residual = self.relu1(residual)
            residual = self.conv2(residual)
            residual = self.bn2(residual)
            if self.training:
                residual /= (1. - self.death_rate)
            x = x + residual
            x = self.relu2(x)

        return x


class DownsampleB(nn.Module):

    def __init__(self, n_in, n_out, stride):
        super(DownsampleB, self).__init__()
        self.avg = nn.AvgPool2d(stride)
        self.expand_ratio = n_out // n_in

    def forward(self, x):
        x = self.avg(x)
        return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)


class ResNetSD(nn.Module):
    """
    Small ResNet for CIFAR & SVHN
    death_rates: death_rates of each block except for the first and the last block
    """

    def __init__(self, depth, death_mode=None, death_rate=0, num_classes=10,
                 block=BasicBlockWithDeathRate, name='resnet_sd'):
        super(ResNetSD, self).__init__()

        nblocks = (depth - 2) // 2
        if death_mode == 'uniform':
            death_rates = [death_rate] * nblocks
        elif death_mode == 'linear':
            death_rates = [float(i + 1) * death_rate / float(nblocks) for i in range(nblocks)]
        else:
            death_rates = None

        assert (depth - 2) % 6 == 0, 'depth should be one of 6N+2'
        n = (depth - 2) // 6
        assert death_rates is None or len(death_rates) == 3 * n
        if death_rates is None:
            death_rates = [0.] * (3 * n)

        self.name = f"{name}{depth}"
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, death_rates[:n])
        self.layer2 = self._make_layer(block, 32, death_rates[n:2 * n], stride=2)
        self.layer3 = self._make_layer(block, 64, death_rates[2 * n:], stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, death_rates, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = DownsampleB(self.inplanes, planes * block.expansion, stride)

        layers = [block(self.inplanes, planes, stride, downsample=downsample, death_rate=death_rates[0])]
        self.inplanes = planes * block.expansion
        for death_rate in death_rates[1:]:
            layers.append(block(self.inplanes, planes, death_rate=death_rate))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
