import ADGT
import torch
import numpy as np
import cv2
import os
from PIL import Image
from torchvision import transforms
from model.inception import inception_v3
from attribution_methods.our.model.resnet import resnet50
from model.vgg import vgg16_bn
from attribution_methods.our.model.canet import canet11_imagenet,canet11_imagenet_no_bn
import torchvision
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
MODEL='vgg16'#'resnet50'#'canet11'#'inception_v3'#
method=['Our_GBP','FullGrad','GradCAM','Saliency','our_method_minus_mean','GIG','CAMERAS']#'DeepLIFT','our_fullgrad','Our_GBP',
#method=['IntegratedGradients','InputXGradient','Saliency','RectGrad','SmoothGrad']#

torch.set_num_threads(4)
#layer_names=['fc','Mixed_7c','Mixed_7b','Mixed_7a','Mixed_6e','Mixed_6d','Mixed_6c','Mixed_6b','Mixed_6a',
#           'Mixed_5d','Mixed_5c','Mixed_5b','Conv2d_4a_3x3','Conv2d_3b_1x1','Conv2d_2b_3x3','Conv2d_2a_3x3',
#           'Conv2d_1a_3x3']
#layer_names=['classifier','features.28','features.14','features.0',]
layer_names=['fc','layer4','layer3','layer2','layer1','conv1']
use_cuda=True
DATASET_NAME='ImageNet'
ROOT='result'
AUG=True
if not os.path.exists(ROOT):
    os.mkdir(ROOT)

def prepare_model(model_name=MODEL):
    if model_name=='inception_v3':
        #model=inception_v3(pretrained=True,transform_input=True)
        model=torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    elif model_name=='resnet50':
        model=resnet50(pretrained=True)
    elif model_name=='vgg16':
        model=vgg16_bn(pretrained=True)
        #model=torch.nn.DataParallel(model)
    else:
        model=canet11_imagenet(cane=True)
    if use_cuda:
        model=model.cuda()

    print(model)
    #for name, module in model._modules.items():
    #    print(name)
    #    if "avgpool" == name.lower() or 'avg_pool' == name.lower():
    #        pass
    return model

