import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torchvision import models
from utils.settings import img_size
from utils.preprocess import mean, std
from saliency_methods import ExpectedGradients
from explain_exp import ExplainerExp


def load_explainer(model, **kwargs):
    method_name = kwargs['method_name']
    if method_name == 'ExpGrad':
        # ------------------------------------ Expected Gradients ----------------------------------------------
        print('============================ Expected Gradients ============================')
        k = kwargs['k']
        bg_size = kwargs['bg_size']
        train_dataset = kwargs['train_dataset']
        test_batch_size = kwargs['test_batch_size']
        random_alpha = kwargs['random_alpha']
        cal_type = kwargs['cal_type']

        expected_grad = ExpectedGradients(model, k=k, bg_dataset=train_dataset, bg_size=bg_size,
                                          batch_size=test_batch_size, random_alpha=random_alpha, cal_type=cal_type)
        return expected_grad


def load_dataset(dataset_name, test_batch_size):
    # ---------------------------- imagenet train ---------------------------
    if 'ImageNet' in dataset_name:
        imagenet_train_dataset = datasets.ImageNet(
            root='datasets',
            split='train',
            transform=transforms.Compose([
                transforms.Resize(size=(img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]))
        # ---------------------------- imagenet eval ---------------------------
        imagenet_val_dataset = datasets.ImageNet(
            root='datasets',
            split='val',
            transform=transforms.Compose([
                transforms.Resize(size=(img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]))
        imagenet_val_loader = torch.utils.data.DataLoader(
            imagenet_val_dataset, batch_size=test_batch_size,
            shuffle=False, num_workers=4, pin_memory=False)

        return imagenet_train_dataset, imagenet_val_loader


def evaluate(method_name, model_name, dataset_name, metric, k=None, bg_size=None):
    model_name = model_name  # vgg16 resnet34
    method_name = method_name

    if model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
    elif model_name == 'resnet34':
        model = models.resnet34(pretrained=True)

    model = model.to('cuda')
    model = torch.nn.DataParallel(model)
    model.eval()

    # =================== load train dataset & test loader ========================
    test_bth = 80
    train_dataset, test_loader = load_dataset(dataset_name=dataset_name, test_batch_size=test_bth)

    # =================== load explainer ========================
    explainer_args = {
        'ExpGrad': {'method_name': 'ExpGrad', 'k': 1, 'bg_size': 50, 'train_dataset': train_dataset,
                    'test_batch_size': test_bth, 'random_alpha': True, 'cal_type': 'valid_ref'},
    }
    if k is not None:
        explainer_args[method_name]['k'] = k
        explainer_args[method_name]['bg_size'] = bg_size

    if metric == 'pixel_perturb':
        explainer = load_explainer(model=model, **explainer_args[method_name])
        explain_exp = ExplainerExp(model, explainer=explainer, dataloader=test_loader)
        explain_exp.perturb_exp(q_ratio_lst=[step * 0.1 for step in range(1, 10)])


if __name__ == '__main__':
    method = 'ExpGrad'
    evaluate(method_name=method, model_name='vgg16', dataset_name='ImageNet_vis', metric='visualization')
