# -*- coding: utf-8 -*

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
Classical U-Net: input_size == output_size
"""


class UNet(nn.Module):
    def __init__(self, args, bilinear=True, init_scale=1.0):
        super(UNet, self).__init__()
        self.args = args
        if args.data == 'MNIST':
            n_channels = 1
        elif args.data == 'cifar10' or args.data == 'cifar100':
            n_channels = 3
        self.inc = inconv(args, n_channels, 64) # First step of Contracting
        self.down1 = down(args, 64, 128) # Second step of Contracting
        self.down2 = down(args, 128, 256) # Third step of Contracting
        self.down3 = down(args, 256, 512) # Fourth step of Contracting
        self.down4 = down(args, 512, 512) # Bottleneck of U-Net
        self.up1 = up(args, 1024, 256, bilinear) # First step of Expanding
        self.up2 = up(args, 512, 128, bilinear) # Second step of Expanding
        self.up3 = up(args, 256, 64, bilinear) # Third step of Expanding
        self.up4 = up(args, 128, 64, bilinear) # Fourth step of Expanding
        self.outc = outconv(64, n_channels) # Output Conv layer with 1*1 filter

        # weights initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, init_scale * math.sqrt(2. /n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
                size = m.weight.size()
                fan_out = size[0]
                fan_in = size[1]
                variance = math.sqrt(2.0/(fan_in + fan_out))
                m.weight.data.normal_(0.0, init_scale * variance)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        if self.args.skip == 1:
            x1_zero = torch.zeros_like(x1)
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1_zero)
        elif self.args.skip == 2:
            x1_zero = torch.zeros_like(x1)
            x2_zero = torch.zeros_like(x2)
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2_zero)
            x = self.up4(x, x1_zero)
        elif self.args.skip == 3:
            x1_zero = torch.zeros_like(x1)
            x2_zero = torch.zeros_like(x2)
            x3_zero = torch.zeros_like(x3)
            x = self.up1(x5, x4)
            x = self.up2(x, x3_zero)
            x = self.up3(x, x2_zero)
            x = self.up4(x, x1_zero)
        elif self.args.skip == 4:
            x1_zero = torch.zeros_like(x1)
            x2_zero = torch.zeros_like(x2)
            x3_zero = torch.zeros_like(x3)
            x4_zero = torch.zeros_like(x4)
            x = self.up1(x5, x4_zero)
            x = self.up2(x, x3_zero)
            x = self.up3(x, x2_zero)
            x = self.up4(x, x1_zero)
        else:
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
        x = self.outc(x)
        #return torch.sigmoid(x)
        return x
    

class double_conv(nn.Module):
    ''' 2 * (conv -> BN -> ReLU) '''
    def __init__(self, args, in_ch, out_ch):
        super(double_conv, self).__init__()
        if args.bn == 'no_bn':
            if args.activation == 'tanh':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.Tanh(),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.Tanh()
                )
            elif args.activation == 'sigmoid':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.Sigmoid(),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.Sigmoid()
                )
            elif args.activation == 'leaky_relu':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.LeakyReLU(inplace=True)
                )
            else:
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.ReLU(inplace=True)
                )
        elif args.bn == 'no_2':
            if args.activation == 'tanh':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.Tanh(),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.Tanh()
                )
            elif args.activation == 'sigmoid':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.Sigmoid(),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.Sigmoid()
                )
            elif args.activation == 'leaky_relu':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.LeakyReLU(inplace=True)
                )
            else:
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.ReLU(inplace=True)
                )
        else:
            if args.activation == 'tanh':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.Tanh(),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.Tanh()
                )
            elif args.activation == 'sigmoid':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.Sigmoid(),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.Sigmoid()
                )
            elif args.activation == 'leaky_relu':
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.LeakyReLU(inplace=True)
                )
            else:
                self.conv = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
                    nn.BatchNorm2d(out_ch),
                    nn.ReLU(inplace=True)
                )
    
    def forward(self, x):
        x = self.conv(x)
        return x
    

class inconv(nn.Module):
    ''' double_conv '''
    def __init__(self, args, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(args, in_ch, out_ch)
    
    def forward(self, x):
        x = self.conv(x)
        return x
    

class down(nn.Module):
    ''' maxpool -> double_conv '''
    def __init__(self, args, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(args, in_ch, out_ch)
        )
    
    def forward(self, x):
        x = self.mpconv(x)
        return x
    

class up(nn.Module):
    ''' upsample -> conv '''
    def __init__(self, args, in_ch, out_ch, bilinear=False):
        super(up, self).__init__()
        if bilinear:
            # nn.Upsample does not change the num_channels.
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            # nn.ConvTranspose2d changes the num_channels (half).
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
        self.conv = double_conv(args, in_ch, out_ch)
    
    def forward(self, x1, x2):
        x1 = self.up(x1) # input is [batch_size, channel, height, wide]
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2))
        
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x
    

class outconv(nn.Module):
    ''' conv '''
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
    
    def forward(self, x):
        x = self.conv(x)
        return x