import matplotlib
import os
import scipy.stats

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from scipy.stats import gaussian_kde, sem
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from mpl_toolkits.mplot3d import Axes3D
from typing import List, Any
from matplotlib import ticker

SHAPES = ['-v', '-^', '-<', '-D', '-o', '-']

COLORS = ['#10ac84', '#00a8ff', '#9c88ff', '#fbc531' ,'#4cd137', '#487eb0', '#f368e0', '#ff9f43', '#ee5253', '#0abde3']

marker_map = {
    'epsilon-greedy': SHAPES[1],
    'softmax': SHAPES[2],
    'mellowmax': SHAPES[3],
    'resmax': SHAPES[4],
    'ResMax': SHAPES[4],
}


color_map = {
    'epsilon-greedy': '#2CA02C',
    'softmax': '#1F77B4',
    'mellowmax': 'red',
    'resmax': '#FF7F0E',
    'ResMax': '#FF7F0E',
}

def use_pdf_backend() -> None:
    matplotlib.use('pdf')

def save_plot(addr: str, outside_plot=True) -> None:
    if outside_plot:
        plt.savefig(addr, bbox_inches='tight')
    else:
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.savefig(addr, bbox_inches='tight', pad_inches=0)

def plot_mean_std(mean: List[float], std: List[float], color: str = COLORS[0], 
                  shape: str = SHAPES[-1], x_values: List[float] = None, 
                  xscale: str = 'linear', label: str = None, add_markers=False, 
                  markers_shape_interval: int = 5, xlabel: str = None,
                  ylabel: str = None,
                  ):
    
    assert len(mean) == len(std)
        
    ub = mean + std/2
    lb = mean - std/2
    
    if x_values is None:
        x_values = np.arange(len(mean))

    # Plotting the standard deviation shadow
    plt.fill_between(x_values, ub, lb, color=color, alpha=.1)
    
    # for creating markers over the plot
    markers_on = None
    if add_markers:
        markers_on_range = list(range(mean.shape[0]))
        if len(x_values)%markers_shape_interval!=1:
            markers_on = list(np.append(markers_on_range[::markers_shape_interval], markers_on_range[-1]))
        else:
            markers_on = markers_on_range[::markers_shape_interval]

    # Plotting the mean values
    plt.plot(x_values, mean, shape, color=color, label=label, markevery=markers_on)
    
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    
    plt.xscale(xscale)
     

def plot_hyperparameter_sensitivity(rewards: List[List[float]], x_values: List[float] = None,
                                    n_episodes_reward: int = None, xlabel: str = None,
                                    ylabel: str = None, xscale: str = 'linear') -> None:
    
    fig = plt.figure()
    
    # Selecting the episodes to average
    start_idx = None
    end_idx = None
    if n_episodes_reward is None:
        pass
    elif int(n_episodes_reward)>0:
        end_idx = n_episodes_reward
    elif int(n_episodes_reward)<0:
        start_idx = n_episodes_reward
    else:
        raise ValueError('n_episodes_reward should be n_episodes_reward<=-1 or n_episodes_reward>1')
    
    mean = np.zeros((len(rewards), len(rewards[0])))
    for i, exp_rewards in enumerate(rewards):
        for j, episode_rewards in enumerate(exp_rewards):
            mean[i, j] = np.mean(episode_rewards[start_idx:end_idx])
    
    std = mean.std(1)
    mean = mean.mean(1)
                   
    plot_mean_std(mean, std, xlabel=xlabel, ylabel=ylabel, xscale=xscale, x_values=x_values)

def plot_state_visitation_1D(observation_data: np.ndarray, min_x: float, max_x: float,
        hist_type: str = 'histogram', bins: int = 20, title: str = None, bandwidth: float = None) -> None:
    """
        {function description}
        observation_data : np.ndarray
        
        min_x : float
        
        max_x : float
        
        hist_type : str
         
        bins: int

        title : str
        
        bandwidth: float

        Return None
    """
    fig = plt.figure()
    
    if hist_type=='histogram':
        plt.hist(observation_data, bins=bins, color=COLORS[0])
    elif hist_type=='kde':
        kde = gaussian_kde(observation_data, bw_method=bandwidth)    
        xgrid = np.linspace(min_x, max_x, 100)
        plt.fill_between(xgrid, kde(xgrid), alpha=0.5, color=COLORS[0])
        plt.plot(xgrid, kde(xgrid), color=COLORS[0])
        plt.ylabel('Probability Density')
    elif hist_type=='bar_chart':
        states, n_occurrences = np.unique(observation_data, return_counts=True)
        plt.bar(states, n_occurrences, color=COLORS[0])
        plt.ylabel('Number of Visitations')
        # plt.hist(observation_data, range=(min_x, max_x), bins=bins, color=COLORS[0])
    else:
        raise ValueError('hist_type value is wrong: {}'.format(hist_type))

    plt.xlim([min_x-0.7, max_x+0.7])
    
    if title is not None: plt.title(title)

    plt.xlabel('States')
 
