import sys
sys.path.append('.')

from matplotlib import pyplot as plt
from util.util import moving_average
import numpy as np


env_name = 'four_rooms'
fig, ax = plt.subplots(figsize=(10,6))

param_list = [
              {'aux_type': 'None', 'aux_num': 0, 'policy': 'main', 'color': 'orange', 'rep': 'nonlinear'},


            # {'aux_type': 'hallway_NO_MS', 'policy': 'main', 'aux_num': 2, 'color': 'red', 'rep': 'nonlinear'},
            # {'aux_type': 'corner_NO_MS', 'policy': 'main', 'aux_num': 2, 'color': 'blue', 'rep': 'nonlinear'},
             {'aux_type': 'random_goal', 'policy': 'main', 'aux_num': 8, 'color': 'black', 'rep': 'nonlinear'},
            #  {'aux_type': 'hallway', 'policy': 'main', 'aux_num': 3, 'color': 'red', 'rep': 'nonlinear'},
            #  {'aux_type': 'corner', 'policy': 'main', 'aux_num': 4, 'color': 'blue', 'rep': 'nonlinear'},
             {'aux_type': 'MSGT', 'policy': 'main', 'aux_num': 8, 'color': 'green', 'rep': 'nonlinear'},
            {'aux_type': 'discovered', 'policy': 'main', 'aux_num': 8, 'color': 'limegreen', 'rep': 'nonlinear'},

            # {'aux_type': 'MSGT', 'policy': 'main', 'aux_num': 8, 'color': 'green', 'rep': 'nonlinear'},
            # {'aux_type': 'random_goal', 'policy': 'main', 'aux_num': 8, 'color': 'black', 'rep': 'nonlinear'},
            # {'aux_type': 'maze_hallway', 'policy': 'main', 'aux_num': 2, 'color': 'red', 'rep': 'nonlinear'},
            # {'aux_type': 'maze_corner', 'policy': 'main', 'aux_num': 3, 'color': 'blue', 'rep': 'nonlinear'},
            # {'aux_type': 'discovered', 'policy': 'main', 'aux_num': 8, 'color': 'limegreen', 'rep': 'nonlinear'},

            # {'aux_type': 'MSGT', 'policy': 'main', 'aux_num': 5, 'color': 'green', 'rep': 'nonlinear'},
            # {'aux_type': 'MSGT_random_replacement', 'policy': 'main', 'aux_num': 5, 'color': 'grey', 'rep': 'nonlinear'},
            # {'aux_type': 'random_goal', 'policy': 'main', 'aux_num': 5, 'color': 'black', 'rep': 'nonlinear'},
            # {'aux_type': 'pinball_bottleneck', 'policy': 'main', 'aux_num': 4, 'color': 'red', 'rep': 'nonlinear'},
            # {'aux_type': 'pinball_corner', 'policy': 'main', 'aux_num': 5, 'color': 'blue', 'rep': 'nonlinear'},
            # {'aux_type': 'discovered', 'policy': 'main', 'aux_num': 4, 'color': 'limegreen', 'rep': 'nonlinear'},



]

for param in param_list:
    num_steps_over_runs = np.load('{}/num_steps_over_runs_aux_type_{}_policy_{}_aux_num_{}_rep_{}.npy'.format
                                  (env_name, param['aux_type'], param['policy'], param['aux_num'], param['rep']))

    window = 10
    if env_name == 'maze':
        window = 20
    if env_name == 'pinball':
        window = 40
    smooth_num_steps_over_runs = np.zeros((num_steps_over_runs.shape[0], num_steps_over_runs.shape[1] - window + 1))
    for i in range(num_steps_over_runs.shape[0]):
        smooth_num_steps_over_runs[i, :] = moving_average(num_steps_over_runs[i, :], window)
    num_steps_over_runs = smooth_num_steps_over_runs

    num_steps_over_runs_mean = np.mean(num_steps_over_runs, 0)
    num_steps_over_runs_se = np.std(num_steps_over_runs, 0) / np.sqrt(num_steps_over_runs.shape[0])

    ax.plot(num_steps_over_runs_mean, param['color'])
    ax.fill_between(np.arange(num_steps_over_runs_mean.shape[0]), num_steps_over_runs_mean - num_steps_over_runs_se/2,
                    num_steps_over_runs_mean + num_steps_over_runs_se/2, alpha = 0.5, color = param['color'])


if env_name == 'four_rooms':
    ax.set_xticks((0, 50, 100, 150, 200))
    ax.set_yticks((100, 300, 500))
    # ax.set_yticks((0, 50, 100, 150, 200))
elif env_name == 'maze':
    ax.set_xticks((0, 100, 200, 300))
    ax.set_yticks((100, 300, 500))
elif env_name == 'pinball':
    ax.set_xticks((0, 50, 100, 150))
    ax.set_yticks([1000, 2000, 3000])
    ax.set_yticklabels(['1k', '2k', '3k'])
    ax.set_ylim((0, 3000))

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

plt.xticks(fontsize=40)
plt.yticks(fontsize=40)
plt.savefig('{}_learning_curve_{}.pdf'.format(env_name, param_list[-1]['aux_type']))
# print(stats.ttest_ind(num_steps_over_runs_mean, num_steps_over_runs_mean_aux))
