import ADGT
import torch
import numpy as np
import cv2
import os
from PIL import Image
from torchvision import transforms
from model.inception import inception_v3
from attribution_methods.our.model.resnet import resnet50
from attribution_methods.our.model.vgg import vgg16_bn
from attribution_methods.our.model.canet import canet11_imagenet, canet11_imagenet_no_bn
import torchvision
import torch.nn as nn
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
MODEL = 'vgg16'#'resnet50'#'canet11'  # 'inception_v3'#
method = ['Saliency','GBP','RectGrad','our' ]  # ,'FullGrad','GradCAM','DeepLIFT','our_fullgrad','Our_GBP','our_method',
# method=['IntegratedGradients','InputXGradient','Saliency','RectGrad','SmoothGrad']#

torch.set_num_threads(4)
# layer_names=['fc','Mixed_7c','Mixed_7b','Mixed_7a','Mixed_6e','Mixed_6d','Mixed_6c','Mixed_6b','Mixed_6a',
#           'Mixed_5d','Mixed_5c','Mixed_5b','Conv2d_4a_3x3','Conv2d_3b_1x1','Conv2d_2b_3x3','Conv2d_2a_3x3',
#           'Conv2d_1a_3x3']
# layer_names=['classifier','features.28','features.14','features.0',]
#layer_names = ['fc', 'layer4', 'layer3', 'layer2', 'layer1', 'conv1']
use_cuda = True
DATASET_NAME = 'ImageNet'
ROOT = 'result'
AUG = True
if not os.path.exists(ROOT):
    os.mkdir(ROOT)

def gbp(img, net,y=None):
    handle = []

    activation_maps = []
    def forward_hook_fn(module, input, output):
        # 在全局变量中保存 ReLU 层的前向传播输出
        # 用于将来做 guided backpropagation
        activation_maps.append(output)

    def backward_hook_fn(module, grad_in, grad_out):
        # ReLU 层反向传播时，用其正向传播的输出作为 guide
        # 反向传播和正向传播相反，先从后面传起
        grad = activation_maps.pop()
        grad[grad > 0] = 1

        positive_grad_out=grad_out[0]
        positive_grad_out = F.relu(positive_grad_out)
        # 创建新的输入端梯度
        new_grad_in = positive_grad_out * grad

        # ReLU 不含 parameter，输入端梯度是一个只有一个元素的 tuple
        return (new_grad_in,)

    for m in net.modules():
        if isinstance(m, nn.ReLU):
            h = m.register_backward_hook(backward_hook_fn)
            handle.append(h)
            h = m.register_forward_hook(forward_hook_fn)
            handle.append(h)

    img.requires_grad_(True)
    out,feature = net(img,feature=True)
    if y is None:
        loss=torch.sum(feature)
    else:
        iii = torch.LongTensor(range(img.size(0))).cuda()
        loss = torch.sum(out[iii,y])
    #out=F.log_softmax(out,1)
    grad = torch.autograd.grad(loss, img)[0]
    for h in handle:
        h.remove()
    return grad,loss

def saliency(img, net,y=None):

    img.requires_grad_(True)
    out,feature = net(img,feature=True)
    if y is None:
        loss=torch.sum(feature)
    else:
        iii = torch.LongTensor(range(img.size(0))).cuda()
        loss = torch.sum(out[iii,y])
    #out=F.log_softmax(out,1)
    grad = torch.autograd.grad(loss, img)[0]

    return grad,loss

def rectgrad(img, net,y=None,p=50):
    handle = []

    activation_maps = []
    def forward_hook_fn(module, input, output):
        # 在全局变量中保存 ReLU 层的前向传播输出
        # 用于将来做 guided backpropagation
        activation_maps.append(output)

    def backward_hook_fn(module, grad_in, grad_out):
        # ReLU 层反向传播时，用其正向传播的输出作为 guide
        # 反向传播和正向传播相反，先从后面传起
        grad = activation_maps.pop()
        # ReLU 正向传播的输出要么大于0，要么等于0，
        # 大于 0 的部分，梯度为1，
        # 等于0的部分，梯度还是 0

        #print(grad_out[0], p)
        #print(grad.size(),grad_out[0].size())
        temp =np.percentile((grad*grad_out[0]).view(grad_out[0].size(0),-1).data.cpu().numpy(),p,axis=1)
        if len(grad.size())==2:
            temp = torch.Tensor(temp).cuda().view(grad_out[0].size(0), 1)
        else:
            temp=torch.Tensor(temp).cuda().view(grad_out[0].size(0),1,1,1)
        #print(temp)
        # grad_out[0] 表示 feature 的梯度，只保留大于 0 的部分

        positive_grad_out=grad_out[0]
        positive_grad_out = torch.where(torch.ge(positive_grad_out, temp), positive_grad_out,torch.zeros_like(positive_grad_out))
        grad[grad > 0]=1
        # 创建新的输入端梯度
        new_grad_in = positive_grad_out

        # ReLU 不含 parameter，输入端梯度是一个只有一个元素的 tuple
        return (new_grad_in,)

    for m in net.modules():
        if isinstance(m, nn.ReLU):
            h = m.register_backward_hook(backward_hook_fn)
            handle.append(h)
            h = m.register_forward_hook(forward_hook_fn)
            handle.append(h)

    img.requires_grad_(True)
    out,feature = net(img,feature=True)
    if y is None:
        loss=torch.sum(feature)
    else:
        iii = torch.LongTensor(range(img.size(0))).cuda()
        loss = torch.sum(out[iii,y])
    #out=F.log_softmax(out,1)
    grad = torch.autograd.grad(loss, img)[0]
    grad=F.relu(grad*img.detach())
    for h in handle:
        h.remove()
    return grad,loss

