import ADGT
import torch
import numpy as np
import cv2
import os
from PIL import Image
from torchvision import transforms
from model.inception import inception_v3
from model.resnet import resnet50
from model.vgg import vgg16
import torchvision
from utils.visualization import save_images
import torch.nn as nn
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
MODEL='vgg16'#'resnet50'#'inception_v3'#
method=['Our_GBP','our_method','our_fullgrad','GradCAM','InputXGradient','Saliency','DeepLIFT','RectGrad','IntegratedGradients']
torch.set_num_threads(4)
layer_names=[]
if MODEL=='vgg16':
    layer_names=['classifier','features.28','features.26','features.24','features.21','features.19','features.17'
        ,'features.14','features.12','features.10','features.7','features.5','features.2','features.0',]
elif MODEL=='resnet50':
    layer_names=['fc','layer4','layer3','layer2','layer1','conv1']
def perturbation(model,layername):
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.xavier_normal_(m.weight.data)
        elif classname.find('Linear') != -1:
            nn.init.xavier_normal_(m.weight.data)
            m.bias.data.fill_(0)
    #wi=functools.partial(weights_init,layername=layername)
    temp=layername.split('.')
    print(temp[0])
    if len(temp)==1:
        model.__getattr__(temp[0]).apply(weights_init)
    else:
        print(temp[1])
        model.__getattr__(temp[0]).__getattr__(temp[1]).apply(weights_init)
        '''
        print(model.__getattr__(temp[0]).__getattr__(temp[1]))
        for k, v in target.state_dict().items():
            if k==temp[1]+'.weight':
                print(k)
                nn.init.xavier_normal_(v.data)
        '''
    return model
def get_mask( img, model, method, train_loader=None,root='result'):
    color = False
    random = None
    if not os.path.exists(root):
        os.mkdir(root)
    if not os.path.exists(root+'/'+method):
        os.mkdir(root+'/'+method)

    def obtain_explain(alg, random=None, train_loader=None):
        if train_loader is None:
            obj = alg.Explainer(model, nclass=1000)
        else:
            obj = alg.Explainer(model, train_loader)
        # result=img.clone().cpu().numpy()
        pred = model(img)
        _, topklabel = torch.topk(pred, 3)
        result = []
        for i in range(topklabel.size(1)):
            templabel = topklabel[:, i]
            mask = obj.get_attribution_map(img.clone(), templabel)
            if not color:
                mask = torch.mean(torch.abs(mask), 1, keepdim=True)  # .cpu()+torch.zeros_like(img).cpu()
            if mask.requires_grad:
                mask = mask.detach()
            mask = mask.cpu().numpy()
            save_images(mask,root+'/'+method+'/'+str(i)+'.jpg')
            result.append(mask)
        cat_dog=np.concatenate(tuple(result),0)
        cat_dog=cat_dog-np.mean(cat_dog,0,keepdims=True)
        save_images(cat_dog,root+'/'+method+'/'+'removemean.jpg')
        # print(result.shape,condition.shape)
        return result, None

    if method == 'GradientSHAP':
        from attribution_methods import GradientSHAP
        mask, mask_random = obtain_explain(GradientSHAP, random)
    elif method == 'DeepLIFTSHAP':
        from attribution_methods import DeepLIFTSHAP
        mask, mask_random = obtain_explain(DeepLIFTSHAP, random)
    elif method == 'Guided_BackProp':
        from attribution_methods import Guided_BackProp
        mask, mask_random = obtain_explain(Guided_BackProp, random)
    elif method == 'DeepLIFT':
        from attribution_methods import DeepLIFT
        mask, mask_random = obtain_explain(DeepLIFT, random)
    elif method == 'IntegratedGradients':
        from attribution_methods import IntegratedGradients
        # model=nn.DataParallel(model)
        mask, mask_random = obtain_explain(IntegratedGradients, random)
    elif method == 'InputXGradient':
        from attribution_methods import InputXGradient
        mask, mask_random = obtain_explain(InputXGradient, random)
    elif method == 'Occlusion':
        from attribution_methods import Occlusion
        mask, mask_random = obtain_explain(Occlusion, random)
    elif method == 'Saliency':
        from attribution_methods import Saliency
        mask, mask_random = obtain_explain(Saliency, random)
    elif method == 'GradCAM':
        from attribution_methods import Grad_CAM, Grad_CAM_batch
        mask, mask_random = obtain_explain(Grad_CAM, random)
        # mask, mask_random = obtain_explain(Grad_CAM_batch, random)
    elif method == 'SmoothGrad':
        from attribution_methods import SmoothGrad
        mask, mask_random = obtain_explain(SmoothGrad, random)
    elif method == 'RectGrad':
        from attribution_methods import RectGrad
        mask, mask_random = obtain_explain(RectGrad, random)
    elif method == 'AIR':
        from attribution_methods import AttrInvRec
        mask, mask_random = obtain_explain(AttrInvRec, random)
    elif method == 'FullGrad':
        from attribution_methods import Full_Grad
        mask, mask_random = obtain_explain(Full_Grad, random)
    elif method == 'FGour':
        from attribution_methods import FGour
        mask, mask_random = obtain_explain(FGour, random)
    elif method == 'Our_GBP':
        from attribution_methods import Our_GBP
        mask, mask_random = obtain_explain(Our_GBP, random)
    elif method == 'our_method':
        from attribution_methods import our_method
        mask, mask_random = obtain_explain(our_method, random)
    elif method == 'our_fullgrad':
        from attribution_methods import our_fullgrad
        mask, mask_random = obtain_explain(our_fullgrad, random)
    else:
        print('no this method')

def prepare_img(img_path):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size=224

    input_image = Image.open(img_path).convert('RGB')
    preprocess = transforms.Compose([
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
    input_tensor = preprocess(input_image)
    imgs = input_tensor.unsqueeze(0)
    preprocessed_imgs=imgs
    preprocessed_imgs=preprocessed_imgs.cuda()
    return preprocessed_imgs
def prepare_model(model_name=MODEL):
    if model_name=='inception_v3':
        #model=inception_v3(pretrained=True,transform_input=True)
        model=torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    elif model_name=='resnet50':
        model=resnet50(pretrained=True)
    elif model_name=='vgg16':
        model=vgg16(pretrained=True)
        #model=torch.nn.DataParallel(model)
    model=model.cuda()

    print(model)
    return model
img=prepare_img(os.path.join('sanity_checks_saliency','data','demo_images','ILSVRC2012_val_00015410.JPEG'))
net=prepare_model()
net.eval()
import copy
temp_net=copy.deepcopy(net)
if not os.path.exists('result'):
    os.mkdir('result')
    os.mkdir('result/sanitycheck')
    os.mkdir('result/sanitycheck/'+MODEL)

pth=os.path.join('result','sanitycheck',MODEL,'raw')
for m in method:
    get_mask(img,temp_net,m,root=pth)
for name in layer_names:
    temp_net=perturbation(temp_net,name)
    pth=os.path.join('result','sanitycheck',MODEL,name)
    for m in method:
        get_mask(img,temp_net,m,root=pth)