import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
import numpy as np
import sys
sys.path.append('..')
from config import opt
import seaborn as sns

plt.rcParams["font.family"] = "Times New Roman"

import matplotlib.pyplot as plt
import matplotlib

def plot_rect(title, x_list, y_list):
    save_dir = f'{opt.work_dir}results/rect_chart/'
    rects = plt.bar(x=x_list, height=y_list, width=1, color='blue')
    # for rect in rects:
    #     height = rect.get_height()
    #     plt.text(rect.get_x() + rect.get_width() / 2, height + 1, str(height), ha="center", va="bottom")
    save_name = '_'.join(title.split(' '))
    plt.savefig(f'{save_dir}{save_name}.png')

def plot_line(title, x_list, y_list, labels,
              x_label=None, y_label=None,
              y_lim=None, legend=True):
    save_dir = f'{opt.work_dir}results/line_chart/'
    assert len(x_list) == len(y_list) and len(y_list) == len(labels)
    fig, ax = plt.subplots()
    for i in range(len(x_list)):
        ax.plot(x_list[i], y_list[i], label=labels[i])
    ax.set_title(title)
    if x_label:
        ax.set_xlabel(x_label)
    if y_label:
        ax.set_ylabel(y_label)
    if legend:
        ax.legend()
    if y_lim:
        plt.ylim(ymin=y_lim[0], ymax=y_lim[1])
    save_name = '_'.join(title.split(' '))
    plt.savefig(f'{save_dir}{save_name}.pdf')

def plot_compare_fig():
    save_dir = f'{opt.work_dir}results/line_chart/'
    title = 'compare image'
    # plot line
    x_list, y_list, labels = ''
    x_label = 'num of attacked image'
    y_label = 'query efficiency'
    legend = False
    assert len(x_list) == len(y_list) and len(y_list) == len(labels)
    fig, ax = plt.subplots()
    for i in range(len(x_list)):
        ax.plot(x_list[i], y_list[i], label=labels[i])
    ax.set_title(title)
    if x_label:
        ax.set_xlabel(x_label)
    if y_label:
        ax.set_ylabel(y_label)
    if legend:
        ax.legend()
    plt.savefig(f'{save_dir}{title}.pdf')