def plot_state_visitation_2D(observation_data: np.ndarray, min_x: float, max_x: float,
        min_y: float, max_y: float, contour: bool = True, kernel: str = 'gaussian',
        bandwidth: float = None, levels: int = 7, cmap: str = 'coolwarm', alpha: float = 0.5,
        remove_tri=False, x_label: str = None, y_label: str = None, n_jobs: int = 1, azim: float = 30,
        sample_size: int = 30):
    """
    This function will visualize state transition visitation probability plots for data with 2D state-space 
        observation_data : ndArray(None, 2)
            Observations that have been gathered through timesteps
        min_x : float
            The title that will be shown at the top of the plot
        max_x : float
        
        kde : bool

        bins: int

        title : str
            The title that will be shown at the top of the plot
    """

    fig = plt.figure()

    if bandwidth is None:
        # use grid search cross-validation to optimize the bandwidth
        params = {'bandwidth': np.logspace(-1, 1, 20)}
        grid = GridSearchCV(KernelDensity(kernel=kernel), params, n_jobs=n_jobs, verbose=1)
        grid.fit(observation_data)
        
        print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))
        
        kde = kde = grid.best_estimator_
    else:
        kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(observation_data)
    
    xgrid = np.linspace(min_x, max_x, sample_size)
    ygrid = np.linspace(min_y, max_y, sample_size)
    Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
    grid_points = np.exp(kde.score_samples(np.vstack([Xgrid.ravel(), Ygrid.ravel()]).T))
    grid_points=grid_points.reshape(sample_size, sample_size)
    
    if contour:
        plt.contour(xgrid, ygrid, grid_points, levels=levels, linewidths=0.5, colors='k')
        plt.contourf(xgrid, ygrid, grid_points, levels=levels, cmap=cmap, alpha=alpha)

        plt.colorbar()

        plt.xlabel(x_label)
        plt.ylabel(y_label)
        
        if remove_tri: plt.fill_between(np.arange(min_x, max_x+1), np.arange(min_y, max_y+1), [max_y]*21, color='white', alpha=1, zorder=100)

    else:

        ax = fig.add_subplot(111, projection='3d')
        # ax.plot_surface(Xgrid, Ygrid, grid_points, cmap=cmap, alpha=alpha)
        ax.plot_wireframe(Xgrid, Ygrid, grid_points, cmap=cmap, alpha=alpha)

        # Creating reflected contour plots
        cset = ax.contourf(Xgrid, Ygrid, grid_points, zdir='z', offset=-np.max(grid_points)*.1, levels=levels, cmap=cmap)
        cset = ax.contourf(Xgrid, Ygrid, grid_points, zdir='x', offset=min_x-(1/10*(max_x-min_x)), levels=levels, cmap=cmap)
        cset = ax.contourf(Xgrid, Ygrid, grid_points, zdir='y', offset=min_y-(1/10*(max_y-min_y)), levels=levels, cmap=cmap)
        
        # Specifying the angle observer z angle of the generated plot
        ax.view_init(azim=azim)
        
        # Tweaking display region and labels
        ax.set_xlim(min_x-(1/10*(max_x-min_x)), max_x+(1/10*(max_x - max_y)))
        ax.set_ylim(min_y-(1/10*(max_y-min_y)), max_y+(1/10*(max_y-min_y)))
        ax.set_zlim(-np.max(grid_points)*.1, np.max(grid_points))
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        ax.set_zlabel('Probability Density')


def hyperparameter_scatter_plot(title : str, df : pd.DataFrame, col_title="Total Return"):
    """
    Generate scatter plot to show hyperparameter ditstribution for each exploration strategy
        df: pandas dataframe
            dataframe with the following columns: "Exploration Technique", "Exp Value" and a column for the returns for each step
    """
    sns.set(font_scale = 1.25)
    plot =  sns.catplot(x="Exploration Technique", y=col_title, hue="Exp Value", kind="swarm",data=df, palette ="rocket", legend_out=True )
    plot.set_xticklabels(rotation=30)
    plt.xlabel("")


def hyperparameter_scatter_plot_envs(title : str, df : pd.DataFrame):
    """
    Generate scatter plot to show hyperparameter ditstribution for each exploration strategy, witha subplot for each environment
        df: pandas dataframe
            dataframe with the following columns: "Exploration Technique", "Exp Value" and a column for the returns for each step
    """
    sns.set(font_scale = 1.25)
    plot =  sns.catplot(x="Exploration Technique", y="Total Return", hue="Exp Value", kind="swarm",data=df, palette ="rocket", legend_out=True, col="Environment")
    axes = plot.axes.flatten()
    
    for ax in axes:
        ax.set_xlabel("")
    plot.set_xticklabels(rotation=30)


def hyperparameter_violin_plot(title : str, df : pd.DataFrame):
    """
    Generate violin plot to show hyperparameter ditstribution for each exploration strategy
        df: pandas dataframe
            dataframe with the following columns: "Exploration Technique", "Exp Value" and a column for the returns for each step
    """
    sns.set(font_scale = 1.25)
    plot =sns.violinplot(x="Exploration Technique", y="Total Return",  data=df, scale='width')
    plt.xticks(rotation=30)

    plt.xlabel("")
    

