
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConvolution(nn.Module):
    """(conv => BN => ReLU) * 2"""

    def __init__(self, in_ch, out_ch, nonlin=nn.ReLU, bn_layer=nn.BatchNorm2d):
        super(DoubleConvolution, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            bn_layer(out_ch),
            nonlin(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            bn_layer(out_ch),
            nonlin(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class InConvolution(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(InConvolution, self).__init__()
        self.conv = DoubleConvolution(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class DownSampler(nn.Module):

    def __init__(self, in_ch, out_ch, nonlin=nn.ReLU, bn_layer=nn.BatchNorm2d):
        super(DownSampler, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConvolution(in_ch, out_ch, nonlin=nonlin, bn_layer=bn_layer)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class UpSampler(nn.Module):

    def __init__(self, in_ch, out_ch, bilinear=True, nonlin=nn.ReLU, bn_layer=nn.BatchNorm2d):
        super(UpSampler, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = DoubleConvolution(in_ch, out_ch, nonlin=nonlin, bn_layer=bn_layer)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        # for padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConvolution(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConvolution, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):

    def __init__(self, n_channels, n_classes, depth=4, initial_features=16, out_nonlin=torch.sigmoid):
        super(UNet, self).__init__()

        self.n_classes = n_classes
        self.depth = depth
        self.out_nonlin = out_nonlin

        # input convolution
        self.inc = InConvolution(n_channels, initial_features)

        # down-sampling part
        down = []
        current_features = initial_features
        for i_layer in range(self.depth-1):
            down.append(DownSampler(current_features, 2 * current_features))
            current_features *= 2
        down.append(DownSampler(current_features, current_features))

        self.down = nn.ModuleList(down)

        # up-sampling part
        current_features *= 2
        up = []
        for i_layer in range(self.depth - 1):
            in_features = current_features
            out_features = current_features // 4
            up.append(UpSampler(in_features, out_features, bilinear=False))
            current_features //= 2
        up.append(UpSampler(current_features, initial_features, bilinear=False))

        self.up = nn.ModuleList(up)

        # output convolution
        self.outc = OutConvolution(initial_features, n_classes)

    def forward(self, x):
        inp = x['X']

        # apply input convolution
        hiddens = [self.inc(inp)]

        # apply down stream convolutions
        for i_layer in range(self.depth):
            hiddens.append(self.down[i_layer](hiddens[-1]))

        # apply up stream convolutions
        hidden_idx = len(hiddens) - 1
        x = hiddens[hidden_idx]
        for i_layer in range(self.depth):
            hidden_idx -= 1
            x = self.up[i_layer](x, hiddens[hidden_idx])

        # apply output convolution
        x = self.outc(x)

        if self.out_nonlin is not None:
            x = self.out_nonlin(x)

        return {'prediction': x}


def test_binary_model():
    net = UNet(n_channels=1, n_classes=1, depth=4)
    x = torch.from_numpy(np.zeros((1, 1, 48, 48), dtype=np.float32))
    out = net({"X": x})
    print("\nBinary Segmentation Model")
    print("input_shape: ", x.shape)
    print("output_shape:", out["prediction"].shape)


if __name__ == "__main__":
    """ main """
    test_binary_model()
