import ADGT
import torch
import numpy as np
import cv2
import os
from PIL import Image
from torchvision import transforms
from model.inception import inception_v3
from attribution_methods.our.model.resnet import resnet50
from model.vgg import vgg16_bn
from attribution_methods.our.model.canet import canet11_imagenet, canet11_imagenet_no_bn
import torchvision
import torch.nn as nn

os.environ["CUDA_VISIBLE_DEVICES"] = "5"
MODEL = 'vgg16'#'resnet50'#'canet11'  # 'inception_v3'#
#method = ['GIG','GradCAM','FullGrad','our_method_minus_mean' ,'DeepLIFT','Saliency','CAMERAS','IntegratedGradients']
method=['our_method_minus_mean' ]
# method=['IntegratedGradients','InputXGradient','Saliency','RectGrad','SmoothGrad']#

torch.set_num_threads(4)
# layer_names=['fc','Mixed_7c','Mixed_7b','Mixed_7a','Mixed_6e','Mixed_6d','Mixed_6c','Mixed_6b','Mixed_6a',
#           'Mixed_5d','Mixed_5c','Mixed_5b','Conv2d_4a_3x3','Conv2d_3b_1x1','Conv2d_2b_3x3','Conv2d_2a_3x3',
#           'Conv2d_1a_3x3']
# layer_names=['classifier','features.28','features.14','features.0',]
layer_names = ['fc', 'layer4', 'layer3', 'layer2', 'layer1', 'conv1']
use_cuda = True
DATASET_NAME = 'ImageNet'
ROOT = 'result'
AUG = True
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    #torch.backends.cudnn.deterministic = True


if not os.path.exists(ROOT):
    os.mkdir(ROOT)


def prepare_model(model_name=MODEL):
    if model_name == 'inception_v3':
        # model=inception_v3(pretrained=True,transform_input=True)
        model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    elif model_name == 'resnet50':
        model = resnet50(pretrained=True)
    elif model_name == 'vgg16':
        model = vgg16_bn(pretrained=True)
        # model=torch.nn.DataParallel(model)
    else:
        model = canet11_imagenet(cane=True)
    if use_cuda:
        model = model.cuda()

    print(model)
    # for name, module in model._modules.items():
    #    print(name)
    #    if "avgpool" == name.lower() or 'avg_pool' == name.lower():
    #        pass
    return model


def prepare_img(path,means = (0.485, 0.456, 0.406),stds = (0.229, 0.224, 0.225),img_size = 224):
    imgs = []
    if path.endswith('JPEG') or path.endswith('png') or path.endswith('jpg'):
        input_image = Image.open(path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds),
        ])
        input_tensor = preprocess(input_image)
        img = input_tensor.unsqueeze(0)
        imgs.append(img.cuda())
        return imgs
    else:
        dirs = os.listdir(path)
        for di in dirs:
            img_path = path + '/' + di
            imgs+=prepare_img(img_path)
        return imgs


def list_5():
    max_K = 5
    net = prepare_model()

    net.eval()
    #img = prepare_img(os.path.join('sanity_checks_saliency', 'data', 'multi'))
    img = prepare_img(os.path.join('sanity_checks_saliency', 'data', 'demo_images'))
    img = torch.cat(img, 0)
    if use_cuda:
        img = img.cuda()
    img = img[img.size(0) - max_K:]
    print(img.size())
    pred = net(img)
    _, topklabel = torch.topk(pred, max_K)
    print(topklabel)
    _, target = torch.topk(pred, 1)


from utils.visualization import save_images,save_images2


def visualize(net, img, method):
    pth = os.path.join(ROOT, 'class_independent')
    adgt = ADGT.ADGT(use_cuda=use_cuda, name=DATASET_NAME, aug=AUG)
    pth_raw = os.path.join(pth, 'raw')
    pth_split = os.path.join(pth, 'split')
    if not os.path.exists(pth):
        os.mkdir(pth)
    if not os.path.exists(pth_split):
        os.mkdir(pth_split)
    save_images2(img.cpu().numpy(), os.path.join(pth_split, 'raw.png'))
    for m in method:
        # adgt.explain_all(img, target, logdir=pth_raw, method=m, model=net, random=False, attack=False,suffix='',topklabel=topklabel)
        adgt.pure_explain(img, net, m, file_name=pth_split)