def avg_reward_learning_curve_plot(title :str, df : pd.DataFrame,  columns : list, window=1000, ci = None, average_params=False):
    """
    This funtion will visualize the average rewards through steps. Will average across all random seeds
        title : str
            The title that will be shown at the top of the plot
        df : pandas Dataframe
            must have colummns: ["Environment", "Exploration Technique", "Exp Value"]
        columns : list
            List of non-numerical columnds of the dataset
        ci : None
            Whether to plot confidence intervals
    """
    sns.set_style(style='white')
    fig, ax = plt.subplots(figsize=(7, 5))
    sns.set(font_scale=1.4)

    df.infer_objects()    
    df = df.set_index(columns).rolling(window, axis=1).mean() 
    df.reset_index(inplace=True)
    
    df = pd.melt(df, id_vars=columns, var_name = "Step", value_name= "Reward", value_vars=range(0, df.shape[1]-len(columns)-1 ))

    if average_params:
        plot = sns.lineplot(x="Step", y= "Reward", hue="Exploration Technique",  ci = ci, data=df)# ci controls the confidence interval, it take a really long time to generate plots if enabled
    else:
        df['Config'] = df["Exploration Technique"].str.cat(df['Exp Value'].astype(str), sep = ": ")    
        plot = sns.lineplot(x="Step", y= "Reward", hue="Config",  ci = ci, data=df)# ci controls the confidence interval, it take a really long time to generate plots if enabled
  

    plot.set_title(title) 
    plot.set(ylabel="Average Reward")

    plot.legend().set_title('')
    plot.legend(loc="best", borderaxespad=0.)


def learning_curve_plot(title :str, df : pd.DataFrame,  columns : list, ci = None):
    """
    This funtion will visualize the accumulated rewards through steps. Will average across all random seeds
        title : str
            The title that will be shown at the top of the plot
          df : pandas Dataframe
            must have colummns: ["Environment", "Exploration Technique", "Exp Value"]
        columns : list
            List of non-numerical columnds of the dataset
        ci : None
            Whether to plot confidence intervals
    """
    sns.set_style(style='white')
    plt.figure(figsize=(7,5))
    df = df.infer_objects()
    sns.set(font_scale=1.4)
    
    df = pd.melt(df, id_vars=columns, var_name = "Step", value_name= "Reward", value_vars=range(0, df.shape[1]-len(columns)-1 ))

    df['Config'] = df["Exploration Technique"].str.cat(df['Exp Value'].astype(str), sep = ": ")    

    plot = sns.lineplot(x="Step", y= "Reward", hue="Config",  ci = ci, data=df)# ci controls the confidence interval, it take a really long time to generate plots if enabled
    plot.set_title(title) 
    plot.set(ylabel="Cumulative Reward")

    plot.legend().set_title('')
    plot.legend(loc="best", borderaxespad=0., frameon=False)


def sensitivity(title :str, df : pd.DataFrame, ci=None):
    """
    This funtion will visualize the accumulated rewards through steps. Will average across all random seeds
        title : str
            The title that will be shown at the top of the plot
        ci : None
            Whether to plot confidence intervals
    """
    print(ci)
    sns.set(font_scale=1.4)
    sns.set_style(style='white')
    plt.figure(figsize=(7,5))
    plot = sns.lineplot(x="Exp Value", y= "Total Return", hue='Exploration Technique', data=df, ci=ci)
    plot.set_xscale('symlog', base=2)
    plt.xlabel(r'$\epsilon, \eta, \tau$') 
    plot.set_title(title)

    plot.legend(loc=2, borderaxespad=0., bbox_to_anchor=(1.05, 1), framealpha=0)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
   
    return plot


def new_sensitivity(df, figsize, environments=None, xticks=None, 
                    summary_variable="Total Return", save_path='sensitivity-{}.pdf'):
    """
    Generates sensitivity  plot
        df : pandas Dataframe
            must have colummns: ["Environment", "Exploration Technique", "Exp Value"]
        summary_varaible : str 
            The variable to represent performance. By default its set to Total Return but something else
            like average return may be used instead, as long as the appropriate column is in df.
    """
    
    if xticks is None:
        xticks = [0, 1] + [2**i for i in range(1, 13)]
    
    if environments is None:
        environments = df["Environment"].unique()
    num_envs = len(df["Environment"].unique())
    
    ax_idx = 0
    last_idx = num_envs

    colour_dict={
        'resmax-normalized':'red',
        'resmax':'green',
        'softmax':'blue',
        'epsilon-greedy':'orange'
    }
    
    df_std = df.groupby(["Environment","Exploration Technique",  "Exp Value"]).agg(scipy.stats.sem)
    df= df.groupby(["Environment","Exploration Technique",  "Exp Value"]).mean() # averages across the random seeds
    
    df= df.reset_index()
    df_std = df_std.reset_index()

    for env in environments:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        env_df = df.loc[df["Environment"] == env]
        env_df_std = df_std.loc[df_std["Environment"] == env]
        ax.set_title(env)
        plt.gca().set_xscale('symlog', base=2)
        for exp_technique in colour_dict.keys():
            exp_df = env_df.loc[env_df["Exploration Technique"] == exp_technique]
            exp_df_std = env_df_std.loc[env_df_std["Exploration Technique"] == exp_technique]
            x = exp_df["Exp Value"]
            y = exp_df[summary_variable]
            y_std = exp_df_std[summary_variable]
            ax.plot(x, y,color=colour_dict[exp_technique],label=exp_technique)
            ax.fill_between(x, y-y_std/2, y+y_std/2, alpha=0.1, color=colour_dict[exp_technique])
            
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        ax.set_xlabel(r'$\epsilon, \tau, \eta$', fontsize=16)            
        plt.xticks(xticks)
        plt.legend()
        
        save_plot(save_path.format(env))


