from rlberry.experiment import load_experiment_results
from rlberry.manager.evaluation import plot_writer_data
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

# figure size in inches
rcParams['figure.figsize'] = 10,7
rcParams['mathtext.default'] = 'regular'
rcParams['font.size'] = 14

# ------------------------------------------
# RandQL vs OptQL
# ------------------------------------------
RESULTS_PATH = 'results'

def plot_exp(experiment_name):
    # Get list of managers and update names
    PLOT_TITLES = {
        'adaptiveql': 'Adaptive Q-Learning',
        'adaptive_randql': 'Adaptive Randomized QL',
        'kernel_ucbvi': 'Kernel UCBVI',
    }
    output_data = load_experiment_results(RESULTS_PATH, experiment_name)
    _manager_list = list(output_data['manager'].values())
    manager_list = []
    agents_list = []
    # Sort by names
    _manager_list = sorted(_manager_list, key=lambda x: x.agent_name)

    for manager in _manager_list:
            if manager.agent_name in PLOT_TITLES:
                manager.agent_name = PLOT_TITLES[manager.agent_name]
                manager_list.append(manager)
                print(manager.agent_name)
                print("n agents = ", len(manager.get_agent_instances()))
                agents_list.append(manager.get_agent_instances()[0])
                del manager.agent_handlers

    res = plot_writer_data(manager_list, tag="episode_rewards", preprocess_func=np.cumsum, show = False,title=' ')
    plt.ylabel('cumultative reward')
    plt.xlabel('episode')
    plt.grid()
    plt.savefig('{}.pdf'.format(experiment_name))

EXPERIMENT_NAMES = [
    'pball_lvl1_randql',
    'pball_lvl2_randql',
    'pball_lvl3_randql'
    ]

for experiment_name in EXPERIMENT_NAMES:
     plot_exp(experiment_name)