import numpy as np

from src.saliency_methods import IntegratedGradients
import torchvision

import torchvision.transforms as transforms
from src.utils import BlockMnist
from torch.utils.data import DataLoader
from src.model import *


def norm_map(s_map):
    v_max = np.percentile(s_map, 98) + 1e-6  # np.max(s_map) + 1e-10
    v_min = np.percentile(s_map, 2)  # np.min(s_map)
    s_norm = np.clip((s_map - v_min) / (v_max - v_min), 0, 1)
    return s_norm


def leakage_measure(model, attr_method='IG', logger=None):

    te_dataset = torchvision.datasets.MNIST('Data',
                                            train=False,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                BlockMnist(test=True),
                                            ]),
                                            download=True)
    te_loader = DataLoader(te_dataset, batch_size=256, shuffle=False, num_workers=4)

    explainer = IntegratedGradients(model, k=10)

    null_locate = np.load('null_block_loc.npy')
    null_block = np.array(torch.load('null_block.pt')).reshape(-1)
    batch_size = te_loader.batch_size

    std_sum = 0
    for batch_num, (image, label) in enumerate(te_loader):
        image = image.cuda()
        target = label.cuda()

        saliency_map = explainer.shap_values(image, sparse_labels=target)
        saliency_map = saliency_map.detach().cpu().numpy()

        for bth in range(image.shape[0]):
            curr_ind = int(batch_num*batch_size + bth)
            s_map = saliency_map[bth:bth+1]
            s_map = norm_map(s_map)

            loc = null_locate[curr_ind]
            block_s = s_map[:, :, :28, :] if loc==0 else s_map[:, :, 28:, :]
            block_s = block_s.reshape(-1)
            block_s = block_s[null_block == 1]
            std_sum += np.linalg.norm(block_s, 2)  # /block_s.shape[0]
    if logger is not None:
        logger.info('#### leakage std: %.3f ####\n' % (std_sum/10000))
    else:
        print('leakage std: %.3f \n' % (std_sum/10000))


if __name__ == '__main__':

    mode_name = [
        'model.pth',
    ]

    for m_name in mode_name:
        m_name.strip()
        print(m_name)
        attribution_method = m_name.split('.')[0]
        act_func = 'ReLU' if 'ReLU' in m_name else 'Softplus'
        model = Model(i_c=1, n_c=10, act=act_func)
        model = torch.nn.DataParallel(model)
        model_pth = 'checkpoint/mnist_/' + m_name
        pretrained_model = torch.load(model_pth)
        model.load_state_dict(pretrained_model, strict=True)
        leakage_measure(model=model)
