import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class Convolutional(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(Convolutional, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.X_shape = 0
        self.pre_activation_shape = 0
        self.biasXgrad=None

        # initialize parameters
        nn.init.xavier_uniform_(self.conv.weight.data)
        if bias:
            self.conv.bias.data.fill_(0)

    def forward(self, x):

        self.X_shape = x.shape

        out = self.conv.forward(x)
        self.pre_activation_shape = out.shape
        return out

    def analyze(self, method, R):
        # if previous layer was a dense layer, R needs to be reshaped
        # to the form of self.X after the convolution in the forward pass
        batch_size, _, height, width = self.pre_activation_shape
        if R.shape != torch.Size([batch_size, self.conv.out_channels, height, width]):
            R = R.view(batch_size, self.conv.out_channels, height, width)
        R= self._guided_backprop_backward(method,R)
        return R

    def save_biasgrad(self,method,R):
        if self.conv.bias is not None:
            bias=self.conv.bias
            bias_size = [1] * len(R.size())
            bias_size[1] = bias.size(0)
            b = bias.view(tuple(bias_size))
            temp=R*b.expand_as(R)
            temp=temp.sum(1,keepdim=True)
            #print(temp.size())
            if method=='our+' and temp.size(2) > 10:
                mask=torch.ones_like(temp)
                mask[:,:,0,:]=0
                mask[:,:,:,temp.size(3)-1]=0
                mask[:, :, temp.size(2)-1,:] = 0
                mask[:, :, :, 0] = 0
                temp=temp*mask
            #print(torch.sum(temp[:,:,0,:])+torch.sum(temp[:,:,:,temp.size(3)-1]))
            self.biasXgrad=temp



    def _guided_backprop_backward(self,method, R):
        self.save_biasgrad(method,R)
        newR = self.deconvolve(R, self.conv.weight)
        return newR
    def _backprop_backward(self, R):
        self.save_biasgrad(R)
        newR = self.deconvolve(R, self.conv.weight)
        return newR

    def deconvolve(self, y, weights):

        # dimensions before convolution in forward pass
        # the deconvolved image has to have the same dimension
        _, _, org_height, org_width = self.X_shape

        # stride and padding from forward convolution
        padding = self.conv.padding
        stride = self.conv.stride

        _, _, filter_height, filter_width = weights.shape

        # the deconvolved image has minimal size
        # to obtain an image with the same size as the image before the convolution in the forward pass
        # we pad the output of the deconvolution
        output_padding = ((org_height + 2 * padding[0] - filter_height) % stride[0],
                          (org_width + 2 * padding[1] - filter_width) % stride[1])  # a=(i+2p−k) mod s /
        #output_padding = ( stride[0]+ 2 * padding[0]-filter_height,
        #                  stride[1]+ 2 * padding[1] - filter_width)#stride+2*padding-filter_
        # perform actual deconvolution
        # this is basically a forward convolution with flipped (and permuted) filters/weights
        deconvolved = torch.nn.functional.conv_transpose2d(input=y, weight=weights, bias=None,
                                                           padding=self.conv.padding, stride=self.conv.stride,
                                                           groups=self.conv.groups, dilation=self.conv.dilation,
                                                           output_padding=output_padding)

        return deconvolved
