import itertools

import matplotlib.pyplot as plt
import numpy as np


def plot_perturb_lc(top):
    # These are the "Tableau 20" colors as RGB.
    tableau20 = [(31, 119, 180), (174, 199, 232),
                 (255, 127, 14), (255, 187, 120),
                 (44, 160, 44), (152, 223, 138),
                 (214, 39, 40), (255, 152, 150),
                 (148, 103, 189), (197, 176, 213),
                 (140, 86, 75), (196, 156, 148),
                 (227, 119, 194), (247, 182, 210),
                 (127, 127, 127), (199, 199, 199),
                 (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]

    # Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.
    for i in range(len(tableau20)):
        r, g, b = tableau20[i]
        tableau20[i] = (r / 255., g / 255., b / 255.)

    # You typically want your plot to be ~1.33x wider than tall. This plot is a rare
    # exception because of the number of lines being plotted on it.
    # Common sizes: (10, 7.5) and (12, 9)
    plt.figure(figsize=(12, 8))

    # Make sure your axis ticks are large enough to be easily read.
    # You don't want your viewers squinting to read your plot.
    x = [x/10 for x in range(0, 11)]
    plt.yticks(x, fontsize=14)

    y = [percentile for percentile in range(0, 10, 1)]
    plt.xticks(y, fontsize=14)

    plt.xlabel(f"% of {'Most' if top else 'Least'} important input pixels removed", fontsize=14)
    plt.ylabel('Absolute Fractional Output Change', fontsize=14)

    # Limit the range of the plot to only where the data is. Avoid unnecessary whitespace.
    plt.ylim(0, 0.75)
    plt.xlim(0, 10)

    # Provide tick lines across the plot to help your viewers trace along
    # the axis ticks. Make sure that the lines are light and small so they
    # don't obscure the primary data lines.
    plt.grid(True)

    # Remove the tick marks; they are unnecessary with the tick lines we just plotted.
    plt.tick_params(axis="both", which="both", bottom="off", top="off",
                    labelbottom="on", left="off", right="off", labelleft="on")

    VanillaSaliency =     [0.216, 0.285, 0.327, 0.361, 0.390, 0.428, 0.450, 0.477, 0.510, 0.540, 0.751, 0.826, 0.856, 0.864, 0.875, 0.877, 0.863, 0.871]
    input_x_grad =        [0.150, 0.168, 0.183, 0.194, 0.200, 0.215, 0.221, 0.229, 0.235, 0.247, 0.368, 0.423, 0.475, 0.517, 0.546, 0.570, 0.636, 0.845]
    gbp =                 [0.150, 0.202, 0.242, 0.280, 0.314, 0.333, 0.350, 0.368, 0.381, 0.387, 0.502, 0.591, 0.634, 0.641, 0.641, 0.644, 0.713, 0.831]
    gradcam =             [0.076, 0.092, 0.103, 0.111, 0.120, 0.126, 0.123, 0.134, 0.136, 0.148, 0.192, 0.207, 0.260, 0.311, 0.390, 0.501, 0.643, 0.772]
    rectgard =            [0.153, 0.208, 0.258, 0.294, 0.318, 0.343, 0.358, 0.365, 0.376, 0.388, 0.532, 0.573, 0.616, 0.616, 0.624, 0.643, 0.711, 0.811]
    integrad =            [0.139, 0.169, 0.190, 0.207, 0.209, 0.217, 0.219, 0.229, 0.239, 0.246, 0.341, 0.391, 0.429, 0.454, 0.488, 0.510, 0.589, 0.836]
    prune_grad_abs_86_5 = [0.039, 0.058, 0.073, 0.088, 0.100, 0.111, 0.118, 0.125, 0.140, 0.152, 0.222, 0.294, 0.375, 0.466, 0.570, 0.676, 0.765, 0.833]
    prune_pgd_abs_86_5 =  [ 0.039, 0.058, 0.072, 0.088, 0.099, 0.111, 0.118, 0.125, 0.141, 0.150, 0.220, 0.293, 0.377, 0.465, 0.570, 0.667, 0.764, 0.831]
    perturbations = [VanillaSaliency,
                     input_x_grad,
                     gbp,
                     gradcam,
                     rectgard,
                     integrad,
                     prune_grad_abs_86_5,
                     prune_pgd_abs_86_5]

    attribution_methods = ['VanillaSaliency', 'input_x_grad', 'gbp', 'gradcam', 'rectgard', 'integrad']
    prune_methods = ['prune_grad_abs', 'prune_pgd_abs']
    prune_thresholds = [86.5]
    attribution_methods.extend(prune_method + '_' + str(prune_threshold)
                               for prune_method, prune_threshold in itertools.product(prune_methods, prune_thresholds))

    color_mapping = [
        ('VanillaGradient', 2),
        ('IntegratedGradient', 4),
        ('Grad*Input', 6),
        ('GuidedBackprop', 8),
        ('GradCam', 10),
        ('RectGrad', 12),
        ('PruneGrad', 14),
        ('PrunePGD', 16),
    ]

    attribution_to_label_mapping = {
        'VanillaSaliency': ('VanillaGradient', 2, None, None),
        'integrad': ('IntegratedGradient', 4, None, None),
        'input_x_grad': ('Grad*Input', 6, None, None),
        'gbp': ('GuidedBackprop', 8, None, None),
        'gradcam': ('GradCam', 10, None, None),
        'rectgard': ('RectGrad', 12, None, None),
        'prune_grad_abs_86.5': ('PruneGrad', 14, None, 'o'),
        'prune_pgd_abs_86.5': ('PrunePGD', 16, '--', '*'),
    }

    for index, attribution_method in enumerate(attribution_methods):
        # Plot each line separately with its own color, using the Tableau 20
        # color set in order.
        plt.plot([0]+list(range(1, 11)),
                 [0] + perturbations[index][:10],  # , 0 if top else 1, no top data
                 lw=2,
                 color=tableau20[attribution_to_label_mapping[attribution_method][1]],
                 label=attribution_to_label_mapping[attribution_method][0],
                 linestyle=attribution_to_label_mapping[attribution_method][2],
                 marker=attribution_to_label_mapping[attribution_method][3])

    plt.legend()

    # matplotlib's title() call centers the title on the plot, but not the graph,
    # so I used the text() call to customize where the title goes.

    # Make the title big enough so it spans the entire plot, but don't make it
    # so big that it requires two lines to show.

    # Note that if the title is descriptive enough, it is unnecessary to include
    # axis labels; they are self-evident, in this plot's case.
    # plt.text(50,
    #          95,
    #          "CIFAR10",
    #          fontsize=17,
    #          ha="center")

    # Finally, save the figure as a PNG.
    # You can also save it as a PDF, JPEG, etc.
    # Just change the file extension in this call.
    # bbox_inches="tight" removes all the extra whitespace on the edges of your plot.
    plt.savefig(f"Cifar10-Resnet8-{'Most' if top else 'Least'}ImportantPixelPerturbationAffect.png",
                bbox_inches="tight")
    plt.show()


# plot_perturb_lc(top=True)
plot_perturb_lc(top=False)