def prepare_img(path):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size=224

    dirs = os.listdir(path)
    imgs = []
    for fn in dirs:
        img_path = path + '/' + fn
        #img = cv2.imread(img_path, 1)
        #img = np.float32(cv2.resize(img, (img_size, img_size))) / 255
        input_image = Image.open(img_path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
        input_tensor = preprocess(input_image)
        img = input_tensor.unsqueeze(0)
        imgs.append(img)
    '''
    imgs = np.array(imgs)

    preprocessed_imgs = imgs.copy()[:, :, :, ::-1]
    for i in range(3):
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] - means[i]
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] / stds[i]
    #preprocessed_imgs=preprocessed_imgs*2-1
    preprocessed_imgs = \
        np.ascontiguousarray(np.transpose(preprocessed_imgs, (0, 3, 1, 2)))
    preprocessed_imgs = torch.from_numpy(preprocessed_imgs)
    '''
    preprocessed_imgs=torch.cat(imgs,0)
    if use_cuda:
        preprocessed_imgs=preprocessed_imgs.cuda()
    return preprocessed_imgs

max_K=5
net=prepare_model()
net.eval()
img=prepare_img(os.path.join('sanity_checks_saliency','data','demo_images'))
img=img[img.size(0)-max_K:]
print(img.size())
pred=net(img)
_,topklabel=torch.topk(pred,max_K)
_,target=torch.topk(pred,1)


def class_independent(net,img,target,topklabel,method):
    pth=os.path.join(ROOT,'class_independent')
    adgt=ADGT.ADGT(use_cuda=use_cuda,name=DATASET_NAME,aug=AUG)
    pth_raw=os.path.join(pth,'raw')
    pth_split=os.path.join(pth,'split')
    if not os.path.exists(pth):
        os.mkdir(pth)
    if not os.path.exists(pth_split):
        os.mkdir(pth_split)
    for m in method:
        #adgt.explain_all(img, target, logdir=pth_raw, method=m, model=net, random=False, attack=False,suffix='',topklabel=topklabel)
        adgt.explain_top5(img, net, m,file_name=pth_split)
        #adgt.explain_split_fft(img, target, logdir=pth_split, method=m, model=net, random=False, attack=False, suffix='',
        #                   topklabel=topklabel)
from utils.visualization import save_images
def visualize(net,img,method):
    pth=os.path.join(ROOT,'class_independent')
    adgt=ADGT.ADGT(use_cuda=use_cuda,name=DATASET_NAME,aug=AUG)
    pth_raw=os.path.join(pth,'raw')
    pth_split=os.path.join(pth,'split')
    if not os.path.exists(pth):
        os.mkdir(pth)
    if not os.path.exists(pth_split):
        os.mkdir(pth_split)
    save_images(img.cpu().numpy(),os.path.join(pth_split,'raw.png'))
    for m in method:
        #adgt.explain_all(img, target, logdir=pth_raw, method=m, model=net, random=False, attack=False,suffix='',topklabel=topklabel)
        adgt.pure_explain(img, net, m,file_name=pth_split)


def model_independent_cascade(net,img,target,topklabel,method,layer_names):
    pth=os.path.join(ROOT,'model_independent_cascade')
    adgt=ADGT.ADGT(use_cuda=use_cuda,name=DATASET_NAME,aug=AUG)
    pth_raw=os.path.join(pth,'raw')
    pth_split=os.path.join(pth,'split')
    if not os.path.exists(pth):
        os.mkdir(pth)
    if not os.path.exists(pth_split):
        os.mkdir(pth_split)
    for i in range(len(layer_names)):
        net=perturbation(net,layer_names[i])
        for m in method:
            '''
            adgt.explain_all(img, target, logdir=os.path.join(pth_raw,str(i)), method=m, model=net, random=False,
                             attack=False,suffix='',topklabel=topklabel)
            adgt.explain_split(img, target, logdir=os.path.join(pth_split,str(i)), method=m, model=net, random=False,
                               attack=False,suffix='',topklabel=topklabel,color=False)
            '''
            adgt.pure_explain(img, net, m,file_name=pth_split,suffix=layer_names[i])
import copy
def model_independent_individual(net,img,target,topklabel,method,layer_names):
    pth=os.path.join(ROOT,'model_independent_individual')
    adgt=ADGT.ADGT(use_cuda=use_cuda,name=DATASET_NAME,aug=AUG)

    pth_split=os.path.join(pth,'split')
    if not os.path.exists(pth):
        os.mkdir(pth)
    if not os.path.exists(pth_split):
        os.mkdir(pth_split)

    for i in range(len(layer_names)):
        temp_net=copy.deepcopy(net)
        temp_net=perturbation(temp_net,layer_names[i])
        for m in method:

            #adgt.explain_all(img, target, logdir=os.path.join(pth_raw,str(i)), method=m, model=temp_net, random=False,
            #                 attack=False,suffix='',topklabel=topklabel)
            adgt.pure_explain(img, temp_net, m,file_name=pth_split,suffix=layer_names[i])
            
            #adgt.explain_split_fft(img, target, logdir=os.path.join(pth_split,str(i)), method=m, model=temp_net, random=False,
            #                   attack=False,suffix='',topklabel=topklabel)
import torch.nn as nn
import functools
def perturbation(model,layername):
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.xavier_normal_(m.weight.data)
        elif classname.find('Linear') != -1:
            nn.init.xavier_normal_(m.weight.data)
            m.bias.data.fill_(0)
    #wi=functools.partial(weights_init,layername=layername)
    temp=layername.split('.')
    print(temp[0])
    if len(temp)==1:
        model.__getattr__(temp[0]).apply(weights_init)
    else:
        print(temp[1])
        model.__getattr__(temp[0]).__getattr__(temp[1]).apply(weights_init)
        '''
        print(model.__getattr__(temp[0]).__getattr__(temp[1]))
        for k, v in target.state_dict().items():
            if k==temp[1]+'.weight':
                print(k)
                nn.init.xavier_normal_(v.data)
        '''
    return model

print('visualize')

visualize(net,img,method)
#class_independent(net,img,target,topklabel,method)
#print('model_independent_cascade')
#model_independent_cascade(copy.deepcopy(net),img,target,topklabel,method,layer_names)
#print('model_independent_individual')
#model_independent_individual(copy.deepcopy(net),img,target,topklabel,method,layer_names)