def prepare_val(data_dir='/data_SSD2/zgh/workspace/data/ImageNet'):
    setup_seed(20)
    valdir = os.path.join(data_dir, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    val_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=64, shuffle=True,
        num_workers=4, pin_memory=False)
    for img,target in val_loader:
        return torch.split(img.cuda(),1),torch.split(target.cuda(),1)

def all_visualize(net, method,img_dir='sanity_checks_saliency/data'):
    pth = os.path.join(ROOT, 'visualize_'+MODEL)
    adgt = ADGT.ADGT(use_cuda=use_cuda, name=DATASET_NAME, aug=AUG)
    if not os.path.exists(pth):
        os.mkdir(pth)
    #img=prepare_img(img_dir)
    img,target=prepare_val()
    for iii in range(len(img)):
        print(iii)
        pth_split=os.path.join(pth,str(iii))
        if not os.path.exists(pth_split):
            os.mkdir(pth_split)
        save_images2(img[iii].cpu().numpy(), os.path.join(pth_split, 'raw.png'))
        for m in method:
            adgt.pure_explain(img[iii].cuda(), net, m, file_name=pth_split,top1=target[iii])

from attribution_methods.our.networks import ExplainableNet_ResNet,ExplainableNet
import torch.nn.functional as F

def class_independent(net,img_dir='sanity_checks_saliency/data/demo_images',method='our'):
    pth = os.path.join(ROOT, 'class_independent_demo' + MODEL)
    if MODEL == 'vgg16':
        model = ExplainableNet(net, method, beta=None)
    else:
        model = ExplainableNet_ResNet(net, method, beta=None)
    model.eval()
    if not os.path.exists(pth):
        os.mkdir(pth)
    img=prepare_img(img_dir)
    #img, target = prepare_val()
    for iii in range(len(img)):
        pth_split = os.path.join(pth, str(iii))
        if not os.path.exists(pth_split):
            os.mkdir(pth_split)
        save_images2(img[iii].cpu().numpy(), os.path.join(pth_split, 'raw.png'))
        pred=model(img[iii])
        K=4
        topkvalue, topklabel = torch.topk(pred, K)
        #print(iii,target[iii],topklabel)
        print(iii, topklabel)
        other = []
        for i in range(K):
            other.append(model.analyze(method=method, index=topklabel[:, i]))
            #save_images(other[i].detach().cpu().numpy(), os.path.join(pth_split, str(i) + '.jpg'))
        mean = sum(other) / K
        # other=torch.cat(tuple(other),0)
        for i in range(K):
            result = F.relu(other[i] - mean)
            save_images(result.detach().cpu().numpy(), os.path.join(pth_split, 'result_' + str(i) + '_.jpg'))


