from attribution_methods import Guided_BackProp,Grad_CAM

import torch
import torch.nn.functional as F
import numpy as np

class Explainer():
    def __init__(self,model,nclass=1000):
        self.model=model
        self.explain_0=Grad_CAM.Explainer(model)
        self.explain_1=Guided_BackProp.Explainer(model)


    def get_attribution_map(self,img,target=None):
        if target is None:
            target=torch.argmax(self.model(img),1)
        temp0 = self.explain_0.get_attribution_map(img, target=target)
        temp1 = self.explain_1.get_attribution_map(img, target=target)
        temp1=cut_most(temp1,1)
        if temp0.size(2)<temp1.size(2):
            temp0=F.interpolate(temp0, (temp1.size(2), temp1.size(3)),mode='bilinear')
        attributions=temp0*temp1
        return attributions

def cut_most(heatmaps,p=1):
    p=100-p
    heatmaps=heatmaps.float()
    rawgrad = heatmaps.view(heatmaps.size(0), -1).cpu().numpy()
    temp = np.percentile(rawgrad, p, axis=1)

    if len(list(heatmaps.size())) == 4:
        temp = torch.FloatTensor(temp).cuda().view(heatmaps.size(0), 1, 1, 1)
        # temp2 = torch.Tensor(temp2).cuda().view(grad_out[0].size(0), 1, 1, 1)
    elif len(list(heatmaps.size())) == 3:
        temp = torch.FloatTensor(temp).cuda().view(heatmaps.size(0), 1,1)
        # temp2 = torch.Tensor(temp2).cuda().view(grad_out[0].size(0), 1)

    return torch.where(torch.le(heatmaps, temp), heatmaps, temp)/temp