import matplotlib.pyplot as plt

from learning_algorithms import Imed, FIMED, OIMED, Med, FMED, OMED
from bandits import *
from experiment import Experiment


########################################
#             Load DSSAT               #
########################################
print("\nRunning Experiment on DSSAT\n", flush=True)
horizon = 10000
nbr_xp = 50
print(f"Horizon = {horizon}\nNumber of experiments = {nbr_xp}\n", flush=True)

########################################
#             Load DSSAT               #
########################################
print("Loading DSSAT\n", flush=True)
bandit = DssatBandit()

########################################
#             Experiment               #
########################################
print("Launching the experiment\n", flush=True)
algorithms = [
    Imed(bandit),
    FIMED(bandit),
    OIMED(bandit),
    Med(bandit),
    FMED(bandit),
    OMED(bandit)
]
experiment = Experiment(algorithms, bandit, suffix=" figure 13")
_ = experiment.run(nbr_xp, horizon)
experiment.plot()
plt.close('all')

########################################
#            Load Bandit               #
########################################
means = np.array([0.3, 0.4, 0.45, 0.5, 0.52, 0.55])
bandit = BernoulliBandit(means)
print(f"means = {means}\n", flush=True)

########################################
#             Experiment               #
########################################
print("Launching the experiment\n", flush=True)
algorithms = [
    Imed(bandit),
    FIMED(bandit),
    OIMED(bandit),
    Med(bandit),
    FMED(bandit),
    OMED(bandit)
]

experiment = Experiment(algorithms, bandit, suffix=" figure 14")
_ = experiment.run(nbr_xp, horizon)
experiment.plot()
plt.close('all')

########################################
#            Load Bandit               #
########################################
means = np.array([0.3, 0.4, 0.45, 0.5, 0.52, 0.55])
bandit = BetaBandit(means)
print(f"means = {means}\n", flush=True)

########################################
#             Experiment               #
########################################
print("Launching the experiment\n", flush=True)
algorithms = [
    Imed(bandit),
    FIMED(bandit),
    OIMED(bandit),
    Med(bandit),
    FMED(bandit),
    OMED(bandit)
]

experiment = Experiment(algorithms, bandit, suffix=" figure 15")
_ = experiment.run(nbr_xp, horizon)
experiment.plot()
plt.close('all')
