import datetime
import itertools
import os
from pprint import pformat

import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import time
from src.dataset.factory import create_dataset, MAP_DATASET_TO_ENUM
from src.models.utils import get_model
from src.generate_attribution import generate_attribution

num_samples = 10000

dataset = 'CIFAR10'
model = 'Resnet8'
# Model Weights
model_weights = './logs/model_ResNet8_dataset_CIFAR10_bs_256_name_torch.optim.Adam_lr_0.001/epoch_0024-model-val_accuracy_84.36.pth'
opt = dict(dataset=dataset, model=model, model_weights=model_weights, num_samples=num_samples)

def remove(image, attribution, replace_value, percentiles, bottom=False, return_mask=False, gray=False):
    """
    image         : tensor of shape [C,H,W]
    attribution   : tensor of shape [C,H,W] or [H,W]
    replace_value : value to replace pixels of original image with
    percentiles   : scalar values between 0 and 100, inclusive. Remove percentile % of pixels
    bottom        : if true keep percentile percent(keeps top percentile percent);
                    otherwise remove 100-percentile percent(keeps bottom percentile percent)
    return_mask   : Returns mask computed for debugging dataset
    gray          : Use single channel for computing percentile.
    """
    modified_images = []
    masks = []
    for percentile in percentiles:
        
        # Convert to 1D numpy array
        modified_image = np.copy(image)

        if gray:
            pixels_replace_threshold = int((percentile * image.size) / 300)
            if len(attribution.shape) == 3:  # gradcam gives 3 channel image
                attribution_tmp = np.array(np.ravel(np.copy(attribution[0])))
            else:  # If single channel image is passed
                attribution_tmp = np.array(np.ravel(np.copy(attribution)))
            mask = np.zeros(attribution_tmp.shape, dtype=bool)
        else:
            pixels_replace_threshold = int(percentile * image.size / 100)
            attribution_tmp = np.array(np.ravel(np.copy(attribution)))
            mask = np.zeros(attribution_tmp.shape, dtype=bool)

        if bottom:
            attribution_index = (attribution_tmp).argsort()[:pixels_replace_threshold][::-1]  # Indices of lowest values
        else:
            attribution_index = attribution_tmp.argsort()[-pixels_replace_threshold:][::-1]  # Indices of lowest values

        mask[attribution_index] = True

        if gray:
            mask = mask.reshape(image[0].shape)
            for i in range(3):
                modified_image[i, mask] = replace_value[i]
        else:
            mask = mask.reshape(image.shape)
            for i in range(3):  # ToDo - Dont hardcode channels
                modified_image[i, mask[i]] = replace_value[i]
        modified_images.append(modified_image)
        if return_mask:
            masks.append(mask)

    if return_mask:
        return modified_images, masks
    return modified_images

def perform_perturbation_analysis(arguments, opts):
    """ Setup result directory """
    print('Arguments:\n{}'.format(pformat(arguments)))

    """ Set device - cpu or gpu """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device - {device}')

    """ Load Model with weights(if available) """
    model: torch.nn.Module = get_model(arguments.get('model_args'), arguments.get('dataset_args'), device).to(device)

    """ Load parameters for the Dataset """
    dataset = create_dataset(arguments['dataset_args'],
                             arguments['val_data_args'],  # Just use val_data_args as train_data_args
                             arguments['val_data_args'],)
                             #split_ratio=1)  # Split doesnt matter, we use test dataset

    dataloader = dataset.test_dataloader

    attribution_methods = ['VanillaSaliency', 'input_x_grad', 'gbp', 'gradcam', 'integrad']
    prune_methods = ['prune_grad_abs', 'PruneInteg']
    prune_thresholds = [86.5]
    attribution_methods.extend(prune_method + '_' + str(prune_threshold)
                               for prune_method, prune_threshold in itertools.product(prune_methods, prune_thresholds))

    num_samples = opts['num_samples']
    gray = True

    # Step sizes to remove top k or bottom k
    percentiles = list(range(1, 11, 1))  # 1, 2, 3, 4, 5, 6, 7, 8, 9, 10

    # To save sum of delta output change for each
    # attribution method, percentile and remove top and bottom percentile pixels
    output_deviation_sum = np.zeros((len(attribution_methods), len(percentiles), 2), dtype=np.float)

    model.eval()
    for counter, data in enumerate(tqdm(dataloader, total=num_samples)):
        if counter == num_samples:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        outputs = model(inputs).detach().cpu()
        _, max_prob_indices = torch.max(outputs.data, 1)
        outputs = torch.nn.functional.softmax(outputs, dim=1)
        outputs = outputs.numpy()

        for attribution_method_index, attribution_method in enumerate(attribution_methods):

            for preprocessed_image, max_prob_index, output in zip(inputs, max_prob_indices, outputs):
                time.sleep(1)
                attribution_map = generate_attribution(dataset, 
                                                       model,
                                                       opts['model'],
                                                       preprocessed_image.unsqueeze(0),
                                                       max_prob_index,
                                                       device,
                                                       attribution_method)

                attribution_map = np.max(attribution_map, axis=0)

                preprocessed_image_gpu = preprocessed_image
                preprocessed_image = preprocessed_image.cpu().numpy()
                modified_images_bottom_remove = remove(preprocessed_image.copy(), attribution_map,
                                                       replace_value=[0,0,0],
                                                       percentiles=percentiles, 
                                                       bottom=True, 
                                                       gray=gray)
                modified_images_top_remove = remove(preprocessed_image.copy(), attribution_map,
                                                    replace_value=[0,0,0],
                                                    percentiles=percentiles, 
                                                    bottom=False, 
                                                    gray=gray)

                # Create a batch of all images
                modified_images_top_remove = torch.from_numpy(np.stack(modified_images_top_remove, axis=0)).to(device)
                modified_images_bottom_remove = torch.from_numpy(np.stack(modified_images_bottom_remove, axis=0)).to(
                    device)

                # Run forward pass
                output_top_q = model(modified_images_top_remove)
                output_bottom_q = model(modified_images_bottom_remove)

                output_top_q = torch.nn.functional.softmax(output_top_q, dim=1)
                output_bottom_q = torch.nn.functional.softmax(output_bottom_q, dim=1)

                output_top_q = output_top_q.detach().cpu().numpy()
                output_bottom_q = output_bottom_q.detach().cpu().numpy()

                # Get output value at max_prob_index for each percentile
                output_top_q_max_class_prob = output_top_q[:, max_prob_index]
                output_bottom_q_max_class_prob = output_bottom_q[:, max_prob_index]

                # Compute deviation from model output for original image at max_prob_index
                top_deviation    = np.abs((output[max_prob_index] - output_top_q_max_class_prob) / output[max_prob_index])
                bottom_deviation = np.abs((output[max_prob_index] - output_bottom_q_max_class_prob) / output[max_prob_index])

                # Add this deviation to right dimension of matrix
                output_deviation_sum[attribution_method_index, :, 0] += top_deviation
                output_deviation_sum[attribution_method_index, :, 1] += bottom_deviation

    # Divide output_deviation_sum each element by num_samples
    output_deviation_mean = output_deviation_sum / num_samples

    with np.printoptions(precision=3, formatter={'float': '{: 0.3f}'.format}, suppress=True, linewidth=np.inf):
        print("Affect of removal of most important pixels at:- \npercentiles ", percentiles)
        for ind, attr in enumerate(attribution_methods):
            print(attr.ljust(15), output_deviation_mean[ind, :, 0])
    print()
    with np.printoptions(precision=3, formatter={'float': '{: 0.3f}'.format}, suppress=True, linewidth=np.inf):
        print("Affect of removal of least important pixels at:- \npercentiles ", percentiles)
        for ind, attr in enumerate(attribution_methods):
            print(attr.ljust(15), output_deviation_mean[ind, :, 1])

    plot_pixel_perturbation(attribution_methods, output_deviation_mean, percentiles, top=True)
    plot_pixel_perturbation(attribution_methods, output_deviation_mean, percentiles, top=False)


