import argparse
import datetime
import itertools
import os
import random
from pprint import pformat

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms
from tqdm import tqdm

from src.dataset.factory import create_dataset, MAP_DATASET_TO_ENUM
from src.generate_attribution import generate_attribution
from src.models.utils import get_model, append_linear_layer_transform
from src.utils.sysutils import get_cores_count


def remove(image, attribution, replace_value, percentiles, bottom=False, return_mask=False, gray=False):
    """
    images        : tensor of shape [C,H,W]
    attributions  : tensor of shape [H,W]
    replace_value : value to replace pixels of original image with
    percentile    : scalar between 0 and 100, inclusive. Remove percentile % of pixels
    bottom        : if true keep percentile percent(keeps top percentile percent);
                    otherwise remove 100-percentile percent(keeps bottom percentile percent)
    """
    modified_images = []
    masks = []
    for percentile in percentiles:
        # Convert to 1D nummpy array
        modified_image = np.copy(image)

        if gray:
            pixels_replace_threshold = int((percentile * image.size) / 300)
            if len(attribution.shape) == 3:  # gradcam gives 3 channel image
                attribution_tmp = np.array(np.ravel(np.copy(attribution[0])))
            else:  # If single channel image is passed
                attribution_tmp = np.array(np.ravel(np.copy(attribution)))
            mask = np.zeros(attribution_tmp.shape, dtype=bool)
        else:
            pixels_replace_threshold = int(percentile * image.size / 100)
            attribution_tmp = np.array(np.ravel(np.copy(attribution)))
            mask = np.zeros(attribution_tmp.shape, dtype=bool)

        if bottom:
            attribution_index = (attribution_tmp).argsort()[:pixels_replace_threshold][::-1]  # Indices of lowest values
        else:
            attribution_index = attribution_tmp.argsort()[-pixels_replace_threshold:][::-1]  # Indices of lowest values

        mask[attribution_index] = True

        if gray:
            mask = mask.reshape(image[0].shape)
            # sum = 0
            for i in range(3):  # ToDo - Dont hardcode channels
                # sum += np.count_nonzero(mask)
                modified_image[i, mask] = replace_value[i]
        else:
            mask = mask.reshape(image.shape)
            # sum = 0
            for i in range(3):  # ToDo - Dont hardcode channels
                # sum += np.count_nonzero(mask[i])
                modified_image[i, mask[i]] = replace_value[i]
        # print(sum)
        modified_images.append(modified_image)
        if return_mask:
            masks.append(mask)
        # print(f'Keep={top}, mask_nz={np.count_nonzero(mask)}')
    if return_mask:
        return modified_images, masks
    return modified_images


def pixel_perturbation_debug(outdir,
                             name,
                             input_image,
                             preprocessed_image,
                             modified_images,
                             masks):
    assert name is not None, "Name is used as title as well as to save image if outdir provided"
    # get some random training images, attribution maps and labels
    fig = plt.figure()
    fig.suptitle(f'{name}')
    # Input Image, Attribution Map, HeatMap of Attriution Map, ThresholdedAttributionMap and Output Image
    plot_rows = len(modified_images) + 1  # One extra for column title
    plot_columns = 4

    titles = ['Input Image', 'Preprocessed Image', 'Mask', 'Modified Image']
    for location, title in zip(range(1, plot_columns + 1), titles):
        smap = fig.add_subplot(plot_rows, plot_columns, location)
        smap.axis('off')
        smap.set_title(title, fontsize=12, fontweight='bold')

    input_image_transpose = np.transpose(input_image, (1, 2, 0))
    preprocessed_image_transpose = np.transpose(preprocessed_image, (1, 2, 0))

    for index, (modified_image, mask) in enumerate(zip(modified_images, masks)):
        locations = [(index + 1) * plot_columns + i for i in range(1, plot_columns + 1)]
        images = [input_image_transpose,
                  preprocessed_image_transpose,
                  np.transpose((mask), (1, 2, 0)) if len(mask.shape) == 3 else mask,
                  np.transpose(modified_image, (1, 2, 0))]

        # Use 'default' for preprocessed image. If float is used, image - image.min() will remove normalization done.
        # Same should be done for modified
        display_modes = ['default', 'default', 'float', 'float']

        for location, title, image, display_mode in zip(locations, titles, images, display_modes):
            smap = fig.add_subplot(plot_rows, plot_columns, location)
            smap.axis('off')
            location += 1

            if display_mode == 'default':
                plt.imshow(image)
            elif display_mode == 'float':
                image = np.float64(image)
                image = image - image.min()
                image /= image.max()
                plt.imshow(np.uint8(image * 255))
            elif display_mode == 'RedBlueHeatmap':
                image = np.float64(image)
                pos_values = np.copy(image)
                pos_values[pos_values < 0.] = 0.0
                neg_values = np.copy(image)
                neg_values[neg_values > 0.] = 0.0
                abs_neg_values = abs(neg_values)

                plt.imshow(pos_values, cmap='Reds')
                plt.imshow(abs_neg_values, cmap='Blues', alpha=0.5)

    fig = plt.gcf()
    fig.tight_layout()
    plt.show()
    if outdir:
        os.makedirs(outdir, exist_ok=True)
        print(f'Saved figure at {outdir}/{name}.png')
        fig.savefig(os.path.join(outdir, name + '.png'))


