from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv


class UNet_patch(nn.Module):

    def __init__(self, in_ch=4, out_ch=5, patch_size=16):
        super().__init__()

        n1 = 16
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.ps = patch_size

        self.refpad_unfold = nn.Sequential(
            nn.ReflectionPad2d(padding=(patch_size//2, patch_size//2-1, patch_size//2, patch_size//2-1)),
            nn.Unfold(kernel_size=patch_size, stride=1, padding=0)
        )

        self.Avgpooling = nn.AvgPool2d(kernel_size=2, stride=2)
        self.Maxpooling = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(in_ch + 2, filters[0])  
        self.Conv2 = conv_block(filters[0] + 2, filters[1]) 
        self.Conv3 = conv_block(filters[1] + 2, filters[2])
        self.Conv4 = conv_block(filters[2] + 2, filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.pred = nn.Conv2d(filters[4], out_ch, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.Avgpooling(x)
        B, C, H, W = x.shape
        x_p = self.refpad_unfold(x).permute(0, 2, 1).reshape(B*H*W, C, self.ps, self.ps)  # bs*H*W, C, ps, ps

        x_p = self.pos_encoding(x_p)
        x_p = self.Maxpooling(self.Conv1(x_p))  # ps/2, ps/2

        x_p = self.pos_encoding(x_p)
        x_p = self.Maxpooling(self.Conv2(x_p))  # ps/4, ps/4
        
        x_p = self.pos_encoding(x_p)
        x_p = self.Maxpooling(self.Conv3(x_p))  # ps/8, ps/8

        x_p = self.pos_encoding(x_p)
        x_p = self.Maxpooling(self.Conv4(x_p))  # ps/16, ps/16

        x_p = self.Conv5(x_p)  # ps/16, ps/16
        x_p = self.pred(x_p)  # bs*H*W, Cout, 1, 1
        x_p = x_p.reshape(B, H, W, -1).permute(0, 3, 1, 2)
        x_p = self.up(x_p)

        return x_p


    @staticmethod
    def pos_encoding(x):
        B, C, H, W = x.shape
        N = H
        r = (torch.arange(N).float() - N / 2) / (N / 2)
        rows = r.view(1, 1, 1, N).repeat(B, 1, N, 1)
        cols = r.view(1, 1, N, 1).repeat(B, 1, 1, N)
        pos = torch.cat((rows, cols), dim=1).to(x.device)
        return torch.cat((x, pos), dim=1)