import ADGT

import os
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from model import resnet,vgg
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utils.visualization import save_images,save_images2
from attribution_methods.TRGBP import Block_Layer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
MODEL = 'vgg16'  # 'resnet50'#'canet11'#'inception_v3'#
method='TRGBP'

torch.set_num_threads(4)
# layer_names=['fc','Mixed_7c','Mixed_7b','Mixed_7a','Mixed_6e','Mixed_6d','Mixed_6c','Mixed_6b','Mixed_6a',
#           'Mixed_5d','Mixed_5c','Mixed_5b','Conv2d_4a_3x3','Conv2d_3b_1x1','Conv2d_2b_3x3','Conv2d_2a_3x3',
#           'Conv2d_1a_3x3']
layer_names=['classifier','features.28','features.26','features.24','features.21','features.19','features.17',
             'features.14','features.12','features.10','features.7','features.5','features.2','features.0',]
#layer_names = ['fc', 'layer4', 'layer3', 'layer2', 'layer1', 'conv1']
use_cuda = True
DATASET_NAME = 'ImageNet'
if not os.path.exists('result2'):
    os.mkdir('result2')
ROOT = 'result2/SanityChecks'
if not os.path.exists(ROOT):
    os.mkdir(ROOT)


def prepare_model(model_name=MODEL):
    if model_name == 'vgg16':
        model = vgg.vgg16(pretrained=True)
        model = model.cuda()
        print(model)
    return model


def prepare_img(path):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size = 224

    #dirs = os.listdir(path)
    #imgs = []
    #for fn in dirs:
    if True:
        img_path = path #+ '/' + fn
        # img = cv2.imread(img_path, 1)
        # img = np.float32(cv2.resize(img, (img_size, img_size))) / 255
        input_image = Image.open(img_path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
        input_tensor = preprocess(input_image)
        img = input_tensor.unsqueeze(0)
    preprocessed_imgs = img
    save_images2(img.cpu().numpy(),os.path.join(ROOT,'raw.jpg'))
    if use_cuda:
        preprocessed_imgs = preprocessed_imgs.cuda()
    return preprocessed_imgs


from utils.visualization import save_images


def model_independent_cascade(net, img, method, layer_names):
    adgt = ADGT.ADGT(use_cuda=use_cuda, name=DATASET_NAME)

    M_result=[]

    M= adgt.pure_explain(img, net, method)
    M_result.append(M)

    for i in range(len(layer_names)):
        net = perturbation(net, layer_names[i])
        if i%1==0:
            M=adgt.pure_explain(img, net, method)
            M_result.append(M)

    #M_result.pop()

    #M = adgt.pure_explain(img, net, method)
    #M_result.append(M)

    visualize(M_result,os.path.join(ROOT,'M.jpg'))



def visualize(X,save_pth,size=224):
    fig = plt.figure()
    N=len(X)
    T=14
    for i in range(N):
        for j in range(T):
            temp=X[i][j]
            #temp=temp.unsqueeze(0)
            temp = F.interpolate(temp, size=(size,size),
                                 mode='bilinear')
            temp=temp.sum(1).squeeze().cpu().numpy()
            plt.subplot(T, N, j*N+i+1)
            ax = fig.add_subplot(T, N, j*N+i+1)
            ax.set_xticks([])
            ax.set_yticks([])
            img = temp
            plt.imshow(img, cmap="seismic")
            high = np.abs(temp).max()
            plt.clim(-high, high)

    plt.gcf().set_size_inches(size*N / 100,  size*T/ 100)
    plt.subplots_adjust(top=0.99, bottom=0.01, right=0.99, left=0.01, hspace=0.05, wspace=0.05)
    plt.savefig(save_pth)
    plt.close()

import copy


import torch.nn as nn
import functools


def perturbation(model, layername):
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.xavier_normal_(m.weight.data)
        elif classname.find('Linear') != -1:
            nn.init.xavier_normal_(m.weight.data)
            m.bias.data.fill_(0)

    # wi=functools.partial(weights_init,layername=layername)
    temp = layername.split('.')
    print(temp[0])
    if len(temp) == 1:
        model.__getattr__(temp[0]).apply(weights_init)
    else:
        print(temp[1])
        model.__getattr__(temp[0]).__getattr__(temp[1]).apply(weights_init)
        '''
        print(model.__getattr__(temp[0]).__getattr__(temp[1]))
        for k, v in target.state_dict().items():
            if k==temp[1]+'.weight':
                print(k)
                nn.init.xavier_normal_(v.data)
        '''
    return model


if __name__=='__main__':
    net = prepare_model()
    net.eval()
    img = prepare_img(os.path.join('sanity_checks_saliency', 'data', 'demo_images','ILSVRC2012_val_00015410.JPEG'))
    pred = net(img)

    model_independent_cascade(copy.deepcopy(net),img,method,layer_names)