def perform_perturbation_analysis(arguments, opt):
    """
    """

    """ Setup result directory """
    outdir = arguments.get("outdir")
    print('Arguments:\n{}'.format(pformat(arguments)))

    """ Set random seed throughout python"""
    print('Using Random Seed value as: %d' % arguments['random_seed'])
    torch.manual_seed(arguments['random_seed'])  # Set for pytorch, used for gpu as well.
    random.seed(arguments['random_seed'])  # Set for python
    np.random.seed(arguments['random_seed'])  # Set for numpy

    """ Set device - cpu or gpu """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device - {device}')

    """ Load Model with weights(if available) """
    model: torch.nn.Module = get_model(arguments.get('model_args'), device, arguments['dataset_args']).to(device)

    """ Load parameters for the Dataset """
    if opt.dataset == 'ImageNet':
        transform = transforms.Compose([transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                        std=[0.229, 0.224, 0.225])])
        dataset = datasets.ImageNet(opt.datasetdir, split='val', download=False, transform=transform)

    else:
        dataset = create_dataset(arguments['dataset_args'],
                                 arguments['val_data_args'],  # Just use val_data_args as train_data_args
                                 arguments['val_data_args'])  # Split doesnt matter, we use test dataset

    attribution_methods = ['VanillaSaliency', 'input_x_grad', 'gbp', 'gradcam', 'rectgard', 'integrad']
    prune_methods = ['prune_grad_abs', 'prune_pgd_abs', 'PruneGrad-Mid']
    if opt.dataset == 'Birdsnap':
        prune_thresholds = [97]
    elif opt.dataset == 'CIFAR10':
        # prune_thresholds = [66.5, 81.5, 86.5, 88.0, 90.5, 92.0]  # [66.5, 81.5, 86.5]
        prune_thresholds = [86.5]
    elif opt.dataset == 'ImageNet':
        prune_thresholds = [90]
    attribution_methods.extend(prune_method + '_' + str(prune_threshold)
                               for prune_method, prune_threshold in itertools.product(prune_methods, prune_thresholds))

    print('Running pixel perturbation for: ', attribution_methods)

    if opt.dataset in ['CIFAR10', 'Birdsnap']:
        dataloader = dataset.test_dataloader
        testset = dataset.testset
    else:
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=get_cores_count())
        testset = dataset

    num_samples = min(10000, len(testset))  # Birdsnap test set has far few images.

    if opt.dataset in ['Birdsnap', 'ImageNet'] and opt.model == 'Resnet50':
        num_samples = int(num_samples / 5)
        smaller_test_set = torch.utils.data.Subset(testset, random.sample(range(0, len(testset)), num_samples))
        dataloader = torch.utils.data.DataLoader(smaller_test_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=get_cores_count())

    # Step sizes to remove top k or bottom k
    percentiles = list(range(1, 11, 1))  # 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
    debug = False

    # Save plots in outdir in ./roar_kar_experiments/DATASET_[train/test]/MODEL_ATTRIBUTIONMETHOD/ImageIndex.png
    model.eval()
    timestamp = datetime.datetime.now().isoformat()

    # To save sum of delta output change for each
    # attribution method, percentile and remove top and bottom percentile pixels
    output_deviation_sum = np.zeros((len(attribution_methods), len(percentiles), 2), dtype=np.float)

    # Save results in corresponding directory
    attribution_output_dir = os.path.join(outdir,
                                          # E.g. ./pixel_perturbation_analysis/timestamp_dataset_model/attrname/
                                          f'{timestamp}_'
                                          f'{opt.dataset}_'
                                          f'{arguments["model_args"]["model_arch_name"].split(".")[-1]}/')
    os.makedirs(attribution_output_dir, exist_ok=True)

    for counter, data in enumerate(tqdm(dataloader, total=1 if debug else num_samples)):
        if counter == num_samples:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        outputs = model(inputs).detach().cpu()
        _, max_prob_indices = torch.max(outputs.data, 1)
        outputs = torch.nn.functional.softmax(outputs, dim=1)
        outputs = outputs.numpy()

        for attribution_method_index, attribution_method in enumerate(attribution_methods):
            use_gray = True

            for preprocessed_image, max_prob_index, output in zip(inputs, max_prob_indices, outputs):
                attribution_map = generate_attribution(opt.dataset,
                                                       model,
                                                       opt.model,
                                                       preprocessed_image.unsqueeze(0),
                                                       max_prob_index,
                                                       device,
                                                       attribution_method)

                # Take absolute value for each pixel channel
                gray = 'gradcam' == attribution_method
                if use_gray:  # To take absolute value for each pixel channel for each attribution method.
                    gray = True
                    attribution_map = np.max(attribution_map, axis=0)

                preprocessed_image_gpu = preprocessed_image
                preprocessed_image = preprocessed_image.cpu().numpy()
                modified_images_bottom_remove = remove(preprocessed_image.copy(), attribution_map,
                                                       replace_value=[0,0,0],
                                                       # Black in original image is -mean/std in preprocessed image
                                                       percentiles=percentiles, bottom=True, gray=gray,
                                                       return_mask=debug)
                modified_images_top_remove = remove(preprocessed_image.copy(), attribution_map,
                                                    replace_value=[0,0,0],
                                                    # Black in original image is -mean/std in preprocessed image
                                                    percentiles=percentiles, bottom=False, gray=gray, return_mask=debug)

                if debug:
                    modified_images_t, masks_t = modified_images_top_remove
                    modified_images_b, masks_b = modified_images_bottom_remove
                    input_image = dataset.denormalize(preprocessed_image_gpu).cpu().numpy()
                    pixel_perturbation_debug(attribution_output_dir,
                                             f'{counter}_{attribution_method}_top_percent_kept',
                                             input_image, preprocessed_image, modified_images_t, masks_t)
                    pixel_perturbation_debug(attribution_output_dir,
                                             f'{opt.dataset}_{opt.model}_{counter}_{attribution_method}_bottom_percent_kept',
                                             input_image, preprocessed_image, modified_images_b, masks_b)
                    continue

                # Create a batch of all images
                modified_images_top_remove = torch.from_numpy(np.stack(modified_images_top_remove, axis=0)).to(device)
                modified_images_bottom_remove = torch.from_numpy(np.stack(modified_images_bottom_remove, axis=0)).to(
                    device)

                # Run forward pass - ToDo - Do in single pass
                output_top_q = model(modified_images_top_remove)
                output_bottom_q = model(modified_images_bottom_remove)

                output_top_q = torch.nn.functional.softmax(output_top_q, dim=1)
                output_bottom_q = torch.nn.functional.softmax(output_bottom_q, dim=1)

                output_top_q = output_top_q.detach().cpu().numpy()
                output_bottom_q = output_bottom_q.detach().cpu().numpy()

                # Get output value at max_prob_index for each percentile
                output_top_q_max_class_prob = output_top_q[:, max_prob_index]
                output_bottom_q_max_class_prob = output_bottom_q[:, max_prob_index]

                # Compute deviation from model output for original image at max_prob_index
                top_deviation = np.abs((output[max_prob_index] - output_top_q_max_class_prob) / output[max_prob_index])
                bottom_deviation = np.abs(
                    (output[max_prob_index] - output_bottom_q_max_class_prob) / output[max_prob_index])

                # Add this deviation to right dimension of matrix
                output_deviation_sum[attribution_method_index, :, 0] += top_deviation
                output_deviation_sum[attribution_method_index, :, 1] += bottom_deviation

        if counter % 500 == 499:
            # Divide output_deviation_sum each element by num_samples
            output_deviation_mean = output_deviation_sum / (counter + 1)

            print("\nAffect of removal of most important pixels at:-")
            for attribution_method_index, attribution_method in enumerate(attribution_methods):
                with np.printoptions(precision=3, formatter={'float': '{: 0.3f}'.format}, suppress=True,
                                     linewidth=np.inf):
                    print(attribution_method.ljust(20) + ' = ', np.array2string(output_deviation_mean[attribution_method_index, :, 0], separator=', '))

            print("Affect of removal of least important pixels at:-")
            for attribution_method_index, attribution_method in enumerate(attribution_methods):
                with np.printoptions(precision=3, formatter={'float': '{: 0.3f}'.format}, suppress=True,
                                     linewidth=np.inf):
                    print(attribution_method.ljust(20) + ' = ', np.array2string(output_deviation_mean[attribution_method_index, :, 1], separator=', '))
            print()

        if debug:
            break

    # Divide output_deviation_sum each element by num_samples
    output_deviation_mean = output_deviation_sum / num_samples

    with np.printoptions(precision=3, formatter={'float': '{: 0.3f}'.format}, suppress=True, linewidth=np.inf):
        print("Affect of removal of most important pixels at:- \npercentiles ", percentiles)
        for ind, attr in enumerate(attribution_methods):
            print(attr.ljust(15), output_deviation_mean[ind, :, 0])
    print()
    with np.printoptions(precision=3, formatter={'float': '{: 0.3f}'.format}, suppress=True, linewidth=np.inf):
        print("Affect of removal of least important pixels at:- \npercentiles ", percentiles)
        for ind, attr in enumerate(attribution_methods):
            print(attr.ljust(15), output_deviation_mean[ind, :, 1])

    # plot_pixel_perturbation(attribution_methods, opt.dataset, output_deviation_mean, percentiles, top=True)
    # plot_pixel_perturbation(attribution_methods, opt.dataset, output_deviation_mean, percentiles, top=False)

    # Save in directory
    np.save(os.path.join(attribution_output_dir, 'pixel_perturbation.npy'), output_deviation_mean)
    np.save(os.path.join(attribution_output_dir, 'percentiles.npy'), np.asarray(percentiles))


