import torch.nn as nn
import torch
import torch.nn.functional as F

class MaxPool(nn.Module):
    def __init__(self, kernel_size=2, stride=None, padding=0, dilation=1, ceil_mode=False):
        super(MaxPool, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
                                 ceil_mode=ceil_mode, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=kernel_size, stride=stride,padding=padding)
        self.indices = None
        self.X_size=0

    def forward(self, x):
        self.X_size = x.size()
        output, self.indices = self.pool(x)

        return output

    def analyze(self, method, R):

        batch_size, channels, height, width = self.X_size
        height = int(height/2)
        width = int(width/2)

        if R.shape != torch.Size([batch_size, channels, height, width]):
            R = R.view(batch_size, channels, height, width)

        R=self.unpool(R, self.indices,self.X_size)
        return R


class AvgPool(nn.Module):
    def __init__(self, kernel_size=2, stride=None, padding=0, dilation=1, ceil_mode=False):
        super(AvgPool, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode)

    def forward(self, x):
        self.X = x

        output = self.pool(x)

        return output

    def analyze(self, method, R):

        Z = (self.forward(self.X)+1e-9)
        S = R / Z
        C = torch.zeros(self.X.shape)
        for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
            C[:, i::2, j::2, :] = S * 0.25 + C[:, i::2, j::2, :].clone()
        R = self.X*C
        return R
class AdaptiveAvgPool(nn.Module):
    def __init__(self,output_size):
        super(AdaptiveAvgPool, self).__init__()
        self.output_size=output_size
        self.pool = nn.AdaptiveAvgPool2d(output_size)
        self.X_size=0

    def forward(self, x):
        self.X_size = x.size()

        output = self.pool(x)

        return output

    def analyze(self, method, R):
        temp=int(R.size(1)/(self.output_size[0]*self.output_size[1]))
        R = R.view(R.size(0), temp, self.output_size[0], self.output_size[1])
        '''
        R=R.mean(2,keepdim=True).mean(3,keepdim=True)
        R=R.expand_as(torch.zeros(self.X_size[0],self.X_size[1],self.X_size[2],self.X_size[3]).to(R.device))
        R=R/(R.size(2)*R.size(3))
        '''
        R = F.interpolate(R, size=(self.X_size[2], self.X_size[3]), mode='bilinear',
                                 align_corners=True)*(self.output_size[0]*self.output_size[1])/(self.X_size[2]*self.X_size[3])
        return R