from prv_accountant.dpsgd import DPSGDAccountant
import math
import pickle
import os
from sklearn.metrics import roc_curve
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


def plot_curve(x, y, xlabel, ylabel, ax, label, color, style, title=None):
    ax.plot([0, 1], [0, 1], 'k-', lw=1.0)
    ax.plot(x, y, lw=2, label=label, color=color, linestyle=style)
    ax.set(xlabel=xlabel, ylabel=ylabel)
    ax.set(aspect=1, xscale='log', yscale='log')
    if title is not None:
        ax.title.set_text(title)


def compute_attack_advantage(fpr, tpr):
    return max(tpr - fpr)


def plot_roc_curve(y_true_list, y_score_list, legend_list, save_path, title, deltas, epsilons, accountants, flip_legend=False):
    assert len(y_true_list) == len(y_score_list)
    assert len(legend_list) == len(y_true_list)
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    for y_true, y_score, legend in zip(y_true_list, y_score_list, legend_list):
        fpr, tpr, thresholds = roc_curve(y_true=y_true, y_score=y_score)

        # plot the roc curve
        plot_curve(x=fpr, y=tpr, xlabel='FPR', ylabel='TPR', style="-", ax=ax, label='{0:}'.format(
            legend), color="black", title=title)

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(deltas)]

    all_bounds = list()
    for j, delta in enumerate(deltas):  # should only happen in the case of epsilon == '1'
        bound = []
        for i in fpr:
            bound.append(min(math.exp(2 * epsilons[j]) * i + (1 + math.exp(epsilons[j])) * delta, 1.0))
        bound = np.hstack(bound)

        # plot RDP curve
        if j == 0:
            plot_curve(x=fpr, y=bound, xlabel='FPR', ylabel='TPR', ax=ax, color=colors[j], style=':',
                   label=f"UB ({accountants[j]}): $\epsilon$={round(epsilons[j], 2)} $\delta=${delta}", title=None)
        else:
            all_bounds.append(bound)

    # get tighest bound from all computed PRV bounds
    tighest_bound = np.amin(all_bounds, axis=0)
    plot_curve(x=fpr, y=tighest_bound, xlabel='FPR', ylabel='TPR', ax=ax, color="red", style=':',
                label="tighest UB (PRV) over multiple $\delta$", title=None)

    if flip_legend:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[::-1], labels=labels[::-1], loc='lower right', fontsize=9.3)
    else:
        plt.legend(loc='lower right', fontsize=9.3)
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)


def plot_deltas(shot, epsilons, deltas, accountants, config, data_path):
    # plot by shot
    for epsilon in ['1']:
        title = 'shots = {}, original RDP $ϵ$={}'.format(shot, '$\infty$' if epsilon == 'inf' else epsilon)
        y_true_list = []
        y_score_list = []
        legend_list = []
        with open(os.path.join(data_path, 'scores_{}_{}_{}.pkl'.format(config, str(shot), epsilon)), "rb") as f:
            result = pickle.load(f)
            y_true_list.append(result['y_true'])
            y_score_list.append(result['scores'])
        legend_list.append('{}'.format('Head' if config == 'none' else 'FiLM'))

        plot_roc_curve(y_true_list, y_score_list, legend_list, os.path.join(OUT_DIR, 'roc_eps_{}_{}.pdf'.format(config, epsilon)),
                       title, deltas, epsilons, accountants, flip_legend=False)


def plot_film_at_eps_1(data_path):
    total_steps = 248
    sample_rate = 0.5
    noise_multiplier = 22.96875
    deltas, epsilons, accountants = [1/1000], [1], ["RDP"]

    # PRV
    prv_accountant = DPSGDAccountant(
        noise_multiplier=noise_multiplier,
        sampling_probability=sample_rate,
        eps_error=1e-5,
        delta_error=1e-11,
        max_steps=total_steps
    )

    for d in [1/1000, 1/2000, 1/5000, 1e-4, 1e-5, 1e-6]:
        eps_low, eps_estimate, eps_upper = prv_accountant.compute_epsilon(num_steps=total_steps, delta=d)
        deltas.append(d)
        epsilons.append(eps_upper)
        accountants.append("PRV")
        print(f"δ={d} results in ϵ={eps_upper}")

    plot_deltas(10, epsilons, deltas, accountants, "film", data_path)


def plot_head_at_eps_1(data_path):
    total_steps = 398
    sample_rate = 0.5
    noise_multiplier = 29.0625
    deltas, epsilons, accountants = [1/1000], [1], ["RDP"]

    # PRV
    prv_accountant = DPSGDAccountant(
        noise_multiplier=noise_multiplier,
        sampling_probability=sample_rate,
        eps_error=1e-4,
        delta_error=1e-10,
        max_steps=total_steps
    )

    for d in [1/1000, 1/2000, 1/5000, 1e-4, 1e-5, 1e-6]:
        eps_low, eps_estimate, eps_upper = prv_accountant.compute_epsilon(num_steps=total_steps, delta=d)
        deltas.append(d)
        epsilons.append(eps_upper)
        accountants.append("PRV")
        print(f"δ={d} results in ϵ={eps_upper}")

    plot_deltas(10, epsilons, deltas, accountants, "none", data_path)


if __name__ == "__main__":

    OUT_DIR = ''
    print("Head")
    plot_head_at_eps_1(os.path.expanduser("~/Documents/results"))
    print("FiLM")
    plot_film_at_eps_1(os.path.expanduser("~/Documents/results"))