def summary_plot(df, summary_variable="Total Return"):
    """
    Generates summary ridge plot
        df : pandas Dataframe
            must have colummns: ["Environment", "Exploration Technique", "Exp Value"]
        summary_varaible : str 
            The variable to represent performance. By default its set to Total Return but something else
            like average return may be used instead, as long as the appropriate column is in df.
    """
    num_envs = len(df["Environment"].unique())
    fig, axs = plt.subplots(num_envs, 1, figsize=(9, 3), sharey=True)

    ax_idx = 0
    last_idx = num_envs

    colour_dict={
        'ResMax':'green',
        'softmax':'blue',
        'epsilon-greedy':'orange'
    }

    df= df.groupby(["Environment","Exploration Technique",  "Exp Value"]).mean() # averages across the random seeds
    df= df.reset_index()

    for env in df["Environment"].unique():

        if num_envs >1:
            ax = axs[ax_idx]
        else:
            ax = axs
        env_df = df.loc[df["Environment"] == env]

        max_val = env_df.loc[env_df["Environment"] == env]['Total Return'].max()
        env_df["Total Return"]/=max_val


        for exp_technique in env_df["Exploration Technique"].unique():
            exp_df = env_df.loc[env_df["Exploration Technique"] == exp_technique]
            x = exp_df["Exp Value"]
            y = exp_df[summary_variable]
            ax.plot(x,y,color=colour_dict[exp_technique],label=exp_technique)
            ax.fill_between(x, y, alpha=0.1, color=colour_dict[exp_technique])
        
        ax.text(0, .5, env, fontweight='normal',ha="left", va="center", transform=ax.transAxes, fontsize=12)
        ax.set_xscale("log")

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.get_yaxis().set_ticks([])

        if ax_idx !=last_idx-1:
            ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False) 

        ax_idx +=1

    plt.xlabel(r'$\epsilon, \eta, \tau$', fontsize=16) 
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right', frameon=False)

def summary_plot_same_scale(df, figsize, environments=None, resmax_xticks=None,
                            softmax_xticks=None, summary_variable="Total Return",
                            save_path='plot.pdf'):
    """
    Generates summary ridge plot332
        df : pandas Dataframe
            must have colummns: ["Environment", "Exploration Technique", "Exp Value"]
        summary_varaible : str 
            The variable to represent performance. By default its set to Total Return but something else
            like average return may be used instead, as long as the appropriate column is in df.
    """
    
    if environments is None:
        environments = df["Environment"].unique()
    num_envs = len(df["Environment"].unique())
    fig, axs = plt.subplots(num_envs, 1, figsize=figsize)
    
    if resmax_xticks is None:
        resmax_xticks = [2**i for i in range(0, 13, 2)]
    if softmax_xticks is None:
        softmax_xticks = [2**i for i in range(-4, 9, 2)]
    
    ax_idx = 0
    last_idx = num_envs

    colour_dict={
        'ResMax':'green',
        'softmax':'blue',
        'epsilon-greedy':'orange'
    }
    

    df_std = df.groupby(["Environment","Exploration Technique",  "Exp Value"]).agg(scipy.stats.sem)
    df= df.groupby(["Environment","Exploration Technique",  "Exp Value"]).mean() # averages across the random seeds
    
    df= df.reset_index()
    df_std = df_std.reset_index()

    for env in environments:
        if num_envs >1:
            ax = axs[ax_idx]
        else:
            ax = axs
        env_df = df.loc[df["Environment"] == env]
        env_df_std = df_std.loc[df_std["Environment"] == env]

        epsilon_axis = ax.twiny()
        resmax_axis = ax.twiny()
        axes = {
            'ResMax': resmax_axis,
            'epsilon-greedy': epsilon_axis,
            'softmax': ax
        }
        for exp_technique in axes.keys():
            exp_df = env_df.loc[env_df["Exploration Technique"] == exp_technique]
            exp_df_std = env_df_std.loc[env_df_std["Exploration Technique"] == exp_technique]
            x = exp_df["Exp Value"]
            y = exp_df[summary_variable]
            y_std = exp_df_std[summary_variable]
            axes[exp_technique].plot(x, y,color=colour_dict[exp_technique],label=exp_technique)
            axes[exp_technique].fill_between(x, y-y_std/2, y+y_std/2, alpha=0.1, color=colour_dict[exp_technique])
            
            axes[exp_technique].spines['top'].set_visible(False)
            axes[exp_technique].spines['right'].set_visible(False)
            axes[exp_technique].spines['left'].set_visible(False)
            if ax_idx !=last_idx-1:
                axes[exp_technique].spines['bottom'].set_visible(False)
            axes[exp_technique].get_yaxis().set_ticks([])
            if ax_idx !=last_idx-1:
                axes[exp_technique].get_xaxis().set_ticks([])
        axes['softmax'].text(0, .5, env, fontweight='normal',ha="left", va="center", transform=ax.transAxes, fontsize=15)
        axes['softmax'].set_xscale('log', base=2)
        axes['ResMax'].set_xscale('log', base=2)
        axes['ResMax'].get_xaxis().set_ticks([])
        axes['softmax'].get_xaxis().set_ticks([])
        axes['epsilon-greedy'].get_xaxis().set_ticks([])
        
        if ax_idx ==last_idx-1:
            ax.set_xlabel(r'$\epsilon, \tau, \eta$', fontsize=16)
            ax.set_xticks(softmax_xticks)
            ax.set_xticklabels(['less greedy'] + ['']*(len(softmax_xticks)-2) + ['more greedy'], fontsize=14)
            
            
            
        axes['epsilon-greedy'].invert_xaxis()
        ax_idx +=1
    handles, labels = ax.get_legend_handles_labels()
    ep_handles, ep_labels = epsilon_axis.get_legend_handles_labels()
    handles.append(ep_handles[0])
    labels.append(ep_labels[0])
    ep_handles, ep_labels = resmax_axis.get_legend_handles_labels()
    handles.append(ep_handles[0])
    labels.append(ep_labels[0])
    fig.legend(handles, labels, loc='upper center', frameon=False, ncol=3,
              bbox_to_anchor=(.5, 0.9), borderaxespad=0., fontsize=14)
    
    save_plot(save_path)

