import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class Batch_Norm2d(nn.Module):
    def __init__(self,num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super(Batch_Norm2d,self).__init__()
        self.batch_norm=nn.BatchNorm2d(num_features,eps,momentum,affine,track_running_stats)
        self.X_size = 0
        self.biasXgrad=None

    def forward(self, x):
        self.X_size=x.size()
        out=self.batch_norm(x)
        return out
    def save_biasgrad(self,method,R):
        if self.batch_norm.bias is not None:
            bias=- (self.batch_norm.running_mean * self.batch_norm.weight
               / torch.sqrt(self.batch_norm.running_var + self.batch_norm.eps)) + self.batch_norm.bias
            bias_size = [1] * len(R.size())
            bias_size[1] = bias.size(0)
            b = bias.view(tuple(bias_size))
            temp = R * b.expand_as(R)
            temp = temp.sum(1, keepdim=True)
            # print(temp.size())
            if method=='our+' and temp.size(2) > 10:
                mask = torch.ones_like(temp)
                mask[:, :, 0, :] = 0
                mask[:, :, :, temp.size(3) - 1] = 0
                mask[:, :, temp.size(2) - 1, :] = 0
                mask[:, :, :, 0] = 0
                temp = temp * mask
            # print(torch.sum(temp[:,:,0,:])+torch.sum(temp[:,:,:,temp.size(3)-1]))
            self.biasXgrad = temp
            #self.biasXgrad=(R*b.expand_as(R)).sum(1,keepdim=True)
    def analyze(self, method, R):
        return self._guided_backprop_backward(method,R)

    def _guided_backprop_backward(self,method, R):
        self.save_biasgrad(method,R)
        newR=R* self.batch_norm.weight.view(1,self.X_size[1],1,1) / torch.sqrt(
                self.batch_norm.running_var.view(1,self.X_size[1],1,1) + self.batch_norm.eps)
        return newR

class Batch_Norm1d(nn.Module):
    def __init__(self,num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super(Batch_Norm1d,self).__init__()
        self.batch_norm=nn.BatchNorm1d(num_features,eps,momentum,affine,track_running_stats)
        self.X_size = 0
        self.biasXgrad=None

    def forward(self, x):
        self.X_size=x.size()
        out=self.batch_norm(x)
        return out
    def save_biasgrad(self,R):
        if self.batch_norm.bias is not None:
            bias=- (self.batch_norm.running_mean * self.batch_norm.weight
               / torch.sqrt(self.batch_norm.running_var + self.batch_norm.eps)) + self.batch_norm.bias
            bias_size = [1] * len(R.size())
            bias_size[1] = bias.size(0)
            b = bias.view(tuple(bias_size))
            self.biasXgrad=(R*b.expand_as(R)).sum(1,keepdim=True)
    def analyze(self, method, R):
        return self._guided_backprop_backward(R)

    def _guided_backprop_backward(self, R):
        #self.save_biasgrad(R)
        newR=R* self.batch_norm.weight.view(1,self.X_size[1]) / torch.sqrt(
                self.batch_norm.running_var.view(1,self.X_size[1]) + self.batch_norm.eps)
        return newR