import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .model import resnet,resnet_small,resnet_nobn

from .layers import Convolutional, Dense, MaxPool,AvgPool,Batch_Norm2d,AdaptiveAvgPool,ReLU,Batch_Norm1d
from .layers.resblock import BasicBlock,Bottleneck,Downsample

class ExplainableNet(nn.Module):
    def __init__(self, model=None,method='gbp', beta=None):
        super(ExplainableNet, self).__init__()

        # replace relus by differentiable counterpart for beta growth
        self.beta = beta

        self.layers = nn.ModuleList([])

        if model is not None:
            self.fill_layers(model)

        # remove activation function in last layer
        self.layers[-1].activation_fn = None

        self.X = 0
        self.output = 0
        self.R=None

    def fill_layers(self, model):

        for layer in model.features:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)
        new_layer=self.create_layer(model.avgpool)
        self.layers.append(new_layer)
        for layer in model.classifier:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

    def create_layer(self, layer):
        def inherit_bn(new_bn, old_bn):
            new_bn.weight.data = old_bn.weight.data
            new_bn.bias.data = old_bn.bias.data
            new_bn.running_var = old_bn.running_var
            new_bn.running_mean = old_bn.running_mean
        if type(layer) == torch.nn.Conv2d:
            new_layer = Convolutional(in_channels=layer.in_channels,
                                      out_channels=layer.out_channels,
                                      kernel_size=layer.kernel_size,
                                      stride=layer.stride,
                                      padding=layer.padding,
                                      bias=True)
            new_layer.conv.weight.data = layer.weight.data
            new_layer.conv.bias.data = layer.bias.data

        elif type(layer) == nn.MaxPool2d:
            new_layer = MaxPool(kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding)
        elif type(layer) == nn.AvgPool2d:
            new_layer = AvgPool(kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding)
        elif type(layer) == nn.AdaptiveAvgPool2d:
            new_layer = AdaptiveAvgPool(layer.output_size)
        elif type(layer) == nn.Linear:
            new_layer = Dense(in_dim=layer.in_features, out_dim=layer.out_features)
            new_layer.linear.weight.data = layer.weight.data
            new_layer.linear.bias.data = layer.bias.data

        elif type(layer) == (nn.Dropout or nn.Dropout2d):
            new_layer = layer

        elif type(layer) == nn.ReLU or type(layer)==nn.Softplus:
            return ReLU(self.beta)
        elif type(layer) == nn.BatchNorm2d:
            new_layer = Batch_Norm2d(num_features=layer.num_features, eps=layer.eps, momentum=layer.momentum,
                                     affine=layer.affine, track_running_stats=layer.track_running_stats)
            inherit_bn(new_layer.batch_norm, layer)
        elif type(layer) == resnet.Bottleneck:
            planes = layer.conv1.weight.data.size(0)
            inplanes = layer.conv1.weight.data.size(1)
            if layer.downsample is None:
                downsample = None
            else:
                downsample = Downsample(inplanes=inplanes, planes=planes, expansion=resnet.Bottleneck.expansion,
                                        stride=layer.stride)
                downsample.conv.conv.weight.data = layer.downsample.__getattr__('0').weight.data
                inherit_bn(downsample.bn.batch_norm, layer.downsample.__getattr__('1'))

            new_layer = Bottleneck(inplanes=inplanes, planes=planes, stride=layer.stride, downsample=downsample,
                                   beta=self.beta)
            new_layer.conv1.conv.weight.data = layer.conv1.weight.data
            new_layer.conv2.conv.weight.data = layer.conv2.weight.data
            new_layer.conv3.conv.weight.data = layer.conv3.weight.data
            inherit_bn(new_layer.bn1.batch_norm, layer.bn1)
            inherit_bn(new_layer.bn2.batch_norm, layer.bn2)
            inherit_bn(new_layer.bn3.batch_norm, layer.bn3)
        else:
            print(layer)
            print('ERROR: unknown layer')
            return None

        return new_layer

    def change_beta(self, beta):
        self.beta_activation = beta
        for layer in self.layers:
            if hasattr(layer, "beta"):
                layer.beta = beta

    def forward(self, x):
        self.X = x
        for layer in self.layers:
            x = layer.forward(x)
        self.output=x
        #self.R = x

        return x


    def classify(self, x):
        outputs = self.forward(x)
        return F.softmax(outputs, dim=1), torch.max(outputs, 1)[1]

    def analyze(self, method='gbp', R=None, index=None, no_aggr=False):
        # self.eval()
        if R is None and self.R is not None:
            R = self.R
        elif index is not None:
            R = torch.eye(self.output.shape[1])[index].to(self.output.device)

        for layer in reversed(self.layers):
            if type(layer) == nn.Dropout or type(layer) == nn.Dropout2d:  # ignore Dropout layer
                continue
            R = layer.analyze(method, R)
        if method == 'gbp':
            return torch.abs(R)
        if method == 'our' or method=='our_no_input':
            R = torch.sum(F.relu(R), dim=1, keepdim=True)
            #R = F.relu(R * torch.sign(torch.abs(self.X)) / (self.X + 1e-8))
            #R = torch.mean(R, dim=1, keepdim=True)
            if no_aggr:
                # R=F.relu6(R)
                if method == 'our_no_input':
                    R = 0
                # R=standarize(R)
                R=[R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            #print(module)
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad),
                                                         size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R.append(gradient)
                            #else:
                            #    gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                            #                                     1)

                            # R += gradient.sum(1, keepdim=True)

                            # R+=standarize(gradient)
            else:
                #R=F.relu6(R)
                if method=='our_no_input':
                    R=0
                #R=standarize(R)
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            #print(module)
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad),
                                                         size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R += gradient
                            #else:
                            #    gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                            #                                     1)

                            # R += gradient.sum(1, keepdim=True)

                            #R+=standarize(gradient)
        if method == 'our+':
            # R = R * torch.sign(torch.abs(self.X)) / (self.X + 1e-8)
            R = torch.sum(F.relu(R), dim=1, keepdim=True)
            #R = F.relu(R)
            #print('right')
            if no_aggr:
                R = [R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module)
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad), size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R.append(gradient)
                            #else:
                            #    gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,1)
                                #continue

                            # R += gradient.sum(1, keepdim=True)

                            # R+=standarize(gradient)
            else:
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad), size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R += gradient
                            #else:
                            #    gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,1)
                                #continue
                            # R += gradient.sum(1, keepdim=True)

        if method == 'fullgrad':
            def _postProcess(input, eps=1e-6):
                # Absolute value

                input = abs(input)

                # Rescale operations to ensure gradients lie between 0 and 1

                flatin = input.view((input.size(0), -1))

                temp, _ = flatin.min(1, keepdim=True)

                input = input - temp.unsqueeze(1).unsqueeze(1)

                flatin = input.view((input.size(0), -1))

                temp, _ = flatin.max(1, keepdim=True)

                input = input / (temp.unsqueeze(1).unsqueeze(1) + eps)

                return input

            # R = _postProcess(R)
            if no_aggr:
                R = [R]
                names = ['X']
                for name, module in self.named_modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:

                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                                                                 1)

                            R.append(gradient)
            else:
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                                                                 1)

                            # R += gradient.sum(1, keepdim=True)
                            R += gradient
        return R

