'''
Adapted from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


__all__ = ['densenet']


class Bottleneck(nn.Module):
    def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0, no_batch_norm=False):
        super(Bottleneck, self).__init__()
        planes = expansion * growthRate
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
                               padding=1, bias=False)
        if not no_batch_norm:
            self.bn1 = nn.BatchNorm2d(inplanes)
            self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropRate = dropRate
        self.no_batch_norm = no_batch_norm

    def forward(self, x):
        if not self.no_batch_norm:
            out = self.bn1(x)
        else:
            out = x
        out = self.relu(out)
        out = self.conv1(out)
        if not self.no_batch_norm:
            out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        if self.dropRate > 0:
            out = F.dropout(out, p=self.dropRate, training=self.training)

        out = torch.cat((x, out), 1)

        return out


class BasicBlock(nn.Module):
    def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0, no_batch_norm=False):
        super(BasicBlock, self).__init__()
        if not no_batch_norm:
            self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
                               padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.dropRate = dropRate
        self.no_batch_norm = no_batch_norm

    def forward(self, x):
        if not self.no_batch_norm:
            out = self.bn1(x)
        else:
            out = x
        out = self.relu(out)
        out = self.conv1(out)
        if self.dropRate > 0:
            out = F.dropout(out, p=self.dropRate, training=self.training)

        out = torch.cat((x, out), 1)

        return out


class Transition(nn.Module):
    def __init__(self, inplanes, outplanes, no_batch_norm=False):
        super(Transition, self).__init__()
        if not no_batch_norm:
            self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1,
                               bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.no_batch_norm = no_batch_norm

    def forward(self, x):
        if not self.no_batch_norm:
            out = self.bn1(x)
        else:
            out = x
        out = self.relu(out)
        out = self.conv1(out)
        out = F.avg_pool2d(out, 2)
        return out


class DenseNet(nn.Module):
    def __init__(self, depth=22, block=Bottleneck, dropRate=0,
                 num_classes=10, growthRate=12, compressionRate=2,
                 no_batch_norm=False):
        super(DenseNet, self).__init__()

        assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
        n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6

        self.growthRate = growthRate
        self.dropRate = dropRate

        # self.inplanes is a global variable used across multiple
        # helper functions
        self.inplanes = growthRate * 2
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
                               bias=False)
        self.dense1 = self._make_denseblock(block, n, no_batch_norm)
        self.trans1 = self._make_transition(compressionRate, no_batch_norm)
        self.dense2 = self._make_denseblock(block, n, no_batch_norm)
        self.trans2 = self._make_transition(compressionRate, no_batch_norm)
        self.dense3 = self._make_denseblock(block, n, no_batch_norm)
        if not no_batch_norm:
            self.bn = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(self.inplanes, num_classes)
        self.no_batch_norm = no_batch_norm

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_denseblock(self, block, blocks, no_batch_norm):
        layers = []
        for i in range(blocks):
            # Currently we fix the expansion ratio as the default value
            layers.append(block(self.inplanes,
                                growthRate=self.growthRate,
                                dropRate=self.dropRate,
                                no_batch_norm=no_batch_norm))
            self.inplanes += self.growthRate

        return nn.Sequential(*layers)

    def _make_transition(self, compressionRate, no_batch_norm):
        inplanes = self.inplanes
        outplanes = int(math.floor(self.inplanes // compressionRate))
        self.inplanes = outplanes
        return Transition(inplanes, outplanes, no_batch_norm)

    def forward(self, x):
        x = self.conv1(x)
        x = self.trans1(self.dense1(x))
        x = self.trans2(self.dense2(x))
        x = self.dense3(x)
        if not self.no_batch_norm:
            x = self.bn(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def densenet(**kwargs):
    """
    Constructs a DenseNet model.
    """
    return DenseNet(**kwargs)
