# -*- coding: utf-8 -*

# trained on CIFAR10

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

"""
Small_Inception
"""

class Small_Inception(nn.Module):

    def __init__(self, args, drop_rate=0.0, init_scale=1.0):
        super(Small_Inception, self).__init__()
        if args.data == 'MNIST':
            self.para_temp = 6
        elif args.data == 'cifar10' or args.data == 'cifar100':
            self.para_temp = 7

        self.features = nn.Sequential(
            BasicConv2d(args, args.input_dim, 96, 3, 1, padding=1),
            BasicInception(args, 96, 32, 32),
            BasicInception(args, 64, 32, 48),
            BasicDownsample(args, 80, 80),
            BasicInception(args, 160, 112, 48),
            BasicInception(args, 160, 96, 64),
            BasicInception(args, 160, 80, 80),
            BasicInception(args, 160, 48, 96),
            BasicDownsample(args, 144, 96),
            BasicInception(args, 240, 176, 160),
            BasicInception(args, 336, 176, 160),
        )
        # Linear classifer
        self.fc = nn.Linear(336, args.num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, init_scale * math.sqrt(2. /n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

                size = m.weight.size()
                fan_out = size[0]
                fan_in = size[1]
                variance = math.sqrt(2.0/(fan_in + fan_out))
                m.weight.data.normal_(0.0, init_scale * variance)

    def forward(self, x):
        out = self.features(x)
        # global average pooling
        out = F.avg_pool2d(out, self.para_temp)
        out = out.view(-1, 336)
        out = self.fc(out)
        return out

class BasicConv2d(nn.Module):

    def __init__(self, args, in_channels, out_channels, kernel_size, stride, padding):
        super(BasicConv2d, self).__init__()
        self.args = args
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
                              stride=stride, bias=False, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.args.activation == 'tanh':
            return F.tanh(x)
        elif self.args.activation == 'sigmoid':
            return F.sigmoid(x)
        elif self.args.activation == 'leaky_relu':
            return F.leaky_relu(x, inplace=True)
        else:
            return F.relu(x, inplace=True)

class BasicInception(nn.Module):

    def __init__(self, args, in_channels, out_channels1, out_channels3):
        super(BasicInception, self).__init__()
        self.conv1 = BasicConv2d(args, in_channels, out_channels1, kernel_size=1, stride=1, padding=0)
        self.conv2 = BasicConv2d(args, in_channels, out_channels3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        outputs = [x1, x2]
        return torch.cat(outputs, 1)

class BasicDownsample(nn.Module):

    def __init__(self, args, in_channels, out_channels):
        super(BasicDownsample, self).__init__()
        self.conv = BasicConv2d(args, in_channels, out_channels, kernel_size=3, stride=2, padding=0)
        self.max = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.max(x)
        outputs = [x1, x2]
        return torch.cat(outputs, 1)