def plot_pixel_perturbation(attribution_methods,
                            dataset_name,
                            pixel_perturbation,
                            percentiles,
                            top=True):
    tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
                 (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
                 (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
                 (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
                 (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
    for i in range(len(tableau20)):
        r, g, b = tableau20[i]
        tableau20[i] = (r / 255., g / 255., b / 255.)

    plt.figure(figsize=(12, 8))
    plt.yticks([accuracy / 10 for accuracy in range(0, 11)], fontsize=14)
    plt.xticks([percentile for percentile in percentiles], fontsize=14)
    plt.ylim(0, 0.75)
    plt.xlim(0, max(percentiles))
    plt.grid(True)

    plt.xlabel(f"% of {'Most' if top else 'Least'} important input pixels removed", fontsize=14)
    plt.ylabel('Absolute Fractional Output Change', fontsize=14)

    # Remove the tick marks; they are unnecessary with the tick lines we just plotted.
    plt.tick_params(axis="both", which="both", bottom="off", top="off",
                    labelbottom="on", left="off", right="off", labelleft="on")

    attribution_to_label_mapping = {
        'VanillaSaliency': ('VanillaGradient', 2, None, None),
        'integrad': ('IntegratedGradient', 4, None, None),
        'input_x_grad': ('Grad*Input', 6, None, None),
        'gbp': ('GuidedBackprop', 8, None, None),
        'gradcam': ('GradCam', 10, None, None),
        'rectgard': ('RectGrad', 12, None, None),
    }

    prune_grad_counter = 1
    prune_grad_marker = 'd'
    prune_pgd_counter = 11
    prune_pgd_marker = '+'

    for index, attribution_method in enumerate(attribution_methods):
        # Plot each line separately with its own color, using the Tableau 20
        # color set in order.
        if attribution_method in attribution_to_label_mapping.keys():
            plt.plot(percentiles,
                     pixel_perturbation[index, :, 0 if top else 1],  # , 0 if top else 1, no top data
                     lw=2,
                     color=tableau20[attribution_to_label_mapping[attribution_method][1]],  # Use bold colors
                     label=attribution_to_label_mapping[attribution_method][0],
                     linestyle=attribution_to_label_mapping[attribution_method][2],
                     marker=attribution_to_label_mapping[attribution_method][3])
        elif 'prune_grad' in attribution_method:
            if prune_grad_counter >= 10:
                raise Exception('Not enough colors to print following methods', attribution_methods)
            plt.plot(percentiles,
                     pixel_perturbation[index, :, 0 if top else 1],  # , 0 if top else 1, no top data
                     lw=2,
                     color=tableau20[prune_grad_counter],  # Use bold colors
                     label=attribution_method,
                     linestyle='dashed',
                     marker=prune_grad_marker)
            prune_grad_counter+=2
        elif 'prune_pgd' in attribution_method:
            if prune_pgd_counter >= 20:
                raise Exception('Not enough colors to print following methods', attribution_methods)
            plt.plot(percentiles,
                     pixel_perturbation[index, :, 0 if top else 1],  # , 0 if top else 1, no top data
                     lw=2,
                     color=tableau20[prune_grad_counter],  # Use bold colors
                     label=attribution_method,
                     linestyle=None,
                     marker=prune_pgd_marker)
            prune_pgd_counter += 2

    plt.legend()

    fig = plt.gcf()
    fig.savefig(f"Cifar10-Resnet8-{'Most' if top else 'Least'}ImportantPixelPerturbationAffect.png",
                bbox_inches="tight")
    plt.show()


def parse_console_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='The dataset to choose', choices=['CIFAR10',
                                                                                                         'ImageNet',
                                                                                                         'Birdsnap'])
    parser.add_argument('--model', type=str, default='Resnet8', help='The model to use', choices=['Resnet8',
                                                                                                  'Resnet50',])
    parser.add_argument('--model_weights', type=str, required=False, default=None, help='Pretrained weights')
    parser.add_argument('--output_dir', type=str, required=False, default='./logs_birdsnap/pixel_perturbation_analysis/',
                        help='Output directory path')
    parser.add_argument('--datasetdir', type=str, required=False, default='./data/ILSVRC2012',
                        help='Dataset directory path')
    opt = parser.parse_args()

    # Load model with pretrained models if no model weights provided
    if opt.dataset == 'CIFAR10':
        model_weights_paths = dict(
            Resnet8='./logs/model_ResNet8_dataset_CIFAR10_bs_256_name_torch.optim.Adam_lr_0.001/epoch_0024-model-val_accuracy_84.36.pth',
            )
    elif opt.dataset == 'Birdsnap':
        model_weights_paths = dict(
            Resnet50='./logs_birdsnap/2020-02-26T12:35:01.546063/epoch_0011-model-val_acc_60.910338517840806.pth'
        )

    if not opt.model_weights and not opt.dataset == 'ImageNet':
        opt.model_weights = model_weights_paths[opt.model]

    return opt


def main():
    opt = parse_console_arguments()
    dataset_args = dict(
        name=MAP_DATASET_TO_ENUM[opt.dataset]
    )

    val_data_args = dict(
        batch_size=1,
        shuffle=True  # in case if a subset of test dataset is used.
    )

    # Shuffling should be on to evaluate
    assert val_data_args['shuffle']

    model_config = dict(
        Resnet8=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict()),
        Resnet50=dict(
            model_arch_name='torchvision.models.resnet50',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict(pretrained=True)),
    )

    model_args = model_config[opt.model]
    if opt.dataset == 'Birdsnap':
        model_args['model_transformer'] = append_linear_layer_transform

    arguments = dict(
        dataset_args=dataset_args,
        val_data_args=val_data_args,
        model_args=model_args,
        outdir=opt.output_dir,
        random_seed=42
    )

    perform_perturbation_analysis(arguments, opt)


if __name__ == '__main__':
    main()

# Snippet to load result and check values at different percentile output changes
# >>> import numpy as np
# >>> delta_outputs = [5, 10, 15, 20, 30, 40]
# >>> x = np.load('x.npy')
# >>> y = np.load('y.npy')
# >>> for delta_output in delta_outputs:
# ...     x[np.where(y >= delta_output)[0][0]]