class ExplainableNet_ResNet(nn.Module):
    def __init__(self, model=None,method='gbp', beta=None,BN=True):
        super(ExplainableNet_ResNet, self).__init__()

        # replace relus by differentiable counterpart for beta growth
        self.beta = beta
        self.fc = None
        self.layers = nn.ModuleList([])
        self.BN=BN
        if model is not None:
            if BN:
                self.fill_layers(model)
            else:
                self.fill_layers_nobn(model)

        # remove activation function in last layer
        self.layers[-1].activation_fn = None

        self.X = 0
        self.output=0
        self.h=0
        self.R=None

    def fill_layers_nobn(self, model):
        new_layer=self.create_layer(model.conv1)
        self.layers.append(new_layer)
        self.layers.append(ReLU(self.beta))
        if hasattr(model,'maxpool'):
            new_layer = self.create_layer(model.maxpool)
            self.layers.append(new_layer)

        for layer in model.layer1:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

        for layer in model.layer2:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

        for layer in model.layer3:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

        if hasattr(model,'layer4'):
            for layer in model.layer4:
                new_layer = self.create_layer(layer)
                if new_layer == 0:
                    continue
                self.layers.append(new_layer)

        new_layer = self.create_layer(model.avgpool)
        self.layers.append(new_layer)
        if hasattr(model,'fc'):
            new_layer = self.create_layer(model.fc)
        else:
            new_layer = self.create_layer(model.linear)
        self.fc=new_layer

    def fill_layers(self, model):
        new_layer=self.create_layer(model.conv1)
        self.layers.append(new_layer)
        new_layer=self.create_layer(model.bn1)
        self.layers.append(new_layer)
        self.layers.append(ReLU(self.beta))
        if hasattr(model,'maxpool'):
            new_layer = self.create_layer(model.maxpool)
            self.layers.append(new_layer)

        for layer in model.layer1:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

        for layer in model.layer2:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

        for layer in model.layer3:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

        if hasattr(model,'layer4'):
            for layer in model.layer4:
                new_layer = self.create_layer(layer)
                if new_layer == 0:
                    continue
                self.layers.append(new_layer)

        new_layer = self.create_layer(model.avgpool)
        self.layers.append(new_layer)
        if hasattr(model,'fc'):
            new_layer = self.create_layer(model.fc)
        else:
            new_layer = self.create_layer(model.linear)
        self.fc=new_layer


    def create_layer(self, layer):
        def inherit_bn(new_bn,old_bn):
            new_bn.weight.data = old_bn.weight.data
            new_bn.bias.data = old_bn.bias.data
            new_bn.running_var = old_bn.running_var
            new_bn.running_mean = old_bn.running_mean
        if type(layer) == torch.nn.Conv2d:
            new_layer = Convolutional(in_channels=layer.in_channels,
                                      out_channels=layer.out_channels,
                                      kernel_size=layer.kernel_size,
                                      stride=layer.stride,
                                      padding=layer.padding,
                                     bias=layer.bias is not None)
            new_layer.conv.weight.data = layer.weight.data
            if layer.bias is  not None:
                new_layer.conv.bias.data = layer.bias.data

        elif type(layer) == nn.MaxPool2d:
            new_layer = MaxPool(kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding)
        elif type(layer) ==nn.AvgPool2d:
            new_layer=AvgPool(kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding)
        elif type(layer)==nn.AdaptiveAvgPool2d:
            new_layer=AdaptiveAvgPool(layer.output_size)
        elif type(layer) == nn.Linear:
            new_layer = Dense(in_dim=layer.in_features, out_dim=layer.out_features)
            new_layer.linear.weight.data = layer.weight.data
            new_layer.linear.bias.data = layer.bias.data

        elif type(layer) == (nn.Dropout or nn.Dropout2d):
            new_layer = layer

        elif type(layer) == nn.ReLU:
            return ReLU(self.beta)
        elif type(layer)==nn.BatchNorm2d:
            new_layer=Batch_Norm2d(num_features=layer.num_features,eps=layer.eps,momentum=layer.momentum,
                                   affine=layer.affine,track_running_stats=layer.track_running_stats)
            inherit_bn(new_layer.batch_norm,layer)
        elif type(layer)==resnet.Bottleneck or type(layer)==resnet_nobn.Bottleneck:
            planes=layer.conv1.weight.data.size(0)
            inplanes=layer.conv1.weight.data.size(1)
            if layer.downsample is None:
                downsample=None
            else:
                downsample=Downsample(inplanes=inplanes,planes=planes,expansion=resnet.Bottleneck.expansion,stride=layer.stride,BN=self.BN)
                downsample.conv.conv.weight.data=layer.downsample.__getattr__('0').weight.data
                if  layer.downsample.__getattr__('0').bias is not None:
                    downsample.conv.conv.bias.data = layer.downsample.__getattr__('0').bias.data
                if self.BN:
                    inherit_bn(downsample.bn.batch_norm,layer.downsample.__getattr__('1'))


            new_layer=Bottleneck(inplanes=inplanes,planes=planes,stride=layer.stride,downsample=downsample,beta=self.beta,BN=self.BN)
            new_layer.conv1.conv.weight.data=layer.conv1.weight.data
            new_layer.conv2.conv.weight.data = layer.conv2.weight.data
            new_layer.conv3.conv.weight.data = layer.conv3.weight.data
            if layer.conv1.bias is not None:
                new_layer.conv1.conv.bias.data = layer.conv1.bias.data
                new_layer.conv2.conv.bias.data = layer.conv2.bias.data
                new_layer.conv3.conv.bias.data = layer.conv3.bias.data
            if self.BN:
                inherit_bn(new_layer.bn1.batch_norm,layer.bn1)
                inherit_bn(new_layer.bn2.batch_norm, layer.bn2)
                inherit_bn(new_layer.bn3.batch_norm, layer.bn3)
        elif type(layer)==resnet.BasicBlock or type(layer)==resnet_nobn.BasicBlock:
            planes=layer.conv1.weight.data.size(0)
            inplanes=layer.conv1.weight.data.size(1)
            if layer.downsample is None:
                downsample=None
            else:
                downsample=Downsample(inplanes=inplanes,planes=planes,expansion=resnet.BasicBlock.expansion,stride=layer.stride,BN=self.BN)
                downsample.conv.conv.weight.data=layer.downsample.__getattr__('0').weight.data
                if  layer.downsample.__getattr__('0').bias is not None:
                    downsample.conv.conv.bias.data = layer.downsample.__getattr__('0').bias.data
                if self.BN:
                    inherit_bn(downsample.bn.batch_norm,layer.downsample.__getattr__('1'))

            new_layer=BasicBlock(inplanes=inplanes,planes=planes,stride=layer.stride,downsample=downsample,beta=self.beta,BN=self.BN)
            new_layer.conv1.conv.weight.data=layer.conv1.weight.data
            new_layer.conv2.conv.weight.data = layer.conv2.weight.data
            if layer.conv1.bias is not None:
                new_layer.conv1.conv.bias.data = layer.conv1.bias.data
                new_layer.conv2.conv.bias.data = layer.conv2.bias.data
            if self.BN:
                inherit_bn(new_layer.bn1.batch_norm,layer.bn1)
                inherit_bn(new_layer.bn2.batch_norm, layer.bn2)
        else:
            print(layer)
            print('ERROR: unknown layer')
            return None

        return new_layer

    def change_beta(self, beta):
        self.beta_activation = beta
        for layer in self.layers:
            if hasattr(layer, "beta"):
                layer.beta=beta

    def forward(self, x):
        self.X = x
        for layer in self.layers:
            x = layer.forward(x)
        self.h=x
        #print(self.fc)
        x=self.fc(x)
        return x


    def classify(self, x):
        outputs = self.forward(x)
        self.output=outputs
        return F.softmax(outputs, dim=1), torch.max(outputs, 1)[1]

    def analyze(self, method='gbp', R=None, index=None,no_aggr=False):
        #self.eval()
        if R is None:
            R = self.R
        if index is not None:
            R = torch.eye(self.fc.linear.weight.size(0))[index].to(self.X.device)
        if method=='midgrad'or method=='gbp_all' and index is None:
            R=torch.ones_like(self.h).to(self.X.device)
        else:
            R=self.fc.analyze(method,R)
        for layer in reversed(self.layers):
            if type(layer) == nn.Dropout or type(layer) == nn.Dropout2d:  # ignore Dropout layer
                continue

            R = layer.analyze(method, R)
        if method=='gbp':
            return torch.abs(R)
        if method == 'our' or method=='our_no_input':
            #R = R * torch.sign(torch.abs(self.X)) / (self.X + 1e-8)
            #R = torch.mean(R, dim=1, keepdim=True)
            R = torch.sum(F.relu(R), dim=1, keepdim=True)
            #R=R.abs()
            if no_aggr:
                R = [R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module)
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad), size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R.append(gradient)
                            #else:
                            #gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,1)


                            # R += gradient.sum(1, keepdim=True)

                            # R+=standarize(gradient)
            else:
                if method=='our_no_input':
                    R=0
                for module in self.modules():
                    if hasattr(module,'biasXgrad'):
                        if module.biasXgrad is not None:
                            #print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad), size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R += gradient.sum(1, keepdim=True)
                                #R += gradient
                            #else:
                                #gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1, 1)




        if method == 'our+':
            # R = R * torch.sign(torch.abs(self.X)) / (self.X + 1e-8)
            R = torch.sum(F.relu(R), dim=1, keepdim=True)
            #R = F.relu(R)
            #print('right')
            if no_aggr:
                R = [R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module)
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad), size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                R.append(gradient)
                            #else:
                            #    gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,1)
                            #continue

                            #R += gradient.sum(1, keepdim=True)

                            # R+=standarize(gradient)
            else:
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate((module.biasXgrad), size=(self.X.size(2), self.X.size(3)),
                                                         mode='bicubic',
                                                         align_corners=True)
                                #R += gradient
                                R += gradient.sum(1, keepdim=True)
                            #else:
                            #    gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,1)



        if method=='fullgrad':
            def _postProcess( input, eps=1e-6):
                # Absolute value

                input = abs(input)

                # Rescale operations to ensure gradients lie between 0 and 1

                flatin = input.view((input.size(0), -1))

                temp, _ = flatin.min(1, keepdim=True)

                input = input - temp.unsqueeze(1).unsqueeze(1)

                flatin = input.view((input.size(0), -1))

                temp, _ = flatin.max(1, keepdim=True)

                input = input / (temp.unsqueeze(1).unsqueeze(1) + eps)

                return input

            #R = _postProcess(R)
            if no_aggr:
                R = [R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:

                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                                                                 1)

                            R.append(gradient)
            else:
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                                                                 1)

                            # R += gradient.sum(1, keepdim=True)
                            R += gradient

        if method=='midgrad' or method=='gbp_all':
            R=[R]
            for module in self.modules():
                if hasattr(module, 'grad'):
                    if module.grad is not None:
                        # print(module.biasXgrad.dim())
                        R.append(module.grad)
                        module.grad=None
            #print(len(R))
        return R