def plot_pixel_perturbation(attribution_methods,
                            pixel_perturbation,
                            percentiles,
                            top=True):
    tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
                 (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
                 (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
                 (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
                 (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
    for i in range(len(tableau20)):
        r, g, b = tableau20[i]
        tableau20[i] = (r / 255., g / 255., b / 255.)

    plt.figure(figsize=(12, 8))
    plt.yticks([accuracy / 10 for accuracy in range(0, 11)], fontsize=14)
    plt.xticks([percentile for percentile in percentiles], fontsize=14)
    plt.ylim(0, 0.75)
    plt.xlim(0, max(percentiles))
    plt.grid(True)

    plt.xlabel(f"% of {'Most' if top else 'Least'} important input pixels removed", fontsize=14)
    plt.ylabel('Absolute Fractional Output Change', fontsize=14)

    # Remove the tick marks; they are unnecessary with the tick lines we just plotted.
    plt.tick_params(axis="both", which="both", bottom="off", top="off",
                    labelbottom="on", left="off", right="off", labelleft="on")

    attribution_to_label_mapping = {
        'VanillaSaliency': ('VanillaGradient', 2, None, None),
        'integrad': ('IntegratedGradient', 4, None, None),
        'input_x_grad': ('Grad*Input', 6, None, None),
        'gbp': ('GuidedBackprop', 8, None, None),
        'gradcam': ('GradCam', 10, None, None),
        'rectgard': ('RectGrad', 12, None, None),
        'prune_grad_abs_86.5': ('PruneGrad', 14, None, 'o'),
        'prune_pgd_abs_86.5': ('PrunePGD', 16, '--', '*'),
    }

    for index, attribution_method in enumerate(attribution_methods):
        # Plot each line separately with its own color, using the Tableau 20
        # color set in order.
        plt.plot(percentiles,
                 pixel_perturbation[index, :, 0 if top else 1],  # , 0 if top else 1, no top data
                 lw=2,
                 color=tableau20[attribution_to_label_mapping[attribution_method][1]],  # Use bold colors
                 label=attribution_to_label_mapping[attribution_method][0],
                 linestyle=attribution_to_label_mapping[attribution_method][2],
                 marker=attribution_to_label_mapping[attribution_method][3])

    plt.legend()
    plt.show()

dataset_args = dict(
    name=MAP_DATASET_TO_ENUM[opt['dataset']]
)

val_data_args = dict(
    batch_size=1,
    shuffle=True  # in case if a subset of test dataset is used. val_data_args are used for test too.
)

# Shuffling should be on to evaluate
assert val_data_args['shuffle']

model_config = dict(
    Resnet8=dict(
        model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8',
        model_weights_path=opt['model_weights'],
        model_constructor_args=dict()),
)

model_args = model_config[opt['model']]

arguments = dict(
    dataset_args=dataset_args,
    val_data_args=val_data_args,
    model_args=model_args,
    random_seed=42
)

perform_perturbation_analysis(arguments, opt)


