from .CAMERAS_raw import CAMERAS

import torch
import torch.nn.functional as F
class Explainer():
    def __init__(self,model,nclass=1000):
        self.model=model
        name='features'
        if hasattr(model,'layer4'):
            name='layer4'
        elif hasattr(model,'layer3'):
            name='layer3'
        self.explain=CAMERAS(model,targetLayerName=name)
        self.model.eval()

    def get_attribution_map(self,img,target=None):
        if target is None:
            target=torch.argmax(self.model(img),1)
        attributions = []
        for i in range(img.size(0)):
            if len(target.size())==0:
                t=target
            else:
                t=target[i]
            attributions.append(F.relu(self.explain.attribute(img[i].unsqueeze(0),t)).detach())
        attributions=torch.cat(attributions,0)
        return attributions
