import setting
import os

setting.init()
from bandit import *
import matplotlib.pyplot as plt
import copy
import itertools


plt.rcParams.update(
    {
        "font.size": 15,
        "text.usetex": True,
    }
)


### Toggling the algorithms
plot_satucb = True
plot_satucbplus = True
plot_satucb_variants = False
plot_other_algorithms = True
plot_ucb_algorithm = False
plot_algo1xucb = False
# UCB_alpha = False # Perform the UCB-alpha experiment (outdated code)

### Toggling the special experiments
experiment_multiple_satisfaction_levels = False
experiment_1_run = False
experiment_round_robin = False

### Choose the setting for the experiment
reward_distribution = "gaussian"  # "gaussian" or "bernoulli"
realizable_case = True
setting_id = 1  # 1, 2 or 3

nb_arm = 20
nb_repetition = 100
resolution_value = 300
plt.figure(figsize=(8, 6))

### Plotting parameters
error_type = "std"
# error_type="std" or "quantile" (only for plot_error=True) (This is not really an error but a confidence interval)
plot_error = True
error_bar = (
    True  # Draw Error Bar (True) or Draw Error Band (False) (only for plot_error=True)
)
################
# Setting up the experiment
if realizable_case:
    nb_step = 10000
    errorevery_bar = 3300
else:
    nb_step = 50000
    errorevery_bar = 12000

if setting_id == 1:
    mean_arms = [i / nb_arm for i in range(nb_arm)]
    satisfaction_level = 0.8 if realizable_case else 1.0
elif setting_id == 2:
    mean_arms = [(i / nb_arm) ** 0.5 for i in range(nb_arm)]
    mean_arms.reverse()  # This time the best arm is not the last one
    satisfaction_level = 0.92 if realizable_case else 1.1
elif setting_id == 3:
    mean_arms = [0 for i in range(nb_arm)]
    mean_arms[0] = 1.0
    satisfaction_level = 0.5 if realizable_case else 1.1
else:
    raise ValueError("Undefined setting_id")

print("The mean of arms are: ", mean_arms)
copy_mean_arms = np.array(mean_arms)
copy_mean_arms.sort()
maximum_arm = np.max(mean_arms)
print("Arm means:", copy_mean_arms)
print("The largest mean of the arms is: ", maximum_arm)
print("The second largest mean of the arms is: ", copy_mean_arms[-2])
print("The satisfying level is: ", satisfaction_level)
print(
    "The number of satisfying arms is ",
    len(copy_mean_arms[copy_mean_arms >= satisfaction_level]),
)

# Bandit definition
if reward_distribution == "gaussian":
    bandits = [
        GaussianBandit(
            nb_arm,
            means=[mean_arms[i] for i in range(nb_arm)],
            sigmas=[1 for i in range(nb_arm)],
        )
    ]
elif reward_distribution == "bernoulli":
    bandits = [
        BernoulliBandit(
            nb_arm,
            means=[mean_arms[i] for i in range(nb_arm)],
        )
    ]


#####################################################################################################################################################################################################################################################################################################################
######################## Experiment ########################
def plot_results(algo_name, regrets, var, plot_error, error_bar, errorevery_bar=0):
    if plot_error == False:
        plt.plot(np.cumsum(regrets), label=algo_name)
    else:
        cum_regrets = np.cumsum(regrets)
        if error_bar == False:
            plt.plot(cum_regrets, label=algo_name)
            if var.shape[0] == 2:
                plt.fill_between(
                    np.arange(nb_step),
                    var[0],
                    var[1],
                    alpha=0.2,
                    label="_{}".format(algo_name),
                )
            else:
                var = var / np.sqrt(nb_repetition)
                plt.fill_between(
                    np.arange(nb_step),
                    cum_regrets - var,
                    cum_regrets + var,
                    alpha=0.2,
                    label="_{}".format(algo_name),
                )
        else:
            if var.shape[0] == 2:
                plt.errorbar(
                    np.arange(nb_step),
                    cum_regrets,
                    yerr=[
                        np.maximum(cum_regrets - var[0], 0),
                        np.maximum(var[1] - cum_regrets, 0),
                    ],
                    errorevery=errorevery_bar,
                    label=algo_name,
                )
            else:
                var = var / np.sqrt(nb_repetition)
                plt.errorbar(
                    np.arange(nb_step),
                    cum_regrets,
                    yerr=var,
                    errorevery=errorevery_bar,
                    label=algo_name,
                )


