#from captum.attr import GuidedBackprop
import torch
from .our import networks as nw
Beta=50
class Explainer():
    def __init__(self,model,nclass=1000):
        import copy
        model = copy.deepcopy(model)
        if hasattr(model,'layer1'):
            self.model = nw.ExplainableNet_ResNet(model, method='fullgrad', beta=Beta)
        else:
            self.model=nw.ExplainableNet(model,method='fullgrad',beta=Beta)
        self.model.eval()
        self.nclass=nclass


    def get_attribution_map(self,img,target=None):
        output = self.model(img)
        if target is None:
            target = output.argmax(dim=1, keepdim=False)  # get the index of the max log-probability
        out_rel = torch.eye(output.shape[1])[target].to(img.device)
        one_hot = out_rel
        attributions = self.model.analyze("fullgrad", one_hot)
        return attributions
