import torch
import torch.nn as nn
import numpy as np


class Dense(nn.Module):
    def __init__(self, in_dim, out_dim, BN=False,bias=True):
        super(Dense, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim,bias=bias)
        self.batch_norm = None
        if BN:
            self.batch_norm = nn.BatchNorm1d(out_dim)

        self.biasXgrad=None

        # initialize parameters
        # nn.init.xavier_uniform_(self.linear.weight.data)
        # self.linear.bias.data.fill_(0)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        out = self.linear(x)

        if self.batch_norm is not None:
            out = self.batch_norm(out)
        return out

    def save_biasgrad(self,method,R):
        def _postProcess(input, eps=1e-6):
            # Absolute value

            input = abs(input)

            # Rescale operations to ensure gradients lie between 0 and 1
            flatin = input.view((input.size(0), -1))

            temp, _ = flatin.min(1, keepdim=True)
            #temp = flatin.min()
            input = input - temp

            flatin = input.view((input.size(0), -1))

            temp, _ = flatin.max(1, keepdim=True)
            #temp = flatin.max()
            input = input / (temp + eps)

            return input
        if self.linear.bias is not None:
            bias=self.linear.bias
            bias_size = [1] * len(R.size())
            bias_size[1] = bias.size(0)
            b = bias.view(tuple(bias_size))
            if method=='fullgrad':
                self.biasXgrad = _postProcess(R * b.expand_as(R)).sum(1, keepdim=True)
            else:
                self.biasXgrad = (R * b.expand_as(R)).sum(1, keepdim=True)


    def analyze(self, method, R):
        self.save_biasgrad(method,R)
        R= self._guided_backprop_backward(R)
        return R

    def _guided_backprop_backward(self, R):
        weight = self.linear.weight
        if self.batch_norm is not None:
            weight = weight * self.batch_norm.weight.unsqueeze(1) / torch.sqrt(
                self.batch_norm.running_var.unsqueeze(1) + self.batch_norm.eps)

        newR = torch.matmul(R, weight)

        return newR
    def _backprop_backward(self, R):
        weight = self.linear.weight
        if self.batch_norm is not None:
            weight = weight * self.batch_norm.weight.unsqueeze(1) / torch.sqrt(
                self.batch_norm.running_var.unsqueeze(1) + self.batch_norm.eps)

        newR = torch.matmul(R, weight)

        return newR
