import numpy as np

from TorchRay.torchray.attribution.grad_cam import grad_cam
from TorchRay.torchray.attribution.grad_cam_positive import grad_cam_positive
from src import Pruner
from src.attribution_methods import grad_times_image, vanilla_saliency as VBP
from src.attribution_methods.gradcam import generate_grad_cam_saliency_maps
from src.attribution_methods.guided_backprop import GuidedBackprop
from src.attribution_methods.integrated_gradients import generate_integrad_saliency_maps
from torch.nn import functional as F

def generate_attribution(dataset_name, model, model_name, preprocessed_image, label, device, attribution_method):
    rectgrad_pruning_threshold = 90
    # Layer for gradcam
    if model_name == 'ConvNetSimple':  # Last Conv+Relu Layer
        target_layer = -4
    elif model_name in ['Resnet8', 'Resnet8v2', 'Resnet10']:
        target_layer = (-3, 0)
        grad_cam_v2_layer = 'layer3.0.relu'
        gradcam_layer = 'layer3'  # For GradCAM Positive
    elif model_name == 'Resnet18':
        target_layer = (-3, 1)
        gradcam_layer = 'layer4'
    elif model_name == 'Resnet50':
        target_layer = (-3, 2)
        if dataset_name == 'Birdsnap':
            grad_cam_v2_layer = '0.layer4'
            # gradcam_prune_layer1 = '0.layer2'
            # gradcam_prune_layer2 = (-5, 3)
            resnet_gradcam_layer = '0.layer4'

        else:
            grad_cam_v2_layer = 'layer4'
            resnet_gradcam_layer = 'layer4'

    if attribution_method == 'VanillaSaliency':
        Vanilla_backprop = VBP.VanillaSaliency(model, device)
        grad = Vanilla_backprop.generate_saliency(preprocessed_image, label, make_single_channel=False)
    elif attribution_method == 'input_x_grad':
        grad = grad_times_image.generate_grad_times_image_saliency(model, preprocessed_image, label, device, make_single_channel=False)
        grad = grad[np.newaxis, ...]
    elif attribution_method == 'gbp':
        grad = GuidedBackprop(model, device).generate_gradients(preprocessed_image, label, make_single_channel=False)
    elif attribution_method == 'gradcam':
        grad = generate_grad_cam_saliency_maps(model, None, preprocessed_image,
                                               label, target_layer=target_layer, device=device)[2]
        grad = np.stack((grad.reshape((1,) + grad.shape),) * 3, axis=1)
    elif attribution_method == 'gradcam_v2_1':
        saliency = grad_cam(model, preprocessed_image, label, saliency_layer=grad_cam_v2_layer)
        image_shape = (preprocessed_image.shape[-2], preprocessed_image.shape[-1])
        saliency = F.interpolate(saliency, image_shape, mode="bilinear", align_corners=False)
        grad = saliency.detach().cpu().clone().numpy()
        grad = np.concatenate((grad,) * 3, axis=1)
    elif attribution_method == 'integrad':
        grad = generate_integrad_saliency_maps(model, preprocessed_image, label, device,
                                               steps=50, make_single_channel=False)
        grad = grad.reshape((1,) + grad.shape)
    elif attribution_method == 'rectgard':
        grad = generate_rectgrad(model, preprocessed_image, label, rectgrad_pruning_threshold, device,
                                 make_single_channel=False)
    elif attribution_method.startswith('prune_grad'):  # attribution_method = prune_grad_X, where X is threshold
        threshold = float(attribution_method.split('_')[-1])
        pruner = Pruner.Pruner(model, preprocessed_image, device)
        pruner.prune(percentile_to_prune=threshold, debug=False)
        grad = pruner.generate_saliency(make_single_channel=False)
        if 'abs' in attribution_method and 'noabs' not in attribution_method:
            grad = np.abs(grad)
        else:
            print('abs not used')
        pruner.remove_handles()
    elif attribution_method.startswith('prune_pgd'):
        threshold = float(attribution_method.split('_')[-1])
        pruner = Pruner.Pruner(model, preprocessed_image, device)
        pruner.prune(percentile_to_prune=threshold, debug=False)
        grad = pruner.generate_saliency_pgd_l2(make_single_channel=False,
                                               debug=False)
        if 'abs' in attribution_method and 'noabs' not in attribution_method:
            grad = np.abs(grad)
        else:
            print('abs not used')
        pruner.remove_handles()
    elif attribution_method == 'sparsity0_prune_pgd':
        pruner = Pruner.Pruner(model, preprocessed_image, device)
        threshold = pruner.base_sparsity()
        pruner.prune(percentile_to_prune=threshold, debug=False)
        grad = pruner.generate_saliency_pgd_l2(make_single_channel=False, debug=False)
        grad = np.abs(grad)
        pruner.remove_handles()
    elif attribution_method.startswith('PruneInteg'):
        sparsity = float(attribution_method.split('_')[-1])
        # grad = np.concatenate((saliency,) * 3, axis=1)
        pruner = Pruner.Pruner(model, preprocessed_image, device)
        pruner.prune_integrad(sparsity, debug=False)
        saliency = pruner.generate_saliency(make_single_channel=False)
        grad = np.abs(saliency)
        pruner.remove_handles()

    assert type(grad) == np.ndarray
    assert grad.shape.__len__() == 4 and grad.shape[0] == 1 and grad.shape[1] == 3
    return grad[0]