def our(img, net,y=None):
    handle = []

    activation_maps = []
    def forward_hook_fn(module, input, output):
        # 在全局变量中保存 ReLU 层的前向传播输出
        # 用于将来做 guided backpropagation
        activation_maps.append(input)

    def backward_hook_fn(module, grad_in, grad_out):
        # ReLU 层反向传播时，用其正向传播的输出作为 guide
        # 反向传播和正向传播相反，先从后面传起
        grad = activation_maps.pop()
        # ReLU 正向传播的输出要么大于0，要么等于0，
        # 大于 0 的部分，梯度为1，
        # 等于0的部分，梯度还是 0


        positive_grad_out=grad_out[0]
        #print(grad)
        R=grad[0]*positive_grad_out
        R=torch.sign(F.relu(R))
        # 创建新的输入端梯度
        new_grad_in = positive_grad_out * R

        # ReLU 不含 parameter，输入端梯度是一个只有一个元素的 tuple
        return (new_grad_in,)

    for m in net.modules():
        if isinstance(m, nn.ReLU):
            h = m.register_backward_hook(backward_hook_fn)
            handle.append(h)
            h = m.register_forward_hook(forward_hook_fn)
            handle.append(h)

    img.requires_grad_(True)
    out,feature = net(img,feature=True)
    if y is None:
        loss=torch.sum(feature)
    else:
        iii = torch.LongTensor(range(img.size(0))).cuda()
        loss = torch.sum(out[iii,y])
    #out=F.log_softmax(out,1)
    grad = torch.autograd.grad(loss, img)[0]
    for h in handle:
        h.remove()
    return grad,loss

def prepare_model(model_name=MODEL):
    if model_name == 'inception_v3':
        # model=inception_v3(pretrained=True,transform_input=True)
        model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    elif model_name == 'resnet50':
        model = resnet50(pretrained=True)
    elif model_name == 'vgg16':
        model = vgg16_bn(pretrained=True)
        # model=torch.nn.DataParallel(model)
    else:
        model = canet11_imagenet(cane=True)
    if use_cuda:
        model = model.cuda()

    print(model)
    # for name, module in model._modules.items():
    #    print(name)
    #    if "avgpool" == name.lower() or 'avg_pool' == name.lower():
    #        pass
    return model


def prepare_img(path):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size = 224

    dirs = os.listdir(path)
    imgs = []
    for fn in dirs:
        img_path = path + '/' + fn
        # img = cv2.imread(img_path, 1)
        # img = np.float32(cv2.resize(img, (img_size, img_size))) / 255
        input_image = Image.open(img_path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
        input_tensor = preprocess(input_image)
        img = input_tensor.unsqueeze(0)
        imgs.append(img)
    '''
    imgs = np.array(imgs)

    preprocessed_imgs = imgs.copy()[:, :, :, ::-1]
    for i in range(3):
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] - means[i]
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] / stds[i]
    #preprocessed_imgs=preprocessed_imgs*2-1
    preprocessed_imgs = \
        np.ascontiguousarray(np.transpose(preprocessed_imgs, (0, 3, 1, 2)))
    preprocessed_imgs = torch.from_numpy(preprocessed_imgs)
    '''
    preprocessed_imgs = torch.cat(imgs, 0)
    if use_cuda:
        preprocessed_imgs = preprocessed_imgs.cuda()
    return preprocessed_imgs


max_K = 1
net = prepare_model()
net.eval()
img = prepare_img(os.path.join('sanity_checks_saliency', 'data', 'demo_images'))
img = img[img.size(0) - max_K:]

pred = net(img)
_, topklabel = torch.topk(pred, max_K)
_, target = torch.topk(pred, 1)
print(img.size(),target)



from utils.visualization import save_images


def visualize(net, img, method,target=None):
    pth = os.path.join(ROOT, 'visualization')
    if not os.path.exists(pth):
        os.mkdir(pth)
    save_images(img.cpu().numpy(), os.path.join(pth, 'raw.png'))
    for m in method:
        if m=='Saliency':
            result,_=saliency(img,net,target)
        elif m=='GBP':
            result,_=gbp(img,net,target)
        elif m=='RectGrad':
            result, _ = rectgrad(img, net, target)
        elif m=='our':
            result,_=our(img,net,target)
        save_images(result.cpu().numpy(), os.path.join(pth, m+'.png'))



print('visualize')

visualize(net, img, method,target)
