import numpy as np
import matplotlib.pyplot as plt

STD_GAP = 0.5
ALPHA = 0.1


def plot_helper(timesteps, accuracies, accuracies_stds, label, color, broadcast=False):
    if broadcast:
        accuracies = np.array([accuracies] * len(timesteps))
        accuracies_stds = np.array([accuracies_stds] * len(timesteps))
    plt.plot(
        timesteps,
        accuracies,
        label=label,
        linestyle='solid',
        linewidth=1,
        color=color,
    )
    plt.fill_between(
        timesteps,
        accuracies - STD_GAP * accuracies_stds,
        accuracies + STD_GAP * accuracies_stds,
        color=color,
        alpha=ALPHA,
    )


def plot_protected(
        timesteps,
        experiment_results,
        color,
        label,
):
    # PROTECTED CHARACTERISTIC PLOTS
    plot_helper(
        timesteps,
        experiment_results.mean_protected_accepted_averages,
        experiment_results.std_protected_accepted_averages,
        label=label,
        color=color,
    )
    plt.xlabel("Timesteps")
    plt.ylabel("Percentage Accepted")


def plot_regret(
    timesteps,
    experiment_results,
    color,
    label,
):

    # REGRET PLOTS
    plot_helper(
        timesteps,
        experiment_results.mean_train_cum_regret_averages,
        experiment_results.std_train_cum_regret_averages,
        label=label,
        color=color,
    )
    plt.xlabel("Timesteps")
    plt.ylabel("Regret")
