import ADGT
import torch
import numpy as np
import cv2
import os
from PIL import Image
from torchvision import transforms
from model.inception import inception_v3
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model import resnet,vgg
import matplotlib.pyplot as plt
from utils.visualization import save_images5,save_images2,save_images4

os.environ["CUDA_VISIBLE_DEVICES"] = "7"
MODEL = 'resnet50'#'vgg16'#
ROOT='result3'
if not os.path.exists(ROOT):
    os.mkdir(ROOT)
method=['InputXGradient','SmoothGrad','Guided_BackProp','DeepLIFT','IntegratedGradients','GradCAM','ScoreCAM','CAMERAS','FullGrad','Guided_GradCAM']

TARGET=4
print(MODEL,TARGET)
torch.set_num_threads(4)

use_cuda = True


def prepare_model(model_name=MODEL):
    if model_name == 'inception_v3':
        # model=inception_v3(pretrained=True,transform_input=True)
        model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    elif model_name == 'resnet50':
        model = resnet.resnet50(pretrained=True)
    elif model_name == 'vgg16':
        model = vgg.vgg16(pretrained=True)
    if use_cuda:
        model = model.cuda()

    print(model)
    return model


def prepare_img(path):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size = 224

    dirs = os.listdir(path)
    imgs = []
    for fn in dirs:
        img_path = path + '/' + fn
        # img = cv2.imread(img_path, 1)
        # img = np.float32(cv2.resize(img, (img_size, img_size))) / 255
        input_image = Image.open(img_path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
        input_tensor = preprocess(input_image)
        img = input_tensor.unsqueeze(0)
        imgs.append(img)
    '''
    imgs = np.array(imgs)

    preprocessed_imgs = imgs.copy()[:, :, :, ::-1]
    for i in range(3):
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] - means[i]
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] / stds[i]
    #preprocessed_imgs=preprocessed_imgs*2-1
    preprocessed_imgs = \
        np.ascontiguousarray(np.transpose(preprocessed_imgs, (0, 3, 1, 2)))
    preprocessed_imgs = torch.from_numpy(preprocessed_imgs)
    '''
    preprocessed_imgs = torch.cat(imgs, 0)
    if use_cuda:
        preprocessed_imgs = preprocessed_imgs.cuda()
    return preprocessed_imgs

def draw(X,pth):
    x=[]
    for m in X:
        temp=m
        #temp=F.relu(temp)
        if temp.size(2)<224:
            #temp=F.interpolate(temp,(224,224),mode='nearest')
            temp = F.interpolate(temp, (224, 224),mode='bilinear')
        temp = temp.sum(1,keepdim=True)
        temp=temp.detach().cpu().numpy()
        high = np.abs(temp).max()
        temp=temp/high
        x.append(temp)
    x=np.concatenate(x,0)
    save_images5(x, pth)


print('visualize')
max_K = 5
net = prepare_model()
net.eval().cuda()
img = prepare_img(os.path.join('sanity_checks_saliency', 'data', 'images'))
img=img.cuda()

pred = net(img)
_, topklabel = torch.topk(pred, max_K)
_, target = torch.topk(pred, 1)
print(img.size(),topklabel)


import ADGT
adgt = ADGT.ADGT(use_cuda=True, name='ImageNet')
pth = os.path.join(ROOT, 'Visual_Inspection_'+MODEL+'_'+str(TARGET))
if not os.path.exists(pth):
    os.mkdir(pth)
np.savetxt(os.path.join(pth,'label.csv'),topklabel.cpu().numpy(),delimiter=',')
for i in range(img.size(0)):
    base_img=img[i].unsqueeze(0)
    results = []
    for m in method:
        print(i,m)
        grad = adgt.pure_explain(base_img, net, m,target=topklabel[i,TARGET])
        results.append(grad)
    draw(results, os.path.join(pth, str(i) + 'baseline.png'))
    # save_images2(base_img.detach().cpu().numpy(), os.path.join(pth,str(j)+'_ResNet','raw.png'))
    save_images2(base_img.detach().cpu().numpy(), os.path.join(pth,  str(i) + 'raw.png'))