def obtain_internal_results(net,img,method='our+'):
    pth = os.path.join(ROOT,'interal_'+MODEL+'_'+method)

    if not os.path.exists(pth):
        os.mkdir(pth)
    pth_internal=os.path.join(pth,'internal')
    if not os.path.exists(pth_internal):
        os.mkdir(pth_internal)
    save_images2(img.cpu().numpy(), os.path.join(pth, 'raw.png'))
    if MODEL == 'vgg16':
        model = ExplainableNet(net, method, beta=None)
    else:
        model=ExplainableNet_ResNet(net,method,beta=None)
    model.eval()
    pred = model(img)
    K = 5
    topkvalue, topklabel = torch.topk(pred, K)
    print(topkvalue,topklabel)
    _, target = torch.topk(pred, 1)
    print(target.squeeze().size())
    reconstruction_results=model.analyze(
        method=method,index=target.squeeze(),no_aggr=True
    )
    for i in range(len(reconstruction_results)):
        save_images(reconstruction_results[i].detach().cpu().numpy(),os.path.join(pth_internal,str(i)+'.jpg'))
    aggr=model.analyze(
        method=method,index=target.squeeze()
    )
    save_images(aggr.detach().cpu().numpy(), os.path.join(pth, 'target.jpg'))
    other=[]

    for i in range(K):
        other.append(model.analyze(method=method,index=topklabel[:,i].squeeze()))
        save_images(other[i].detach().cpu().numpy(), os.path.join(pth,str(i)+ '.jpg'))
    mean=sum(other)/K
    #other=torch.cat(tuple(other),0)
    for i in range(K):
        result=F.relu(other[i]-mean)
        save_images(result.detach().cpu().numpy(), os.path.join(pth, 'result_'+str(i)+'_.jpg'))
def randomization_test(net,img_dir='sanity_checks_saliency/data/demo_images',method='our',perturb=None):
    if perturb is None:
        pth = os.path.join(ROOT, 'randomization_test_' + MODEL)
    else:
        pth = os.path.join(ROOT, 'randomization_test_' + MODEL+str(perturb))
    if MODEL == 'vgg16':
        model = ExplainableNet(net, method, beta=None)
    else:
        model = ExplainableNet_ResNet(net, method, beta=None)
    model.eval()
    if not os.path.exists(pth):
        os.mkdir(pth)
    img=prepare_img(img_dir)
    #img, target = prepare_val()
    for iii in range(len(img)):
        pth_split = os.path.join(pth, str(iii))
        if not os.path.exists(pth_split):
            os.mkdir(pth_split)
        save_images2(img[iii].cpu().numpy(), os.path.join(pth_split, 'raw.png'))
        pred=model(img[iii])
        K=5
        topkvalue, topklabel = torch.topk(pred, K)
        print(iii,topklabel)
        reconstruction_results = model.analyze(
            method=method, index=topklabel[:,0],no_aggr=True
        )

        aggr = model.analyze(
            method=method, index=topklabel[:,0]
        )
        for i in range(1, K):
            temp0=model.analyze(
                method=method, index=topklabel[:,i], no_aggr=True
            )
            temp1=model.analyze(
                method=method, index=topklabel[:,i]
            )
            for j in range(len(reconstruction_results)):
                reconstruction_results[j]-=temp0[j]/(K-1)
            aggr-=temp1/(K-1)
        for j in range(len(reconstruction_results)):
            save_images(reconstruction_results[j].detach().cpu().numpy(),os.path.join(pth_split, str(j)+'.jpg'))
        save_images(aggr.detach().cpu().numpy(), os.path.join(pth_split, 'all.jpg'))
def perturbation(model,layername):
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.xavier_normal_(m.weight.data)
        elif classname.find('Linear') != -1:
            nn.init.xavier_normal_(m.weight.data)
            m.bias.data.fill_(0)
    #wi=functools.partial(weights_init,layername=layername)
    temp=layername.split('.')
    print(temp[0])
    if len(temp)==1:
        model.__getattr__(temp[0]).apply(weights_init)
    else:
        print(temp[1])
        model.__getattr__(temp[0]).__getattr__(temp[1]).apply(weights_init)
        '''
        print(model.__getattr__(temp[0]).__getattr__(temp[1]))
        for k, v in target.state_dict().items():
            if k==temp[1]+'.weight':
                print(k)
                nn.init.xavier_normal_(v.data)
        '''
    return model

import copy
print('visualize')
net = prepare_model()
net.eval()
#all_visualize(net,method)
class_independent(net)
#randomization_test(copy.deepcopy(net))
#for i in range(len(layer_names)):
#    net=perturbation(net,layer_names[i])
#    randomization_test(copy.deepcopy(net),perturb=i)
#visualize(net, img, method)
#obtain_internal_results(net,img,'our')
