from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv

class Bottleneck(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch, kernel_size=3):
        super().__init__()
        self.conv1 = conv_block(in_ch, mid_ch, kernel_size)
        self.conv2 = conv_block(mid_ch, out_ch, kernel_size)
    
    def forward(self, x):
        return self.conv2(self.conv1(x))

class UNet_up(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Up = nn.Upsample(scale_factor=2, mode='bilinear')

        self.Avgpool16 = nn.AvgPool2d(kernel_size=16, stride=16)
        self.Avgpool8 = nn.AvgPool2d(kernel_size=8, stride=8)
        self.Avgpool4 = nn.AvgPool2d(kernel_size=4, stride=4)
        self.Avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2)

        self.Conv16 = Bottleneck(4, 64, 64)
        self.Conv8 = Bottleneck(68, 64, 64)
        self.Conv4 = Bottleneck(68, 64, 64, kernel_size=5)
        self.Conv2 = Bottleneck(68, 64, 64, kernel_size=7)
        self.Conv1 = Bottleneck(68, 64, 5, kernel_size=9)

    def forward(self, x):
        x_16 = self.Up(self.Conv16(self.Avgpool16(x)))
        x_8 = self.Up(self.Conv8(torch.cat((x_16, self.Avgpool8(x)), dim=1)))
        x_4 = self.Up(self.Conv4(torch.cat((x_8, self.Avgpool4(x)), dim=1)))
        x_2 = self.Up(self.Conv2(torch.cat((x_4, self.Avgpool2(x)), dim=1)))
        x_1 = self.Conv1(torch.cat((x_2, x), dim=1))

        return x_1
