import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torch.autograd import Variable

def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom deconvolutional layer for simplicity."""
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom convolutional layer for simplicity."""
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

class G12(nn.Module):
    """Generator for transfering from mnist to svhn"""
    def __init__(self, conf, conv_dim=64, svhn_input=None):
        super(G12, self).__init__()

        self.config = conf

        # encoding blocks
        self.conv1 = conv(1, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        
        # residual blocks
        self.conv3 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)
        self.conv4 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)
        
        # decoding blocks
        self.deconv1 = deconv(conv_dim*2, conv_dim, 4)
        if self.config.exp_id == 0 or self.config.exp_id == 3:
            self.deconv2 = deconv(conv_dim, 3, 4, bn=False)
        else:
            self.deconv2 = deconv(conv_dim, 1, 4, bn=False)
        
    def forward(self, x):
        out_1 = F.leaky_relu(self.conv1(x), 0.05)      # (?, 64, 16, 16)
        out_2 = F.leaky_relu(self.conv2(out_1), 0.05)    # (?, 128, 8, 8)
        
        out_3 = F.leaky_relu(self.conv3(out_2), 0.05)    # ( " )
        out_4 = F.leaky_relu(self.conv4(out_3), 0.05)    # ( " )
        
        out_5 = F.leaky_relu(self.deconv1(out_4), 0.05)  # (?, 64, 16, 16)
        out = F.tanh(self.deconv2(out_5))              # (?, 3, 32, 32)

        return out

class residual_block(nn.Module):
    def __init__(self, in_channels, filters=64, kernel_size=3, stride=1, padding=1):
        super(residual_block, self).__init__()
        self.main = nn.Sequential(
            # batch_size x in_channels x 64 x 64
            nn.Conv2d(in_channels, filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True),
            nn.Conv2d(filters, filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(filters)
            # batch_size x filters x 64 x 64
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != filters:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filters, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filters)
            )

    def forward(self, inputs):
        output = self.main(inputs)
        output += self.shortcut(inputs)
        return output

class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, leakiness=0.2):
        super(conv_block, self).__init__()
        self.main = nn.Sequential(
            # batch_size x in_channels x H x W
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(leakiness, inplace=False)
            # batch_size x out_channels x H' x W'
        )

    def forward(self, inputs):
        output = self.main(inputs)
        return output

class inject_noise(nn.Module):
    def __init__(self, opt, dropout=False):
        super(inject_noise, self).__init__()
        self.noise_mean = opt.D_noise_mean
        self.noise_stddev = opt.D_noise_stddev
        self.dropout = nn.Sequential()
        if dropout:
            self.dropout = nn.Dropout(opt.D_keep_prob, inplace=False)
        self.noise = torch.FloatTensor()

        self.noise = self.noise.cuda()

    def forward(self, inputs):
        output = self.dropout(inputs)
        # print(output.size())
        if self.noise_stddev != 0:
            n = Variable(self.noise.resize_(output.size()).normal_(self.noise_mean, self.noise_stddev))
            output += n
        return output

class pixelda_D(nn.Module):
    def __init__(self, in_channels, opt, kernel_size=3, stride=2, padding=1):
        super(pixelda_D, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        layers, out_channels, projection_size  = self.make_layers(opt.ndf, opt.D_projection_size, opt)
        self.main = nn.Sequential(
            # batch_size x in_channels x 64 x 64
            nn.Conv2d(in_channels, opt.ndf, kernel_size=3, stride=1, padding=1, bias=True),
            # nn.BatchNorm2d(opt.ndf),
            nn.LeakyReLU(opt.leakiness, inplace=False),
            *layers
        )
        self.fully_connected = nn.Linear(projection_size*projection_size*out_channels, 1)

    def make_layers(self, filters, projection_size, opt):
        feature_map = 32    # H or W
        layers = []
        in_channels = filters
        out_channels = in_channels
        while feature_map > projection_size:
            out_channels = in_channels * 2
            for _ in range(1, opt.D_conv_block_size):
                layers.append(conv_block(in_channels, out_channels,
                              kernel_size=3, stride=1, padding=1, leakiness=opt.leakiness))
                in_channels = out_channels
            layers.append(conv_block(in_channels, out_channels,
                          self.kernel_size, self.stride, self.padding, opt.leakiness))
            layers.append(inject_noise(opt, dropout=True))
            in_channels = out_channels
            feature_map = int(np.floor(np.divide(feature_map + 2*self.padding - self.kernel_size, self.stride) + 1))

        assert feature_map == projection_size
        return layers, out_channels, feature_map

    def forward(self, inputs):
        output = self.main(inputs)
        output = output.view(output.size(0), -1)
        output = self.fully_connected(output)
        return output

class pixelda_G(nn.Module):
    def __init__(self, out_channels, in_channels, opt, kernel_size=3, stride=1, padding=1):
        super(pixelda_G, self).__init__()
        filters = opt.ngf
        blocks = []
        for block in range(opt.G_residual_blocks):
            blocks.append(residual_block(filters, filters, kernel_size, stride, padding))
        self.main = nn.Sequential(
            # batch_size x (image_channels + noise_channels) x H x W
            nn.Conv2d(in_channels, filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
            nn.ReLU(True),
            *blocks,
            nn.Conv2d(filters, out_channels, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Tanh()
            # batch_size x out_channels x H x W
        )

    def forward(self, inputs):
        output = self.main(inputs)
        return output
    
class G21(nn.Module):
    """Generator for transfering from svhn to mnist"""
    def __init__(self, conf,  conv_dim=64, svhn_input=None):
        super(G21, self).__init__()

        self.config = conf

        # encoding blocks
        if self.config.exp_id == 0 or self.config.exp_id == 3:
            self.conv1 = conv(3, conv_dim, 4)
        else:
            self.conv1 = conv(1, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        
        # residual blocks
        self.conv3 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)
        self.conv4 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)
        
        # decoding blocks
        self.deconv1 = deconv(conv_dim*2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 1, 4, bn=False)
        
    def forward(self, x):
        out_1 = F.leaky_relu(self.conv1(x), 0.05)      # (?, 64, 16, 16)
        out_2 = F.leaky_relu(self.conv2(out_1), 0.05)    # (?, 128, 8, 8)
        
        out_3 = F.leaky_relu(self.conv3(out_2), 0.05)    # ( " )
        out_4 = F.leaky_relu(self.conv4(out_3), 0.05)    # ( " )
        
        out_5 = F.leaky_relu(self.deconv1(out_4), 0.05)  # (?, 64, 16, 16)
        out = F.tanh(self.deconv2(out_5))              # (?, 1, 32, 32)

        return out
    
class D1(nn.Module):
    """Discriminator for mnist."""
    def __init__(self, conv_dim=64):
        super(D1, self).__init__()
        self.conv1 = conv(1, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
        self.fc = conv(conv_dim*4, 1, 4, 1, 0, False)
        
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)    # (?, 64, 16, 16)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 4, 4)
        out = self.fc(out).squeeze()
        return out

class D2(nn.Module):
    """Discriminator for svhn."""
    def __init__(self, conf,conv_dim=64):
        super(D2, self).__init__()
        self.config = conf
        if self.config.exp_id == 0 or self.config.exp_id == 3:
            self.conv1 = conv(3, conv_dim, 4, bn=False)
        else:
            self.conv1 = conv(1, conv_dim, 4, bn=False)

        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
        self.fc = conv(conv_dim*4, 1, 4, 1, 0, False)
        
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)    # (?, 64, 16, 16)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 4, 4)
        out = self.fc(out).squeeze()
        return out