def plot_substitute_compare():
    save_dir = f'{opt.work_dir}results/line_chart/'
    save_name = 'substitute_comparison'
    save_path = save_dir+save_name

    pal = sns.color_palette('Set2')
    sns.set_palette(pal)

    linestyles = ['-', '--', '-.', ':']

    subplot_w = 15
    subplot_h = 9
    plt.figure(figsize=(3*subplot_w+5, 2*subplot_h+5))
    fig, ax = plt.subplots(2, 3, figsize=(subplot_w, subplot_h))

    labels = ['Papernot', 'Black-Box Ripper', 'ActiveThief', 'SEEKER']
    x_label = 'query budget(K)'

    line_width = 2

    # cifar10-cifar100
    x_list = [[0.0, 0.11, 0.22, 0.44, 0.88, 1.76, 3.5, 7.0]] + \
        [[j for j in np.arange(0, 7.5, 0.5)] for _ in range(3)]
    # accuracy
    papernot = [0.1, 0.1, 0.18, 0.2, 0.22, 0.22, 0.23, 0.25]
    black_box_ripper = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
                        0.1, 0.1, 0.1, 0.1]
    active_kc = [0.1, 0.27, 0.32, 0.35, 0.36, 0.36, 0.37, 0.40, 0.40, 0.41, 0.42,
                 0.42, 0.43, 0.43, 0.43]
    ours = [0.1, 0.45, 0.55, 0.62, 0.627, 0.66, 0.666, 0.67, 0.69, 0.69, 0.69,
            0.69, 0.70, 0.70, 0.70]#, 0.71, 0.70, 0.72, 0.73, 0.73, 0.73]
    y_list = [papernot, black_box_ripper, active_kc, ours]
    for i in range(len(x_list)):
        ax[0, 0].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[0, 0].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[0, 0].set_ylabel('Accuracy', fontsize=13, weight='bold')
        # ax[0, 0].set_xlim()

    # fidelity
    papernot = [0.1, 0.11, 0.18, 0.21, 0.22, 0.22, 0.23, 0.25]
    black_box_ripper = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
                        0.1, 0.1, 0.1, 0.1]
    active_kc = [0.1, 0.29, 0.34, 0.37, 0.39, 0.39, 0.40, 0.42, 0.44, 0.44, 0.45,
                 0.45, 0.46, 0.46, 0.46]
    ours = [0.1, 0.47, 0.58, 0.65, 0.65, 0.69, 0.705, 0.71, 0.71, 0.72, 0.73,
            0.74, 0.744, 0.748, 0.75]#, 0.71, 0.70, 0.72, 0.73, 0.73, 0.73]
    y_list = [papernot, black_box_ripper, active_kc, ours]
    for i in range(len(x_list)):
        ax[0, 1].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[0, 1].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[0, 1].set_ylabel('Fidelity', fontsize=13, weight='bold')
        # ax[0, 0].set_xlim()

    # ASR
    papernot = [0.1, 0.1, 0.13, 0.124, 0.125, 0.11, 0.14, 0.167]
    black_box_ripper = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
                        0.1, 0.1, 0.1, 0.1]
    active_kc = [0.1, 0.13, 0.20, 0.20, 0.26, 0.27, 0.26, 0.25, 0.27, 0.24, 0.27,
                 0.26, 0.24, 0.245, 0.258]
    ours = [0.1, 0.3, 0.51, 0.545, 0.67, 0.70, 0.75, 0.77, 0.76, 0.77, 0.81,
            0.81, 0.826, 0.84, 0.844]#, 0.88, 0.89, 0.90, 0.91, 0.88, 0.89]
    y_list = [papernot, black_box_ripper, active_kc, ours]
    for i in range(len(x_list)):
        ax[0, 2].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[0, 2].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[0, 2].set_ylabel('ASR', fontsize=13, weight='bold')
        # ax[0, 0].set_xlim()

    # cifar100-cifar10
    x_list = [[0.0, 0.5, 1.0, 2.0, 4.0, 8.0]] + \
             [[j for j in np.arange(0, 7.5, 0.5)] for _ in range(3)]
    # accuracy
    papernot = [0.01, 0.009, 0.01, 0.009, 0.01, 0.009]
    black_box_ripper = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
                        0.01, 0.01, 0.01, 0.01]
    active_kc = [0.01, 0.06, 0.065, 0.08, 0.09, 0.095, 0.10, 0.11, 0.11, 0.10, 0.12,
                 0.12, 0.122, 0.124, 0.124]
    ours = [0.01, 0.07, 0.12, 0.16, 0.19, 0.22, 0.22, 0.256, 0.28, 0.27, 0.29,
            0.28, 0.3, 0.29, 0.305]#, 0.71, 0.70, 0.72, 0.73, 0.73, 0.73]
    y_list = [papernot, black_box_ripper, active_kc, ours]
    for i in range(len(x_list)):
        ax[1, 0].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[1, 0].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[1, 0].set_ylabel('Accuracy', fontsize=13, weight='bold')
        # ax[0, 0].set_xlim()

    # fidelity
    papernot = [0.01, 0.009, 0.01, 0.009, 0.011, 0.01]
    black_box_ripper = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
                        0.01, 0.01, 0.01, 0.01]
    active_kc = [0.01, 0.06, 0.07, 0.08, 0.09, 0.10, 0.105, 0.11, 0.10, 0.105, 0.12,
                 0.12, 0.125, 0.12, 0.125]
    ours = [0.01, 0.07, 0.13, 0.17, 0.20, 0.23, 0.23, 0.26, 0.283, 0.285, 0.31,
            0.31, 0.32, 0.31, 0.322]#, 0.71, 0.70, 0.72, 0.73, 0.73, 0.73]
    y_list = [papernot, black_box_ripper, active_kc, ours]
    for i in range(len(x_list)):
        ax[1, 1].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[1, 1].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[1, 1].set_ylabel('Fidelity', fontsize=13, weight='bold')
        # ax[0, 0].set_xlim()

    # ASR
    papernot = [0.28, 0.353, 0.005, 0.003, 0.006, 0.003]
    black_box_ripper = [0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28,
                        0.28, 0.28, 0.28, 0.28]
    active_kc = [0.28, 0.35, 0.35, 0.41, 0.33, 0.325, 0.28, 0.28, 0.28, 0.25, 0.26,
                 0.26, 0.27, 0.25, 0.289]
    ours = [0.28, 0.49, 0.50, 0.56, 0.625, 0.68, 0.71, 0.69, 0.71, 0.71, 0.77,
            0.79, 0.75, 0.79, 0.773]#, 0.88, 0.89, 0.90, 0.91, 0.88, 0.89]
    y_list = [papernot, black_box_ripper, active_kc, ours]
    for i in range(len(x_list)):
        ax[1, 2].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[1, 2].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[1, 2].set_ylabel('ASR', fontsize=13, weight='bold')
        # ax[0, 0].set_xlim()

    for i in range(2):
        for j in range(3):
            ax[i, j].tick_params(labelsize=11)
            labels = ax[i, j].get_xticklabels() + ax[i, j].get_yticklabels()
            [label.set_fontweight('bold') for label in labels]
            ax[i, j].xaxis.set_major_locator(MultipleLocator(1))
            ax[i, j].legend(fontsize=11, loc='upper left')
            ax[i, j].set_xlim(0, 7)

    ax[0, 0].set_ylim(0.0, 1.0)
    ax[0, 1].set_ylim(0.0, 1.0)
    ax[0, 2].set_ylim(0.0, 1.1)
    ax[1, 0].set_ylim(0.0, 0.37)
    ax[1, 1].set_ylim(0.0, 0.37)
    ax[1, 2].set_ylim(0.0, 1.0)

    ax[0, 0].yaxis.set_major_locator(MultipleLocator(0.2))
    ax[0, 1].yaxis.set_major_locator(MultipleLocator(0.2))
    ax[0, 2].yaxis.set_major_locator(MultipleLocator(0.2))
    ax[1, 0].yaxis.set_major_locator(MultipleLocator(0.1))
    ax[1, 1].yaxis.set_major_locator(MultipleLocator(0.1))
    ax[1, 2].yaxis.set_major_locator(MultipleLocator(0.2))

    fig.tight_layout()

    plt.savefig(save_path, bbox_inches='tight', pad_inches=1)