class ExplainableNet_Simple(nn.Module):
    def __init__(self, model=None,method='gbp', beta=None):
        super(ExplainableNet_Simple, self).__init__()

        # replace relus by differentiable counterpart for beta growth
        self.beta = beta

        self.layers = nn.ModuleList([])

        if model is not None:
            self.fill_layers(model)

        # remove activation function in last layer
        self.layers[-1].activation_fn = None

        self.X = 0
        self.output = 0

    def fill_layers(self, model):

        for layer in model.main:
            new_layer = self.create_layer(layer)
            if new_layer == 0:
                continue
            self.layers.append(new_layer)

    def create_layer(self, layer):
        def inherit_bn(new_bn, old_bn):
            new_bn.weight.data = old_bn.weight.data
            new_bn.bias.data = old_bn.bias.data
            new_bn.running_var = old_bn.running_var
            new_bn.running_mean = old_bn.running_mean
        if type(layer) == torch.nn.Conv2d:
            new_layer = Convolutional(in_channels=layer.in_channels,
                                      out_channels=layer.out_channels,
                                      kernel_size=layer.kernel_size,
                                      stride=layer.stride,
                                      padding=layer.padding,
                                      bias=layer.bias)
            new_layer.conv.weight.data = layer.weight.data
            if layer.bias:
                new_layer.conv.bias.data = layer.bias.data

        elif type(layer) == nn.MaxPool2d:
            new_layer = MaxPool(kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding)
        elif type(layer) == nn.AvgPool2d:
            new_layer = AvgPool(kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding)
        elif type(layer) == nn.AdaptiveAvgPool2d:
            new_layer = AdaptiveAvgPool(layer.output_size)
        elif type(layer) == nn.Linear:
            if layer.bias is not None:
                new_layer = Dense(in_dim=layer.in_features, out_dim=layer.out_features)
                new_layer.linear.weight.data = layer.weight.data
                new_layer.linear.bias.data = layer.bias.data
            else:
                new_layer = Dense(in_dim=layer.in_features, out_dim=layer.out_features,bias=False)
                new_layer.linear.weight.data = layer.weight.data

        elif type(layer) == (nn.Dropout or nn.Dropout2d):
            new_layer = layer

        elif type(layer) == nn.ReLU or type(layer)==nn.Softplus:
            return ReLU(self.beta)
        elif type(layer) == nn.BatchNorm2d:
            new_layer = Batch_Norm2d(num_features=layer.num_features, eps=layer.eps, momentum=layer.momentum,
                                     affine=layer.affine, track_running_stats=layer.track_running_stats)
            inherit_bn(new_layer.batch_norm, layer)
        elif type(layer) == nn.BatchNorm1d:
            new_layer = Batch_Norm1d(num_features=layer.num_features, eps=layer.eps, momentum=layer.momentum,
                                     affine=layer.affine, track_running_stats=layer.track_running_stats)
            inherit_bn(new_layer.batch_norm, layer)
        elif type(layer) == resnet.Bottleneck:
            planes = layer.conv1.weight.data.size(0)
            inplanes = layer.conv1.weight.data.size(1)
            if layer.downsample is None:
                downsample = None
            else:
                downsample = Downsample(inplanes=inplanes, planes=planes, expansion=resnet.Bottleneck.expansion,
                                        stride=layer.stride)
                downsample.conv.conv.weight.data = layer.downsample.__getattr__('0').weight.data
                inherit_bn(downsample.bn.batch_norm, layer.downsample.__getattr__('1'))

            new_layer = Bottleneck(inplanes=inplanes, planes=planes, stride=layer.stride, downsample=downsample,
                                   beta=self.beta)
            new_layer.conv1.conv.weight.data = layer.conv1.weight.data
            new_layer.conv2.conv.weight.data = layer.conv2.weight.data
            new_layer.conv3.conv.weight.data = layer.conv3.weight.data
            inherit_bn(new_layer.bn1.batch_norm, layer.bn1)
            inherit_bn(new_layer.bn2.batch_norm, layer.bn2)
            inherit_bn(new_layer.bn3.batch_norm, layer.bn3)
        elif type(layer)==nn.Flatten:
            new_layer=layer
        else:
            print(layer)
            print('ERROR: unknown layer')
            return None

        return new_layer

    def change_beta(self, beta):
        self.beta_activation = beta
        for layer in self.layers:
            if hasattr(layer, "beta"):
                layer.beta = beta

    def forward(self, x):
        self.X = x
        for layer in self.layers:
            x = layer.forward(x)

        #self.R = x

        return x


    def classify(self, x):
        outputs = self.forward(x)
        return F.softmax(outputs, dim=1), torch.max(outputs, 1)[1]

    def analyze(self, method='gbp', R=None, index=None, no_aggr=False):
        # self.eval()
        if R is None:
            R = self.R
        elif index is not None:
            R = torch.eye(self.output.shape[1])[index].to(self.output.device)

        for layer in reversed(self.layers):
            if type(layer) == nn.Dropout or type(layer) == nn.Dropout2d or type(layer)==nn.Flatten:  # ignore Dropout layer
                continue

            R = layer.analyze(method, R)
        eps=1e-3
        if method == 'our':
            R = F.relu(R * torch.sign(F.relu(torch.abs(self.X)-1e-3)) / (self.X + 1e-8))
            if no_aggr:
                R = [R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:

                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                                                                 1)

                            R.append(gradient)
            else:
                R = torch.mean(R, dim=1, keepdim=True)
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1), 1,
                                                                 1)

                            # R += gradient.sum(1, keepdim=True)
                            R += gradient
        if method == 'fullgrad':
            def _postProcess(input, eps=1e-6):
                # Absolute value

                input = abs(input)

                # Rescale operations to ensure gradients lie between 0 and 1
                flatin = input.view((input.size(0), -1))

                temp, _ = flatin.min(1, keepdim=True)
                #temp = flatin.min()
                input = input - temp

                flatin = input.view((input.size(0), -1))

                temp, _ = flatin.max(1, keepdim=True)
                #temp = flatin.max()
                input = input / (temp + eps)

                return input
            R=_postProcess(R * self.X)
            #R=R * self.X
            if no_aggr:
                R = [R]
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1))

                            R.append(gradient)
                            #R.append(gradient)
            else:
                for module in self.modules():
                    if hasattr(module, 'biasXgrad'):
                        if module.biasXgrad is not None:
                            # print(module.biasXgrad.dim())
                            if module.biasXgrad.dim() == 4:
                                gradient = F.interpolate(module.biasXgrad, size=(self.X.size(2), self.X.size(3)),
                                                         mode='bilinear',
                                                         align_corners=True)
                            else:
                                gradient = module.biasXgrad.view(module.biasXgrad.size(0), module.biasXgrad.size(1))

                            # R += gradient.sum(1, keepdim=True)
                            R += _postProcess(gradient)
        return R

