from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv


class UNet_inplace(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.conv1 = conv_block(4, 64, 31)
        self.conv2 = conv_block(64, 64, 25)
        self.conv3 = conv_block(64, 64, 13)
        self.conv4 = conv_block(64, 64, 7)
        self.conv5 = conv_block(64, 64, 5)
        self.conv6 = nn.Conv2d(64, 5, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)

        return x