if experiment_1_run:
    rewards_ucb, regrets_ucb, expectations_ucb, std_ucb = experiment(
        ucb,
        bandits,
        satisfaction_level,
        1000,
        1,
        error_type=error_type,
    )
    rewards_algo3, regrets_algo3, expectations_algo3, std_algo3 = experiment(
        algo3,
        bandits,
        satisfaction_level,
        1000,
        1,
        error_type=error_type,
    )
    plt.plot(expectations_ucb, ".", label=r"\textsc{UCB1}")
    plt.plot(expectations_algo3, ".", label=r"\textsc{Sat-UCB}")
    plt.plot(
        np.ones_like(expectations_algo3) * satisfaction_level,
        label=r"\rm Satisfaction level",
    )
    # plt.title("Expected reward of the arm played at each step\nGaussian rewards")
    plt.xlabel(r"\rm Time step")
    plt.ylabel(r"\rm Expected reward")
    plt.legend()
    plt.tight_layout()
    plt.show()
    exit()

if experiment_round_robin:
    errorevery_bar -= 200
    rewards_algo1, regrets_algo1, expectations_algo1, std_algo1 = experiment(
        algo1,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    plot_results(
        r"\textsc{Algorithm 1}",
        regrets_algo1,
        std_algo1,
        plot_error,
        error_bar,
        errorevery_bar,
    )

    errorevery_bar -= 200
    rewards_algo_rr, regrets_algo_rr, expectations_algo_rr, std_algo_rr = experiment(
        algoroundrobin,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    plot_results(
        r"\textsc{Algorithm 1 with Round Robin}",
        regrets_algo_rr,
        std_algo_rr,
        plot_error,
        error_bar,
        errorevery_bar,
    )


if plot_algo1xucb:
    errorevery_bar -= 200
    (
        rewards_algo1xucb,
        regrets_algo1xucb,
        expectations_algo1xucb,
        std_algo1xucb,
    ) = experiment(
        algo1xucb,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    plot_results(
        r"\textsc{Algorithm 1 with UCB Exploration}",
        regrets_algo1xucb,
        std_algo1xucb,
        plot_error,
        error_bar,
        errorevery_bar,
    )


if plot_satucb:
    errorevery_bar -= 200
    rewards_algo3, regrets_algo3, expectations_algo3, std_algo3 = experiment(
        algo3,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    plot_results(
        r"\textsc{Sat-UCB}",
        regrets_algo3,
        std_algo3,
        plot_error,
        error_bar,
        errorevery_bar,
    )

if plot_satucbplus:
    errorevery_bar -= 200
    rewards_algo3, regrets_algo3, expectations_algo3, std_algo3 = experiment(
        algo3_old,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    plot_results(
        r"\textsc{Sat-UCB}$^+$",
        regrets_algo3,
        std_algo3,
        plot_error,
        error_bar,
        errorevery_bar,
    )

######################## Experimental Algorithm3 multiple satisfaction level ########################
if experiment_multiple_satisfaction_levels == True:
    if setting_id == 3:
        satisfaction_level_list = [1.0, 1.025, 1.05, 1.1, 1.2]
    elif setting_id == 1:
        satisfaction_level_list = [0.95, 0.96, 0.97, 1.0, 1.1]
    for i in range(len(satisfaction_level_list)):
        rewards_algo3, regrets_algo3, expectations_algo3, std_algo3 = experiment(
            algo3_old,
            bandits,
            satisfaction_level_list[i],
            nb_step,
            nb_repetition,
            error_type=error_type,
        )
        errorevery_bar -= 200
        plot_results(
            r"\rm \textsc{Sat-UCB}$^+$ S-Level: " + str(satisfaction_level_list[i]),
            regrets_algo3,
            std_algo3,
            plot_error,
            error_bar,
            errorevery_bar,
        )

######################## Algorithm3 Variants ########################
if plot_satucb_variants == True:
    (
        rewards_algo3_old,
        regrets_algo3_old,
        expectations_algo3_old,
        std_algo3_old,
    ) = experiment(
        algo3_old,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    errorevery_bar -= 200
    plot_results(
        r"\textsc{Sat-UCB}$^+$",
        regrets_algo3_old,
        std_algo3_old,
        plot_error,
        error_bar,
        errorevery_bar,
    )
    ####
    (
        rewards_algo3xucb,
        regrets_algo3xucb,
        expectations_algo3xucb,
        std_algo3xucb,
    ) = experiment(
        algo3xucb,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    errorevery_bar -= 200
    plot_results(
        r"\textsc{Sat-UCB x UCB1}",
        regrets_algo3xucb,
        std_algo3xucb,
        plot_error,
        error_bar,
        errorevery_bar,
    )
    #####
    (
        rewards_algo3xavg,
        regrets_algo3xavg,
        expectations_algo3xavg,
        std_algo3xavg,
    ) = experiment(
        algo3xavg,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    errorevery_bar -= 200
    plot_results(
        r"\textsc{Sat-UCB x Average}",
        regrets_algo3xavg,
        std_algo3xavg,
        plot_error,
        error_bar,
        errorevery_bar,
    )

####### UCB UCB UCB UCB UCB
if plot_ucb_algorithm == True:
    rewards_ucb, regrets_ucb, expectations_ucb, std_ucb = experiment(
        ucb, bandits, satisfaction_level, nb_step, nb_repetition, error_type=error_type
    )
    """rewards_ucb5, regrets_ucb5, expectations_ucb5, std_ucb5 = experiment(
        ucb,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
        parameter=0.5,
    )
    rewards_ucb25, regrets_ucb25, expectations_ucb25, std_ucb25 = experiment(
        ucb,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
        parameter=0.25,
    )
    rewards_ucb125, regrets_ucb125, expectations_ucb125, std_ucb125 = experiment(
        ucb,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
        parameter=0.125,
    )"""
    errorevery_bar -= 200
    plot_results(
        r"\textsc{UCB1}", regrets_ucb, std_ucb, plot_error, error_bar, errorevery_bar
    )
    """plot_results(
        "UCB1 (confidence*0.5)",
        regrets_ucb5,
        std_ucb5,
        plot_error,
        error_bar,
        errorevery_bar,
    )
    plot_results(
        "UCB1 (confidence*0.25)",
        regrets_ucb25,
        std_ucb25,
        plot_error,
        error_bar,
        errorevery_bar,
    )
    plot_results(
        "UCB1 (confidence*0.125)",
        regrets_ucb125,
        std_ucb125,
        plot_error,
        error_bar,
        errorevery_bar,
    )"""

#########################################################
# if UCB_alpha==True:
#     setting.alpha=1
#     setting.delta_conf_level=0.001
#     rewards_ucb_alpha_elimination,regrets_ucb_alpha_elimination,expectations_ucb_alpha_elimination, std_ucb_alpha_elimination = experiment(ucb_alpha_elimination, bandits, satisfaction_level, nb_step, nb_repetition)
#     setting.alpha=1.2
#     setting.delta_conf_level=0.001
#     rewards_ucb_alpha_elimination2,regrets_ucb_alpha_elimination2,expectations_ucb_alpha_elimination2, std_ucb_alpha_elimination2 = experiment(ucb_alpha_elimination, bandits, satisfaction_level, nb_step, nb_repetition)

#     setting.alpha=1.7
#     setting.delta_conf_level=0.001
#     rewards_ucb_alpha_elimination3,regrets_ucb_alpha_elimination3,expectations_ucb_alpha_elimination3, std_ucb_alpha_elimination3 = experiment(ucb_alpha_elimination, bandits, satisfaction_level, nb_step, nb_repetition)
#     if error_bar == False:
#         plt.plot(np.cumsum(regrets_ucb_alpha_elimination), label=r'${\rm UCB}_{\alpha}$')
#     else:
#         errorevery_bar-=200
#         plt.errorbar(np.arange(nb_step), np.cumsum(regrets_ucb_alpha_elimination), yerr=std_ucb_alpha_elimination, errorevery=errorevery_bar, label=r'${\rm UCB}_{\alpha} 1$')

#         errorevery_bar-=400
#         plt.errorbar(np.arange(nb_step), np.cumsum(regrets_ucb_alpha_elimination2), yerr=std_ucb_alpha_elimination2, errorevery=errorevery_bar, label=r'${\rm UCB}_{\alpha} 1.2$')

#         errorevery_bar-=600
#         plt.errorbar(np.arange(nb_step), np.cumsum(regrets_ucb_alpha_elimination3), yerr=std_ucb_alpha_elimination3, errorevery=errorevery_bar, label=r'${\rm UCB}_{\alpha} 1.7$')
####### ucb_alpha_elimination       ucb_alpha_elimination
if plot_other_algorithms == True:
    rewards_ucb, regrets_ucb, expectations_ucb, std_ucb = experiment(
        ucb, bandits, satisfaction_level, nb_step, nb_repetition, error_type=error_type
    )
    setting.alpha = 1
    setting.delta_conf_level = 0.001
    (
        rewards_ucb_alpha_elimination,
        regrets_ucb_alpha_elimination,
        expectations_ucb_alpha_elimination,
        std_ucb_alpha_elimination,
    ) = experiment(
        ucb_alpha_elimination,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    errorevery_bar -= 200
    plot_results(
        r"\textsc{UCB1}", regrets_ucb, std_ucb, plot_error, error_bar, errorevery_bar
    )
    errorevery_bar -= 200
    plot_results(
        r"$\textsc{UCB}_{\alpha}$",
        regrets_ucb_alpha_elimination,
        std_ucb_alpha_elimination,
        plot_error,
        error_bar,
        errorevery_bar,
    )

####### ucb_alpha_no_elimination      ucb_alpha_no_elimination
# # setting.alpha=1
# # setting.delta_conf_level=0.001
# # rewards_ucb_alpha_no_elimination,regrets_ucb_alpha_no_elimination,expectations_ucb_alpha_no_elimination, std_ucb_alpha_no_elimination = experiment(ucb_alpha_no_elimination, bandits, satisfaction_level, nb_step, nb_repetition)
# # if error_bar == False:
# #     plt.plot(np.cumsum(regrets_ucb_alpha_no_elimination), label="ucb_alpha_no_elimination alpha=1 and delta=.001")
# # else:
# #     plt.errorbar(np.arange(nb_step), np.cumsum(regrets_ucb_alpha_no_elimination), yerr=std_ucb_alpha_no_elimination, label="ucb_alpha_no_elimination alpha=1 and delta=.01")
# # ###### satisfaction_mean_reward    satisfaction_mean_reward
# # rewards_satisfaction_mean_reward1,regrets_satisfaction_mean_reward1,expectations_satisfaction_mean_reward1, std_satisfaction_mean_reward1 = experiment(satisfaction_mean_reward1, bandits, satisfaction_level, nb_step, nb_repetition)
# # if error_bar == False:
# #     plt.plot(np.cumsum(regrets_satisfaction_mean_reward1), label="Satisfaction in Mean Reward UCL 1")
# # else:
# #     plt.errorbar(np.arange(nb_step), np.cumsum(regrets_satisfaction_mean_reward1), yerr=std_satisfaction_mean_reward1, label="Satisfaction in Mean Reward UCL 1")
# # # ###### satisfaction_mean_reward    satisfaction_mean_reward
# # rewards_satisfaction_mean_reward2,regrets_satisfaction_mean_reward2,expectations_satisfaction_mean_reward2, std_satisfaction_mean_reward2 = experiment(satisfaction_mean_reward2, bandits, satisfaction_level, nb_step, nb_repetition)
# # if error_bar == False:
# #     plt.plot(np.cumsum(regrets_satisfaction_mean_reward2), label="Satisfaction in Mean Reward UCL 2")
# # else:
# #     plt.errorbar(np.arange(nb_step), np.cumsum(regrets_satisfaction_mean_reward2), yerr=std_satisfaction_mean_reward2, label="Satisfaction in Mean Reward UCL 2")
# # ###### satisfaction_mean_reward    satisfaction_mean_reward
# # rewards_satisfaction_mean_reward3,regrets_satisfaction_mean_reward3,expectations_satisfaction_mean_reward3, std_satisfaction_mean_reward3 = experiment(satisfaction_mean_reward3, bandits, satisfaction_level, nb_step, nb_repetition)
# # if error_bar == False:
# #     plt.plot(np.cumsum(regrets_satisfaction_mean_reward3), label="Satisfaction in Mean Reward UCL 3")
# # else:
# #     plt.errorbar(np.arange(nb_step), np.cumsum(regrets_satisfaction_mean_reward3), yerr=std_satisfaction_mean_reward3, label="Satisfaction in Mean Reward UCL 3")
# ###### satisfaction_mean_reward    satisfaction_mean_reward
# rewards_satisfaction_mean_reward5,regrets_satisfaction_mean_reward5,expectations_satisfaction_mean_reward5, std_satisfaction_mean_reward5 = experiment(satisfaction_mean_reward5, bandits, satisfaction_level, nb_step, nb_repetition)
# if error_bar == False:
#     plt.plot(np.cumsum(regrets_satisfaction_mean_reward5), label="Satisfaction in Mean Reward UCL 5")
# else:
#     plt.errorbar(np.arange(nb_step), np.cumsum(regrets_satisfaction_mean_reward5), yerr=std_satisfaction_mean_reward5, label="Satisfaction in Mean Reward UCL 5")
# # ###### satisfaction_mean_reward    satisfaction_mean_reward
# rewards_satisfaction_mean_reward6,regrets_satisfaction_mean_reward6,expectations_satisfaction_mean_reward6, std_satisfaction_mean_reward6 = experiment(satisfaction_mean_reward6, bandits, satisfaction_level, nb_step, nb_repetition)
# if error_bar == False:
#     plt.plot(np.cumsum(regrets_satisfaction_mean_reward6), label="Satisfaction in Mean Reward UCL 6")
# else:
#     plt.errorbar(np.arange(nb_step), np.cumsum(regrets_satisfaction_mean_reward6), yerr=std_satisfaction_mean_reward6, label="Satisfaction in Mean Reward UCL 6")
#################################
#################################
#################################
###### satisfaction_mean_reward    satisfaction_mean_reward
if plot_other_algorithms == True:
    (
        rewards_satisfaction_mean_reward4,
        regrets_satisfaction_mean_reward4,
        expectations_satisfaction_mean_reward4,
        std_satisfaction_mean_reward4,
    ) = experiment(
        satisfaction_mean_reward,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    errorevery_bar -= 200
    plot_results(
        r"\rm Satisfaction in Mean Reward \textsc{UCL}",
        regrets_satisfaction_mean_reward4,
        std_satisfaction_mean_reward4,
        plot_error,
        error_bar,
        errorevery_bar,
    )
#######
####### deterministic_ucl  deterministic_ucl   deterministic_ucl
if plot_other_algorithms == True:
    (
        rewards_deterministic_ucl,
        regrets_deterministic_ucl,
        expectations_deterministic_ucl,
        std_deterministic_ucl,
    ) = experiment(
        deterministic_ucl,
        bandits,
        satisfaction_level,
        nb_step,
        nb_repetition,
        error_type=error_type,
    )
    errorevery_bar -= 200
    plot_results(
        r"\rm Deterministic \textsc{UCL}",
        regrets_deterministic_ucl,
        std_deterministic_ucl,
        plot_error,
        error_bar,
        errorevery_bar,
    )
################################################################
#########   Draw Draw Draw Draw Draw
os.makedirs("Figures/New", exist_ok=True)

linestyles = [
    "solid",
    "dotted",
    "dashed",
    "dashdot",
    (0, (3, 1, 1, 1, 1, 1)),
    (5, (10, 3)),
    (0, (3, 1, 1, 1)),
]
ax = plt.gca()
for l, ls in zip(ax.lines, itertools.cycle(linestyles)):
    l.set_linestyle(ls)

if experiment_multiple_satisfaction_levels:
    plt.title(
        f"\\rm Average satisficing regret over {nb_repetition} runs\n\\rm Gaussian rewards - Setting {setting_id}"
    )
    plt.xlabel(r"\rm Time step")
    plt.ylabel(r"\rm Satisficing regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(
        "Figures/New/Gaussian/EB-Realizable.png", format="png", dpi=resolution_value
    )
elif realizable_case:
    plt.title(
        f"\\rm Average satisficing regret over {nb_repetition} runs\n\\rm Realizable case - Gaussian rewards - Setting {setting_id}"
    )
    plt.xlabel(r"\rm Time step")
    plt.ylabel(r"\rm Satisficing regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(
        "Figures/New/Gaussian/EB-Realizable.png", format="png", dpi=resolution_value
    )
    #######
else:
    plt.title(
        f"\\rm Average regret over {nb_repetition} runs\n\\rm Not realizable case - Gaussian rewards - Setting {setting_id}"
    )
    plt.xlabel(r"\rm Time step")
    plt.ylabel(r"\rm Regret")
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig(
        "Figures/New/Gaussian/EB-NonRealizable.png",
        format="png",
        dpi=resolution_value,
    )
plt.show()