def plot_average_learning_curves(reward_data, labels, window_size=1000, title = None,
                            y_axis_name = 'Average Reward', x_axis_name = 'Step', save_path = 'average_learning_curve_plot.pdf'):
    """
    reward_data: np.ndarray Nxrxt: where N is the number of exploration heuristics, r is the number of runs, and t is the number of timesteps
    exp_names: list[str] N: A list of each exploration heuristic name and its value like 'ResMax: 4096'
    """
    plt.figure()
    
    reward_data = np.array(reward_data)
    for i, exp_data in enumerate(reward_data):
        
        x = range(0, len(exp_data[0]), window_size)
        
        data = [None] * len(exp_data)
        for t, d in enumerate(exp_data):
            data[t] = [np.sum(d[k*window_size:window_size*(k+1)]) for k in range(len(d)//window_size)]
        
        data = np.array(data)
        
        y = data.mean(0)
        y_std = np.apply_along_axis(scipy.stats.sem, 0, data)
        
        plt.plot(x, y, label=labels[i])
        plt.fill_between(x, y-y_std/2, y+y_std/2, alpha=0.1)
        
    if title is not None:
        plt.title(title)
    plt.ylabel(y_axis_name)
    plt.xlabel(x_axis_name)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.legend()
    
    save_plot(save_path)

def plot_average_learning_curves_rolling(reward_data, labels, window_size=1000, title = None, y_axis_name = 'Average Reward', x_axis_name = 'Step'):
    """
    Like the plot_average_learning_curves but uses a rolling average

    arguments:
        reward_data: np.ndarray N x r x t: where N is the number of exploration heuristics, r is the number of runs, and t is the number of timesteps
        exp_names: list[str] N: A list of each exploration heuristic name and its value like 'ResMax: 4096'
    
    """
    sns.set(font_scale = 1.25)
    def moving_average(x):
        # from https://stackoverflow.com/questions/14313510/how-to-calculate-rolling-moving-average-using-numpy-scipy
        nonlocal window_size
        return np.convolve(x, np.ones(window_size), 'valid') / window_size

    sns.set_style(style='white')
    plt.figure()
    
    reward_data = np.array(reward_data)
    for i, exp_data in enumerate(reward_data):
        rol  =  np.apply_along_axis(moving_average,  1, exp_data) # rolling average 
        y = np.mean(rol, axis=0)
        y_std = np.apply_along_axis(scipy.stats.sem, 0, rol)
        
        x = range(0, len(y))
        plt.plot(x, y, label=labels[i])
        plt.fill_between(x, y-y_std/2, y+y_std/2, alpha=0.1)
        
    if title is not None:
        plt.title(title)

    plt.ylabel(y_axis_name)
    plt.xlabel(x_axis_name)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.legend(loc="best", framealpha=0)

def learning_curve_exp_algo(dic_data: dict, exp_technique: str, save_prefix: str = '', save_dir: str = '', epsilon_schedule: bool = False, aggregation_range: int = 1):
    for env in dic_data:
        fig = plt.figure()
        try:
            if exp_technique == 'epsilon-greedy' and epsilon_schedule:
                for portion_decay in dic_data[env][exp_technique]:
                    print('portion_decay: ', portion_decay)
                    for outside_value in dic_data[env][exp_technique][portion_decay]:
                        print('outside_value: ', outside_value)
                        for step_size in dic_data[env][exp_technique][portion_decay][outside_value]:
                            print('step_size: ', step_size)
                            for exp_value in dic_data[env][exp_technique][portion_decay][outside_value][step_size]:
                                
                                try:
                                    print('exp_value: ', exp_value)
                                    runs_data = dic_data[env][exp_technique][portion_decay][outside_value][step_size][exp_value]
                                    print(runs_data)
                                    mean = np.mean(runs_data, axis=0)
                                    print(mean)
                                    mean = np.mean(mean.reshape(-1, aggregation_range), axis=1)
                                    print('2nd: ', mean)
                                    std = np.std(runs_data,axis=0)
                                    print(std)
                                    std = np.mean(std.reshape(-1, aggregation_range), axis=1)
                                    x = list(range(1, len(mean)+1))
                                    label = '({}, {}, {})'.format(step_size, outside_value, portion_decay)
                                    plt.plot(x, mean, label=label)
                                    plt.fill_between(x, mean-std/2, mean+std/2, alpha=0.1)
                                except ValueError as e:
                                    print(e)

            else:
                for step_size in dic_data[env][exp_technique]:
                    for exp_value in dic_data[env][exp_technique][step_size]:
                        runs_data = dic_data[env][exp_technique][step_size][exp_value]
                        mean = np.mean(runs_data, axis=0)
                        mean = np.mean(mean.reshape(-1, aggregation_range), axis=1)
                        std = np.std(runs_data,axis=0)
                        std = np.mean(std.reshape(-1, aggregation_range), axis=1)
                        x = list(range(1, len(mean)+1))
                        label = '({}, {})'.format(step_size, exp_value)
                        plt.plot(x, mean, label=label)
                        plt.fill_between(x, mean-std/2, mean+std/2, alpha=0.1)

            plt.legend()
            plt.title('{} {} {}'.format(save_prefix, exp_technique, env))
            save_plot(os.path.join(save_dir, 'learning-curves_{}_{}_{}.png'.format(save_prefix, exp_technique, env)))
            plt.close(fig) 
        except Exception as e:
            print(e)
            print(exp_technique, ' does not exist')

def best_learning_curve_exp_algo(dic_data: dict, save_prefix: str = '', save_dir: str = '', epsilon_schedule: bool = False, aggregation_range: int = 1, is_steps: bool = False):
    for env in dic_data:
        fig = plt.figure()
        for exp_technique in dic_data[env]:
            try:
                if exp_technique == 'epsilon-greedy' and epsilon_schedule:
                    best_params = [None, None, None, None]
                    best_auc = float('+inf') if is_steps else float('-inf')
                    for portion_decay in dic_data[env][exp_technique]:
                        for outside_value in dic_data[env][exp_technique][portion_decay]:
                            for step_size in dic_data[env][exp_technique][portion_decay][outside_value]:
                                for exp_value in dic_data[env][exp_technique][portion_decay][outside_value][step_size]:
                                    
                                    try:
                                        runs_data = dic_data[env][exp_technique][portion_decay][outside_value][step_size][exp_value]
                                        mean = np.mean(runs_data, axis=0)
                                        mean = np.mean(mean.reshape(-1, aggregation_range), axis=1)
                                        auc = np.sum(mean)
                                        is_better = auc<best_auc if is_steps else auc>best_auc
                                        if is_better:
                                            best_params = [portion_decay, outside_value, step_size, exp_value]
                                            best_auc = auc
                                    except ValueError as e:
                                        print(e)

                    runs_data = dic_data[env][exp_technique][best_params[0]][best_params[1]][best_params[2]][best_params[3]]
                    mean = np.mean(runs_data, axis=0)
                    mean = np.mean(mean.reshape(-1, aggregation_range), axis=1)
 
                    std = sem(runs_data, axis=0)
                    std = np.mean(std.reshape(-1, aggregation_range), axis=1)
                    x = list(range(1, len(mean)+1))
                    label = r"$\varepsilon$-greedy"
                    # label = '({}, {}, {})'.format(step_size, outside_value, portion_decay)
                    plt.plot(x, mean, label=label)
                    plt.fill_between(x, mean-std, mean+std, alpha=0.1)
                    print(env, '  ', exp_technique)
                    print('best_portion_decay:', best_params[0])
                    print('best_outside_value :', best_params[1])
                    print('best_step_size :', best_params[2])
                    print('best_exp_value :', best_params[3])
                    print(save_prefix)
                else:
                    # Finding best learning-curve
                    best_step_size = None
                    best_exp_value = None
                    best_auc = float('+inf') if is_steps else float('-inf')
                    for step_size in dic_data[env][exp_technique]:
                        for exp_value in dic_data[env][exp_technique][step_size]:
                            runs_data = dic_data[env][exp_technique][step_size][exp_value]
                            mean = np.mean(runs_data, axis=0)
                            mean = np.mean(mean.reshape(-1, aggregation_range), axis=1)
                            auc = np.sum(mean)
                            is_better = auc<best_auc if is_steps else auc>best_auc
                            if is_better:
                                best_step_size = step_size
                                best_exp_value = exp_value
                                best_auc = auc

                    runs_data = dic_data[env][exp_technique][best_step_size][best_exp_value]
                    mean = np.mean(runs_data, axis=0)
                    mean = np.mean(mean.reshape(-1, aggregation_range), axis=1)
                    std = sem(runs_data, axis=0)
                    std = np.mean(std.reshape(-1, aggregation_range), axis=1)
                    x = list(range(1, len(mean)+1))
                    # label = '({}, {})'.format(best_step_size, best_exp_value)
                    label = exp_technique
                    if label == 'epsilon-greedy':
                        label = r"$\varepsilon$-greedy"
                    elif label == "ResMax":
                        label = '({})'.format("resmax")
                    elif label == "resmax":
                        label = '({})'.format("resmax")
                    elif label == "mellowmax":
                        label = '({})'.format("mellowmax")
                    elif label == "softmax":
                        label = '({})'.format("softmax")
                    else:
                        label = '({})'.format(exp_technique)
                    plt.plot(x, mean, label=label, color=color_map[exp_technique])
                    plt.fill_between(x, mean-std, mean+std, alpha=0.1, color=color_map[exp_technique])
                    print(env, '  ', exp_technique)
                    print('best_step_size :', best_step_size)
                    print('best_exp_value :', best_exp_value)
                    print(save_prefix)
            except Exception as e:
                print(e)
                print(exp_technique, ' does not exist')


        plt.legend(prop={'size': 14}, frameon=False)
        plt.xlabel(r'Step ($10^4$)', fontsize=16)
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        
        formatter = ticker.ScalarFormatter(useMathText=True)
        formatter.set_scientific(True)
        formatter.set_powerlimits((-1,1))
        plt.gca().yaxis.set_major_formatter(formatter)

        if save_prefix == 'returns':
            plt.ylabel('Average Return', fontsize=16)
        else:
            plt.ylabel('Average Steps per Episode', fontsize=16)

        plt.tick_params(axis='y', labelsize=16)
        plt.tick_params(axis='x', labelsize=16)

        plt.title('{}'.format(env), fontsize=18)
        save_plot(os.path.join(save_dir, 'best-learning-curves_{}_{}.png'.format(save_prefix, env)))
        plt.close(fig) 

def sensitivity_plot(dic_data: dict, env_step_sizes: dict, exp_values_dict: dict = None, save_prefix: str = '', save_dir: str = ''):

    for env in env_step_sizes:

        fig = plt.figure()

        plt.gca().set_xscale('symlog', basex=2)
        
        for exp_technique in exp_values_dict:
            for step_size in env_step_sizes[env]:
                if exp_values_dict is None:
                    exp_values_list = list(dic_data[env][exp_technique][step_size].keys())
                else:
                    exp_values_list = exp_values_dict[exp_technique]

                exp_value_num = len(exp_values_list)
                means = np.zeros(exp_value_num)
                stds = np.zeros(exp_value_num)

                exp_values = [float(exp_value) for exp_value in exp_values_list]
                exp_values = sorted(exp_values)

                if len(exp_values) == 0:
                    continue

                for idx, exp_value in enumerate(exp_values):
                    runs_data = dic_data[env][exp_technique][step_size][str(exp_value)]
                    runs_sum = np.sum(runs_data, axis=1)
                    means[idx] = np.mean(runs_sum, axis=0)
                    stds[idx]= sem(runs_sum, axis=0)

                label = exp_technique
                if label == 'epsilon-greedy':
                    label = r"$\varepsilon$-greedy"
                plt.plot(exp_values, means, color=color_map[exp_technique], label=label)
                plt.fill_between(exp_values, means-stds, means+stds, alpha=0.1, color=color_map[exp_technique])

        plt.legend(prop={'size': 14}, frameon=False)
        plt.xlabel(r'$\epsilon, \eta, \tau$', fontsize=16)
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)

        if save_prefix == 'returns':
            plt.ylabel('Accumulated Returns', fontsize=16)
        else:
            plt.ylabel('Accumulated Steps per Episode', fontsize=16)
        
        formatter = ticker.ScalarFormatter(useMathText=True)
        formatter.set_scientific(True)
        formatter.set_powerlimits((-1,1))
        plt.gca().yaxis.set_major_formatter(formatter)

        plt.title('{}'.format(env), fontsize=18)
        plt.tick_params(axis='y', labelsize=16)
        plt.tick_params(axis='x', labelsize=16)
        plt.xticks([0, 1, 2**2, 2**4, 2**6, 2**8, 2**10, 2**12, 2**14, 2**16])
        save_plot(os.path.join(save_dir, 'sensitivity_plot_{}_{}.pdf'.format(save_prefix, env)))
        plt.close(fig)


def bar_plot(dic_data: dict, env_step_sizes: dict, exp_values_dict: dict = None, save_prefix: str = '', bar_width: float = 0.25, epsilon_schedule: bool = False, save_dir: str = ''):
   
    # plt.rcParams.update({'font.size': 22})
    for env in env_step_sizes:

        fig = plt.figure()
        exp_techniques = []
        for exp_idx, exp_technique in enumerate(exp_values_dict):
            if exp_technique == 'epsilon-greedy' and epsilon_schedule:
                if env not in env_step_sizes:
                    continue

                step_size = env_step_sizes[env]
                
                exp_values = exp_values_dict[exp_technique]
                
                exp_value_num = len(exp_values)
                means = np.zeros(exp_value_num)
                stds = np.zeros(exp_value_num)
                poses = np.arange(exp_value_num) + bar_width*exp_idx
                
                if len(exp_values) == 0:
                    continue

                for idx, exp_value in enumerate(exp_values):
                    runs_data = dic_data[env][exp_technique][str(exp_value[0])][str(exp_value[1])][step_size][str(0.1)]
                    runs_sum = np.sum(runs_data, axis=1)
                    means[idx] = np.mean(runs_sum, axis=0)
                    stds[idx]= np.std(runs_sum, axis=0)
                
                # label = exp_technique
                # if label == 'epsilon-greedy':
                label = r"$\varepsilon$-greedy"
                plt.bar(poses, means, yerr=stds, color=color_map[exp_technique], width=bar_width, edgecolor='white', label=label)
                
                if exp_technique not in exp_techniques:
                    exp_techniques.append(exp_technique)


            else:
                if env not in env_step_sizes:
                    continue

                step_size = env_step_sizes[env]
                exp_values_list = exp_values_dict[exp_technique]
                
                exp_value_num = len(exp_values_list)
                means = np.zeros(exp_value_num)
                stds = np.zeros(exp_value_num)
                poses = np.arange(exp_value_num) + bar_width*exp_idx
 
                exp_values = [float(exp_value) for exp_value in exp_values_list]
                exp_values = sorted(exp_values)
                
                if len(exp_values) == 0:
                    continue

                for idx, exp_value in enumerate(exp_values):
                    runs_data = dic_data[env][exp_technique][step_size][str(exp_value)]
                    runs_sum = np.sum(runs_data, axis=1)
                    means[idx] = np.mean(runs_sum, axis=0)
                    stds[idx]= np.std(runs_sum, axis=0)
                
                label = exp_technique
                
                if label == 'epsilon-greedy':
                    label = r"$\varepsilon$-greedy"
                
                plt.bar(poses, means, yerr=stds, color=color_map[exp_technique], width=bar_width, edgecolor='white', label=label)
                if exp_technique not in exp_techniques:
                    exp_techniques.append(exp_technique)

        plt.xticks([r + bar_width for r in range(3)], ['Low', 'Medium', 'High'])
        plt.legend(prop={'size': 14})
        plt.xlabel('Exploration Intensity', fontsize=16)
        if save_prefix == 'returns':
            plt.ylabel('Accumulated Returns', fontsize=16)
        else:
            plt.ylabel('Accumulated Steps per Episode', fontsize=16)
        
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.gca().spines['left'].set_visible(False)
        
        plt.title('{}'.format(env), fontsize=18)
        plt.tick_params(axis='x', labelsize=16)
        # plt.tick_params(axis='y', labelsize=16)
        save_plot(os.path.join(save_dir, 'bar_plot_{}_{}.pdf'.format(save_prefix, env)))
        plt.close(fig)

def single_run_learning_curve_exp_algo(dic_data: dict, exp_technique: str, save_prefix: str = '', save_dir: str = '', aggregation_range: int = 10000):
    
    for env in dic_data:
        try:
            for step_size in dic_data[env][exp_technique]:
                for exp_value in dic_data[env][exp_technique][step_size]:
                    fig = plt.figure()
                    for idx, run_data in enumerate(dic_data[env][exp_technique][step_size][exp_value]):
                        x = list(range(1, len(run_data)+1))
                        label = '{}'.format(idx)
                        run_data = np.mean(run_data.reshape(-1, aggregation_range), axis=1)
                        plt.plot(x, run_data, marker_map[exp_technique], label=label)
                    
                    save_plot(os.path.join(save_dir, 'learning-curves_{}_{}_{}_{}_{}.png'.format(save_prefix, exp_technique, step_size, exp_value, env)))
                    plt.close(fig)
        except:
            print(exp_technique, ' does not exist')

def sensitivity_plot_2(dic_data: dict, env_step_sizes: dict, exp_values_dict: dict = None, save_prefix: str = '', save_dir: str = ''):
    for env in dic_data:
        
        fig = plt.figure()

        plt.gca().set_xscale('symlog', basex=2)
        for exp_technique in dic_data[env]:
            for step_size in dic_data[env][exp_technique]:
                exp_value_num = len(dic_data[env][exp_technique][step_size].keys())
                means = np.zeros(exp_value_num)
                stds = np.zeros(exp_value_num)
                exp_values = [float(exp_value) for exp_value in dic_data[env][exp_technique][step_size].keys()]
                exp_values = sorted(exp_values)

                if len(exp_values) == 0:
                    continue
                
                for idx, exp_value in enumerate(exp_values):
                    runs_data = dic_data[env][exp_technique][step_size][str(exp_value)]
                    runs_sum = np.sum(runs_data, axis=1)
                    means[idx] = np.mean(runs_sum, axis=0)
                    stds[idx]= sem(runs_sum, axis=0)
                
                # label = '({}, {})'.format(step_size, exp_technique)
                if exp_technique == "ResMax":
                    label = '({})'.format("resmax")
                elif exp_technique == "epsilon-greedy":
                    label = '({})'.format(r"$\varepsilon$-greedy")
                elif exp_technique == "mellowmax":
                    label = '({})'.format("mellowmax")
                elif exp_technique == "softmax":
                    label = '({})'.format("softmax")
                else:
                    label = '({})'.format(exp_technique)
                plt.plot(exp_values, means, color=color_map[exp_technique], label=label)
                plt.fill_between(exp_values, means-stds, means+stds, alpha=0.1, color=color_map[exp_technique])
        
        plt.legend(prop={'size': 14}, frameon=False)
        plt.xlabel(r'$\epsilon, \eta, \tau$', fontsize=16)            
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)

        if save_prefix == 'returns':
            plt.ylabel('Accumulated Returns', fontsize=16)
        else:
            plt.ylabel('Accumulated Steps per Episode', fontsize=16)
        formatter = ticker.ScalarFormatter(useMathText=True)
        formatter.set_scientific(True)
        formatter.set_powerlimits((-1,1))
        plt.gca().yaxis.set_major_formatter(formatter)

        plt.title('{}'.format(env), fontsize=18)
        plt.tick_params(axis='y', labelsize=16)
        plt.tick_params(axis='x', labelsize=16)
        plt.xticks([0, 1, 2**2, 2**4, 2**6, 2**8, 2**10, 2**12, 2**14, 2**16, 2**18])

        save_plot(os.path.join(save_dir, 'sensitivity_plot_{}_{}.pdf'.format(save_prefix, env)))
        plt.close(fig)
