from unittest import registerResult
from environments import *
from agents import *
import matplotlib.pyplot as plt
import os
from functions.functions import test_policy
from joblib import Parallel, delayed

import datetime
import time

t0 = time.time()

print("-----------------------------------------------------------\n")
print("Bandit with Ranking Feedback Performances on Synthetic Data\n")
print("-----------------------------------------------------------\n")

save = True

exp = int(input("Digit the Experiment you want to run: "))


if save:
    tail = datetime.datetime.now().strftime("%y_%m_%d-%H_%M_")
    dir = 'results/'+tail+'Experiment_'+str(exp)
    os.mkdir(dir)
    dir +='/'



# Parameters of the Envronment
if (exp == 1):
    means = np.array([0.9, 1.05, 1.12, 1.15])
    noise_std = 1.0
    best_arm = np.argmax(means)
    best_arm_value = np.max(means)
    arms = len(means)

elif (exp == 2):
    means = np.array([0.05, 0.25, 0.5, 1.0])
    noise_std = 1.0
    best_arm = np.argmax(means)
    best_arm_value = np.max(means)
    arms = len(means)

elif (exp == 3):
    means = np.array([0.03, 0.07, 0.1, 0.08, 0.97, 1])
    noise_std = 1.0
    best_arm = np.argmax(means)
    best_arm_value = np.max(means)
    arms = len(means)

else: 
    means = np.array([0.05, 0.05, 0.1, 0.15, 0.25, 0.5, 0.75, 1.0])
    noise_std = 1.0
    best_arm = np.argmax(means)
    best_arm_value = np.max(means)
    arms = len(means)


T = 2000
n_experiments = 50
num_seeds = 10
mini_exp = int(n_experiments/num_seeds)
seeds = np.arange(num_seeds)

ranking_agent1 = RankingAgent(arms, T, 1)
ranking_agent15 = RankingAgent(arms, T, 1.5)
ranking_agent2 = RankingAgent(arms, T, 2)
ec_agent = ECagent(arms, T)
RLPE_agent = RLPE(arms, T)

agents_list = [ranking_agent1, ranking_agent15, ranking_agent2, ec_agent, RLPE_agent]
names_list = ['R1', 'R15', 'R2', 'EC', 'RLPE']

env = GaussianEnvironment(means, arms, noise_std = noise_std)

results = []
for agent in agents_list:
    results.append(Parallel(n_jobs=num_seeds)(delayed(test_policy)(seed, mini_exp, T, best_arm_value, env, agent) for seed in seeds))
    print('Agent finished')

for i in range(len(results)):
    results[i] = np.concatenate(results[i], axis=0)
    print(results[i].shape)

if save:
    for i in range(len(names_list)):
        np.save(dir+names_list[i], results[i])

print('Running time = {}'.format(time.time()-t0))