# compare with query-based methods
def plot_query_compare():
    save_dir = f'{opt.work_dir}results/line_chart/'
    save_name = 'qer_comparison'
    save_path = save_dir + save_name

    pal = sns.color_palette('Set2')
    sns.set_palette(pal)

    linestyles = [(0, (1, 1)), '-', '--', '-.', ':']

    subplot_w = 12
    subplot_h = 5
    plt.figure(figsize=(2 * subplot_w + 5, 1 * subplot_h + 5))
    fig, ax = plt.subplots(1, 2, figsize=(subplot_w, subplot_h))

    labels = ['Papernot', 'Black-Box Ripper', 'ActiveThief', 'SEEKER', 'Simulator']
    x_label = 'Successful attacks'

    line_width = 2

    # x axis: successfully attacked picture in 10000 pictures
    # y axis: query efficiency, i.e. attack num / query efficiency

    # cifar10-cifar100
    # simulator attack
    simulator_batch_query = [1510, 1712, 1638, 1474, 1752, 1794, 1622, 1574, 1780, 1678,
                             1544, 1702, 1692, 1722, 1766, 1706, 1796, 1696, 1712, 1760]
    simulator_query = [simulator_batch_query[0]]
    for i in range(1, len(simulator_batch_query)):
        simulator_query.append(simulator_query[-1] + simulator_batch_query[i])
    simulator_batch_attack_n = np.array([81.1, 91.5, 87.0, 79.8, 94.6, 86.6, 87.0, 88.4, 91.6, 88.8,
                                         79.8, 89.1, 87.6, 86.0, 89.2, 87.1, 88.2, 91.0, 88.9, 91.0]) * 0.842
    simulator_attack_n = [simulator_batch_attack_n[0]]
    for i in range(1, len(simulator_batch_attack_n)):
        simulator_attack_n.append(simulator_attack_n[-1]+simulator_batch_attack_n[i])
    assert len(simulator_query) == len(simulator_attack_n)
    simulator_query_efficiency = [simulator_attack_n[i] / simulator_query[i] for i in range(len(simulator_query))]

    # Papernot
    papernot_query = [7040 for _ in range(20)]
    papernot_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.137 * 0.842
    assert len(papernot_query) == len(papernot_attack_n)
    papernot_query_efficiency = [papernot_attack_n[i] / papernot_query[i] for i in range(len(papernot_query))]

    # Black-Box Ripper
    black_query = [7000 for _ in range(20)]
    black_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.081 * 0.842
    assert len(black_query) == len(black_attack_n)
    black_query_efficiency = [black_attack_n[i] / black_query[i] for i in range(len(black_query))]

    # active-kc
    active_kc_query = [6500 for _ in range(20)]
    active_kc_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.274 * 0.842
    assert len(active_kc_query) == len(active_kc_attack_n)
    active_kc_query_efficiency = [active_kc_attack_n[i] / active_kc_query[i] for i in range(len(active_kc_query))]

    # ours
    # our_query = [j for j in np.arange(500, 5500, 500)]
    # our_attack_n = np.array([39.6, 61.3, 70.2, 73.9, 78.0, 78.4, 81.6, 78.1, 83.2, 86.9]) * 0.842 * 1000 / 100
    # our_attack_n = [39.6, 61.3, 70.2, 73.9, 78.0, 78.4, 81.6, 78.1, 83.2, 86.9,
    #                       82.8, 86.9, 87.6, 87.7, 89.9, 87.7, 86.6, 86.8, 86.1, 86.7] * 0.842 * 10000 / 100
    our_query = [6000 for _ in range(20)]
    our_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.864 * 0.842
    assert len(our_query) == len(our_attack_n)
    our_query_efficiency = [our_attack_n[i] / our_query[i] for i in range(len(our_query))]

    x_list = [papernot_attack_n, black_attack_n, active_kc_attack_n, our_attack_n, simulator_attack_n]
    y_list = [papernot_query_efficiency, black_query_efficiency,
              active_kc_query_efficiency, our_query_efficiency, simulator_query_efficiency]

    for i in range(len(x_list)):
        ax[0].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[0].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[0].set_ylabel('QER', fontsize=13, weight='bold')

    # simulator attack
    simulator_batch_query = [1510, 1712, 1638, 1474, 1752, 1794, 1622, 1574, 1780, 1678,
                             1544, 1702, 1692, 1722, 1766, 1706, 1796, 1696, 1712, 1760]
    simulator_query = [simulator_batch_query[0]]
    for i in range(1, len(simulator_batch_query)):
        simulator_query.append(simulator_query[-1] + simulator_batch_query[i])
    simulator_batch_attack_n = np.array([81.1, 91.5, 87.0, 79.8, 94.6, 86.6, 87.0, 88.4, 91.6, 88.8,
                                         79.8, 89.1, 87.6, 86.0, 89.2, 87.1, 88.2, 91.0, 88.9, 91.0]) * 0.6526
    simulator_attack_n = [simulator_batch_attack_n[0]]
    for i in range(1, len(simulator_batch_attack_n)):
        simulator_attack_n.append(simulator_attack_n[-1]+simulator_batch_attack_n[i])
    assert len(simulator_query) == len(simulator_attack_n)
    simulator_query_efficiency = [simulator_attack_n[i] / simulator_query[i] for i in range(len(simulator_query))]

    # Papernot
    papernot_query = [8000 for _ in range(20)]
    papernot_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.03 * 0.6526
    assert len(papernot_query) == len(papernot_attack_n)
    papernot_query_efficiency = [papernot_attack_n[i] / papernot_query[i] for i in range(len(papernot_query))]

    # Black-Box Ripper
    black_query = [7000 for _ in range(20)]
    black_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.278 * 0.6526
    assert len(black_query) == len(black_attack_n)
    black_query_efficiency = [black_attack_n[i] / black_query[i] for i in range(len(black_query))]

    # active-kc
    active_kc_query = [6000 for _ in range(20)]
    active_kc_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.274 * 0.6526
    assert len(active_kc_query) == len(active_kc_attack_n)
    active_kc_query_efficiency = [active_kc_attack_n[i] / active_kc_query[i] for i in range(len(active_kc_query))]

    # ours
    # our_query = [j for j in np.arange(500, 5500, 500)]
    # our_attack_n = np.array([39.6, 61.3, 70.2, 73.9, 78.0, 78.4, 81.6, 78.1, 83.2, 86.9]) * 0.842 * 1000 / 100
    # our_attack_n = [39.6, 61.3, 70.2, 73.9, 78.0, 78.4, 81.6, 78.1, 83.2, 86.9,
    #                       82.8, 86.9, 87.6, 87.7, 89.9, 87.7, 86.6, 86.8, 86.1, 86.7] * 0.842 * 10000 / 100
    our_query = [5500 for _ in range(20)]
    our_attack_n = np.array([i for i in np.arange(100, 2100, 100)]) * 0.82 * 0.6526
    assert len(our_query) == len(our_attack_n)
    our_query_efficiency = [our_attack_n[i] / our_query[i] for i in range(len(our_query))]

    x_list = [papernot_attack_n, black_attack_n, active_kc_attack_n, our_attack_n, simulator_attack_n]
    y_list = [papernot_query_efficiency, black_query_efficiency,
              active_kc_query_efficiency, our_query_efficiency, simulator_query_efficiency]

    for i in range(len(x_list)):
        ax[1].plot(x_list[i], y_list[i], label=labels[i], linewidth=line_width, linestyle=linestyles[i])
        ax[1].set_xlabel(x_label, fontsize=13, weight='bold')
        ax[1].set_ylabel('QER', fontsize=13, weight='bold')


    for i in range(2):
        ax[i].tick_params(labelsize=11)
        labels = ax[i].get_xticklabels() + ax[i].get_yticklabels()
        [label.set_fontweight('bold') for label in labels]
        ax[i].legend(fontsize=11, loc='upper left')
        # ax[i].set_xlim(0, 7)
        # ax[i].xaxis.set_major_locator(MultipleLocator(1))

    fig.tight_layout()

    plt.savefig(save_path, bbox_inches='tight', pad_inches=1)


def main():
    # plot_substitute_compare()
    # plot_query_compare()
    x_list = [i for i in range(10)]
    # y_list = [0.3180, 0.0348, 0.1114, 0.5958, 0.1303, 0.2952, 0.0705, 0.0770, 0.1666, 0.2005]
    # title = 'query data-victim prob'
    y_list = [0.1324, 0.0343, 0.1202, 0.2040, 0.0771, 0.1182, 0.1056, 0.0452, 0.0927, 0.0703]
    title = 'public data-victim prob'
    plot_rect(title, x_list, y_list)
if __name__ == '__main__':
    main()