import ADGT

import os
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from model import resnet,vgg
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utils.visualization import save_images5,save_images2,save_images4

from sanity_checks_quantitative_baseline import ssim,draw
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
MODEL = 'resnet50'#'vgg16'#
METHOD=['InputXGradient','InputXSG','InputXGBP','DeepLIFT','IntegratedGradients']
torch.set_num_threads(4)

use_cuda = True
DATASET_NAME = 'ImageNet'
if not os.path.exists('result3'):
    os.mkdir('result3')
ROOT = 'result3/Mask_Final'
if not os.path.exists(ROOT):
    os.mkdir(ROOT)

class Block_Layer(nn.Module):
    def __init__(self,size=1):
        super(Block_Layer,self).__init__()
        self.relu=nn.ReLU()
        self.size=size
    def forward(self,x):
        mask = torch.ones_like(x)
        s = int(x.size(2) / 7)
        mask[:, :, s*3:s*3+s*self.size, s*3:s*3+s*self.size] = 0
        x = self.relu(x*mask)
        return x
def prepare_model(model_name=MODEL):
    if model_name == 'inception_v3':
        # model=inception_v3(pretrained=True,transform_input=True)
        model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    elif model_name == 'resnet50':
        model = resnet.resnet50(pretrained=True)
    elif model_name == 'vgg16':
        model = vgg.vgg16(pretrained=True)
    if use_cuda:
        model = model.cuda()

    print(model)
    return model


def prepare_img(path):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size = 224

    dirs = os.listdir(path)
    imgs = []
    for fn in dirs:
        img_path = path + '/' + fn
        # img = cv2.imread(img_path, 1)
        # img = np.float32(cv2.resize(img, (img_size, img_size))) / 255
        input_image = Image.open(img_path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
        input_tensor = preprocess(input_image)
        img = input_tensor.unsqueeze(0)
        imgs.append(img)
    '''
    imgs = np.array(imgs)

    preprocessed_imgs = imgs.copy()[:, :, :, ::-1]
    for i in range(3):
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] - means[i]
        preprocessed_imgs[:, :, :, i] = preprocessed_imgs[:, :, :, i] / stds[i]
    #preprocessed_imgs=preprocessed_imgs*2-1
    preprocessed_imgs = \
        np.ascontiguousarray(np.transpose(preprocessed_imgs, (0, 3, 1, 2)))
    preprocessed_imgs = torch.from_numpy(preprocessed_imgs)
    '''
    preprocessed_imgs = torch.cat(imgs, 0)
    if use_cuda:
        preprocessed_imgs = preprocessed_imgs.cuda()
    return preprocessed_imgs

def draw(X,pth):
    x=[]
    for m in X:
        temp=m
        #temp=F.relu(temp)
        if temp.size(2)<224:
            #temp=F.interpolate(temp,(224,224),mode='nearest')
            temp = F.interpolate(temp, (224, 224),mode='bilinear')
        temp = temp.sum(1,keepdim=True)
        temp=temp.detach().cpu().numpy()
        high = np.abs(temp).max()
        temp=temp/high
        x.append(temp)
    x=np.concatenate(x,0)
    save_images5(x, pth)


from utils.visualization import save_images


def blocking_model2(model,target_layer=2,size=2):
    print(len(model.features))
    index=-1
    for i in range(len(model.features)):
        if isinstance(model.features[i], nn.ReLU):
            index+=1
            if index + target_layer == 12:
                model.features[i]=Block_Layer(size)
    return model
def blocking_model_vgg(model,size=2):
    print(len(model.features))
    index=-1
    for i in range(len(model.features)):
        if isinstance(model.features[i], nn.ReLU):
            model.features[i]=Block_Layer(size)
            break
    return model
def _set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)
def blocking_model_res(model,size=2):
    model.relu=Block_Layer(size)
    return model
def quatification(X,save_pth,suffix=''):
    N = len(X)
    result = np.zeros([len(METHOD), 13])
    for i in range(0,N):
        for j in range(len(METHOD)):
            temp = X[i][j]

            # temp=temp.unsqueeze(0)
            base = F.interpolate(temp, size=(224, 224),
                                 mode='bilinear')
            base = base.abs()
            base=base/base.max()
            result[j,i]=energy_rate(base,2)
    print(result)
    draw(result*100, os.path.join(save_pth, suffix+'energy_rate.jpg'))
    np.savetxt(os.path.join(save_pth, suffix+'energy_rate.csv'), result, delimiter=',')
def energy_rate(x,size=1):
    mask=torch.zeros_like(x)
    s = int(x.size(2) / 7)
    mask[:, :, s * 3:s * 3 + s * size, s * 3:s * 3 + s * size] = 1
    x=torch.abs(x)
    er=torch.sum(x*mask)/(torch.sum(x))
    return er.item()
def visualize(X,save_pth,size=224):
    fig = plt.figure()
    N=len(X)
    for i in range(N):
        for j in range(len(METHOD)):
            temp=X[i][j]
            #temp=temp.unsqueeze(0)
            temp = F.interpolate(temp, size=(size,size),
                                 mode='bilinear')
            temp=temp.sum(1).squeeze().cpu().numpy()
            plt.subplot(len(METHOD), N, j*N+i+1)
            ax = fig.add_subplot(len(METHOD), N, j*N+i+1)
            ax.set_xticks([])
            ax.set_yticks([])
            img = temp
            plt.imshow(img, cmap="seismic")
            high = np.abs(temp).max()
            plt.clim(-high, high)

    plt.gcf().set_size_inches(size*N / 100,  size*10/ 100)
    plt.subplots_adjust(top=0.99, bottom=0.01, right=0.99, left=0.01, hspace=0.05, wspace=0.05)
    plt.savefig(save_pth)
    plt.close()


import copy


import torch.nn as nn
import functools




if __name__=='__main__':
    net = prepare_model()
    net.eval().cuda()
    if MODEL=='vgg16':
        blocking_model_vgg(net)
    else:
        blocking_model_res(net)
    img = prepare_img(os.path.join('sanity_checks_saliency', 'data', 'images'))
    img = img.cuda()

    import ADGT, split_model

    adgt = ADGT.ADGT(use_cuda=True, name='ImageNet')
    pth = ROOT
    if not os.path.exists(pth):
        os.mkdir(pth)
    model_As, model_Bs, feature_size = split_model.prepare_model(net,mode=MODEL)

    for i in range(img.size(0)):
        base_img = img[i].unsqueeze(0)
        X = split_model.explain_middle_layer_and_raw(base_img, net, model_As, model_Bs, feature_size, METHOD,
                                                     )
        split_model.visualize_all(X, pth, prefix=str(i), methods=METHOD)
        save_images2(base_img.detach().cpu().numpy(), os.path.join(pth, str(i) + 'raw.png'))