from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv


class ArgmaxPooling(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self):
        super().__init__()

        self.argmax_kernel = nn.parameter.Parameter(torch.zeros(5, 5, 2, 2).float(), requires_grad=False)
        for i in range(5):
            self.argmax_kernel[i, i, :, :] = 1

    def forward(self, y):
        # y_ap has shape B*N_classes*H*W, one hot, float
        y_ap1 = F.one_hot(y, num_classes=5).permute(0,3,1,2).float()
        y_ap2 = F.one_hot(F.conv2d(y_ap1, self.argmax_kernel, bias=None, stride=2).argmax(dim=1), num_classes=5).permute(0, 3, 1, 2).float()
        y_ap4 = F.one_hot(F.conv2d(y_ap2, self.argmax_kernel, bias=None, stride=2).argmax(dim=1), num_classes=5).permute(0, 3, 1, 2).float()
        y_ap8 = F.one_hot(F.conv2d(y_ap4, self.argmax_kernel, bias=None, stride=2).argmax(dim=1), num_classes=5).permute(0, 3, 1, 2).float()
        y_ap16 = F.one_hot(F.conv2d(y_ap8, self.argmax_kernel, bias=None, stride=2).argmax(dim=1), num_classes=5).permute(0, 3, 1, 2).float()

        return [y_ap1, y_ap2, y_ap4, y_ap8, y_ap16]
