import numpy as np
import matplotlib.pyplot as plt
import os

from scipy.stats import multivariate_normal
from scipy import interpolate
from matplotlib.gridspec import GridSpec
from util.SaveAndLoad import load_and_sort_all_models_linmoe

colors = ['blue', 'red', 'green', 'purple' ,'black', 'gray', 'rosybrown', 'maroon', 'red', 'coral', 'chocolate',
          'darkorange',  'olive', 'lightgreen', 'limegreen', 'lightseagreen', 'cyan', 'lightblue', 'deepskyblue', 'lightslategray',
           'blueviolet', 'plum', 'magenta', 'palevioletred']*50

def load_mean_dists(path):
    try:
        return np.load(path + 'mean_dists.npy')
    except:
        return np.load(path + 'mean_dists.npz', allow_pickle=True)['arr_0']

def load_cmp_mean_task_rewards(path):
    try:
        return np.load(path + 'comp_mean_task_rewards_c_e.npy')
    except:
        return np.load(path + 'comp_mean_task_rewards_c_e.npz', allow_pickle=True)['arr_0']

def load_cmp_entropies(path):
    try:
        return np.load(path + 'comp_entropoies_c_e.npy')
    except:
        return np.load(path + 'comp_entropoies_c_e.npz',allow_pickle=True)['arr_0']
def load_cmp_mean_log_resps(path):
    try:
        return np.load(path + 'comp_mean_log_resps_c_e.npy')
    except:
        return np.load(path + 'comp_mean_log_resps_c_e.npz',allow_pickle=True)['arr_0']
def load_ctxt_entropies(path):
    try:
        return np.load(path + 'cond_ctxt_entropies_e_c.npy')
    except:
        return np.load(path + 'cond_ctxt_entropies_e_c.npz',allow_pickle=True)['arr_0']
def load_ctxt_mean_log_resps(path):
    try:
        return np.load(path + 'cond_ctxt_mean_log_resps_c_e.npy')
    except:
        return np.load(path + 'cond_ctxt_mean_log_resps_c_e.npz',allow_pickle=True)['arr_0']
def load_weight_entropies(path):
    try:
        return np.load(path + 'marginal_entropies_e.npy')
    except:
        return np.load(path + 'marginal_entropies_e.npz',allow_pickle=True)['arr_0']
def load_weight_rewards(path):
    try:
        return np.load(path + 'marginal_rewards_e_c.npy')
    except:
        return np.load(path + 'marginal_rewards_e_c.npz',allow_pickle=True)['arr_0']
def load_weight_weights(path):
    try:
        return np.load(path + 'marginal_weights_c_e.npy')
    except:
        return np.load(path + 'marginal_weights_c_e.npz',allow_pickle=True)['arr_0']
def load_test_entropy(path):
    try:
        return np.load(path + 'test_mixture_model_entropy.npy')
    except:
        return np.load(path + 'test_mixture_model_entropy.npz',allow_pickle=True)['arr_0']
def load_test_reward(path):
    try:
        return np.load(path + 'test_reward.npy')
    except:
        return np.load(path + 'test_reward.npz',allow_pickle=True)['arr_0']

def load_x_n_samples(path):
    try:
        n_samples_executed = np.load(path + 'n_ep_samples_executed.npy')
        return np.cumsum(n_samples_executed)
    except:
        n_samples_executed = np.load(path + 'n_ep_samples_executed.npz',allow_pickle=True)['arr_0']
        return np.cumsum(n_samples_executed)

def load_x_n_samples_t(path):
    try:
        n_samples_executed = np.load(path + 'n_env_interacts.npy')
        return np.cumsum(n_samples_executed)
    except:
        n_samples_executed = np.load(path + 'n_env_interacts.npz',allow_pickle=True)['arr_0']
        return np.cumsum(n_samples_executed)

def plot_time_line(path2exp, load_func, colors, fig=None, legend_names=None, single_exp_name=None, it=None):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)

    if legend_names is None:
        legend_names = []
    exps_list = [single_exp_name] if single_exp_name is not None else os.listdir(path2exp)
    for i, c_exp_name in enumerate(exps_list):

        c_exp_path = path2exp + '/' + c_exp_name
        # add legend names -> hyperparams from cw are added after '__'
        legend_names.append(c_exp_path.split('/')[-1].split('__')[-1])
        c_exp_path += '/log/'
        data = []
        for rep in os.listdir(c_exp_path):
            if it is None:
                c_rep_path = c_exp_path + rep +'/0/'
            else:
                c_rep_path = c_exp_path + rep + '/data_p_iteration/0/it_' + str(it) + '/'
            loaded_data = load_func(c_rep_path)
            if len(loaded_data.shape) == 2:
                loaded_data = loaded_data[:, 0]
            data.append(loaded_data)

        data = np.stack(data)
        mean_data = np.mean(data, axis=0)
        std_data = np.std(data, axis=0)
        plt.plot(range(mean_data.shape[0]), mean_data, color = colors[i])
        plt.fill_between(range(mean_data.shape[0]), mean_data - 2*std_data, mean_data + 2*std_data, alpha=0.2,
                         color=colors[i])
    plt.grid()
    plt.legend(legend_names)
    return fig, legend_names


def plot_rewards_on_n_samples_no_cw_path(path2exp, iterations_to_consider, mode='env_interacts', fig=None, legend_names=None, grid = True):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)

    if legend_names is None:
        legend_names = []
    iterations_to_consider = iterations_to_consider[:-1]

    n_samples_tmp = load_x_n_samples(path2exp)
    n_samples_n_env_ints_tmp = load_x_n_samples_t(path2exp)

    n_samples = np.zeros(n_samples_tmp.shape)
    n_samples_n_env_ints = np.zeros(n_samples_n_env_ints_tmp.shape)

    n_samples[:-1] = n_samples_tmp[1:]  # mistakenly saved the number of samples to early, but model was saved after one update iteration.
    n_samples_n_env_ints[:-1] = n_samples_n_env_ints_tmp[1:]
    n_samples[-1] = n_samples_tmp[-1]  # we do not sample in the last iterations -> number of samples stays same
    n_samples_n_env_ints[-1] = n_samples_n_env_ints_tmp[-1]

    red_its = np.where(n_samples[iterations_to_consider] == n_samples[-1])[0]
    n_samples = np.delete(n_samples[iterations_to_consider],
                          red_its[:-1])  # delete all the entries which have same number of n_samples
    n_samples_n_env_ints = np.delete(n_samples_n_env_ints[iterations_to_consider],
                          red_its[:-1])  # delete all the entries which have same number of n_samples

    test_rewards = np.delete(load_test_reward(path2exp),
                             red_its)  # test_reward has as last entry the reward from after cleaning up
    # -> delete the rest (with red_its) -> we want to use the reward of the cleaned up model
    if mode =='env_interacts':
        plt.plot(n_samples_n_env_ints, test_rewards)
    else:
        plt.plot(n_samples, test_rewards)
    if grid:
        plt.grid(True)

    return fig, legend_names

def plot_rewards_on_n_samples(path2exp, iterations_to_consider, colors, fig=None, legend_names=None, grid = True):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)

    if legend_names is None:
        legend_names = []
    exps_list = os.listdir(path2exp)
    iterations_to_consider = iterations_to_consider[:-1]    # the last one is the same as the one before but only with deleted comps

    # first prepare the x axis
    for i, c_exp_name in enumerate(exps_list):
        c_exp_path = path2exp + '/' + c_exp_name
        legend_names.append(c_exp_path.split('/')[-1].split('__')[-1])
        c_exp_path += '/log/'
        n_samples_reps = []
        rewards_reps = []
        all_maximums = []
        all_minimums = []
        for rep in os.listdir(c_exp_path):
            c_rep_path = c_exp_path + rep + '/0/'
            n_samples_tmp = load_x_n_samples(c_rep_path)
            n_samples = np.zeros(n_samples_tmp.shape)
            n_samples[:-1] = n_samples_tmp[1:] # mistakenly saved the number of samples to early, but model was saved after one update iteration.
            n_samples[-1] = n_samples_tmp[-1] # we do not sample in the last iterations -> number of samples stays same
            red_its = np.where(n_samples[iterations_to_consider] == n_samples[-1])[0]
            n_samples = np.delete(n_samples[iterations_to_consider], red_its[:-1])  # delete all the entries which have same number of n_samples
            test_rewards = np.delete(load_test_reward(c_rep_path), red_its) # test_reward has as last entry the reward from after cleaning up
                                                                            # -> delete the rest (with red_its) -> we want to use the reward of the cleaned up model
            n_samples_reps.append(n_samples)
            rewards_reps.append(test_rewards)
            all_maximums.append(np.max(n_samples))
            all_minimums.append(np.min(n_samples))

        min_of_maxs = np.min(all_maximums)
        max_of_mins = np.max(all_minimums)
        all_n_samples = np.stack(n_samples_reps).reshape(-1)
        all_n_samples = np.sort(np.unique(all_n_samples))
        all_n_samples = all_n_samples[:np.where(all_n_samples>min_of_maxs)[0][0]]
        all_n_samples = all_n_samples[np.where(all_n_samples<=max_of_mins)[0][-1]:]
        all_interpolated_rews = []
        for j in range(len(rewards_reps)):
            c_rewards = rewards_reps[j]
            c_n_samples = n_samples_reps[j]
            f = interpolate.interp1d(c_n_samples, c_rewards)
            all_interpolated_rews.append(f(all_n_samples))

        data = np.stack(all_interpolated_rews)
        mean_data = np.mean(data, axis=0)
        std_data = np.std(data, axis=0)
        skip_samples = 50
        all_n_samples = np.concatenate((all_n_samples[::skip_samples], all_n_samples[-1].reshape(-1)))
        mean_data = np.concatenate((mean_data[::skip_samples], mean_data[-1].reshape(-1)))
        std_data = np.concatenate((std_data[::skip_samples], std_data[-1].reshape(-1)))/np.sqrt(data.shape[0])
        plt.plot(all_n_samples, mean_data, color=colors[i])
        plt.fill_between(all_n_samples, mean_data - 2 * std_data, mean_data + 2 * std_data, alpha=0.2, color=colors[i])

    if grid:
        plt.grid()
    plt.legend(legend_names)
    min_xlim = plt.xlim()[0]
    max_xlim = all_n_samples[-1]
    plt.xlim((min_xlim, max_xlim))
    return fig, legend_names


def plot_mean_pi_s_o_time_line(path2exp, colors, linestyle=None, fig=None, legend_names=None):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)

    if legend_names is None:
        legend_names = []
    exps_list = os.listdir(path2exp)
    for i, c_exp_name in enumerate(exps_list):
        c_exp_path = path2exp + '/' + c_exp_name
        # add legend names -> hyperparams from cw are added after '__'
        legend_names.append(c_exp_path.split('/')[-1].split('__')[-1])
        c_exp_path += '/log/'
        mean_entropies_pi_s_o = []
        for rep in os.listdir(c_exp_path):
            c_rep_path = c_exp_path + rep + '/0/'
            mean_entropies_pi_s_o.append(np.mean(load_ctxt_entropies(c_rep_path), axis =0))
        mean_entropies_pi_s_o = np.stack(mean_entropies_pi_s_o)[:, :, 0]
        std_mean_entropies_pi_s_o = np.std(mean_entropies_pi_s_o, axis=0)
        mean_entropies_pi_s_o = np.mean(mean_entropies_pi_s_o, axis=0)
        plt.plot(range(mean_entropies_pi_s_o.shape[0]), mean_entropies_pi_s_o, linestyle=linestyle[i], color=colors[i])
        plt.fill_between(range(mean_entropies_pi_s_o.shape[0]), mean_entropies_pi_s_o-2*std_mean_entropies_pi_s_o, mean_entropies_pi_s_o+2*std_mean_entropies_pi_s_o, alpha=0.2,
                         color=colors[i])
    plt.grid()
    plt.legend(legend_names)
    return fig, legend_names

def plot_mean_pi_s_o_diff(path2exp, colors, it_lookback= 50, it_lookback_rewards = 10, it_thresh=99, fig=None,
                          legend_names=None, zero_resp=False, grid=True,xlim=None, plt_exp_mixture_entr=True):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)
    exps_list = os.listdir(path2exp)
    if legend_names is None:
        legend_names = []
    for i, c_exp_name in enumerate(exps_list):
        c_exp_path = path2exp + '/' + c_exp_name
        # add legend names -> hyperparams from cw are added after '__'
        legend_names.append(c_exp_path.split('/')[-1].split('__')[-1])
        c_exp_path += '/log/'
        mean_entropies_pi_s_o = []
        mean_rewards = []
        exp_mixt_entr = []
        for rep in os.listdir(c_exp_path):
            c_rep_path = c_exp_path + rep + '/0/'
            mean_entropies_pi_s_o.append(np.mean(load_ctxt_entropies(c_rep_path), axis=0))
            mean_rewards.append(load_test_reward(c_rep_path))
            if plt_exp_mixture_entr:
                exp_mixt_entr.append(load_test_entropy(c_rep_path))
        mean_entropies_pi_s_o = np.stack(mean_entropies_pi_s_o)[:, :, 0]
        rewards = np.stack(mean_rewards)
        if plt_exp_mixture_entr:
            exp_mixt_entr = np.stack(exp_mixt_entr)
            mean_exp_mixt_entr = np.mean(exp_mixt_entr, axis=0)
            std_exp_mixt_entr = np.std(exp_mixt_entr, axis=0)
        std_rewards = np.std(rewards, axis=0)
        mean_rewards = np.mean(rewards, axis=0)
        std_mean_entropies_pi_s_o = np.std(mean_entropies_pi_s_o, axis=0)
        mean_entropies_pi_s_o = np.mean(mean_entropies_pi_s_o, axis=0)
        mean_last_its = np.mean(mean_entropies_pi_s_o[-(it_lookback+it_thresh):-it_thresh])
        mean_std_last_its = np.mean(std_mean_entropies_pi_s_o[-(it_lookback+it_thresh):-it_thresh])
        mean_rewards_last_its = np.mean(mean_rewards[-it_lookback_rewards:])
        mean_reward_stds_last_its = np.mean(std_rewards[-it_lookback_rewards:])
        marker = 'o' if zero_resp else 'x'
        if plt_exp_mixture_entr:
            plt.subplot(211)
        plt.errorbar(x=mean_last_its, y=mean_rewards_last_its, xerr=mean_std_last_its * 2,
                     fmt=marker, color=colors[i])
        if plt_exp_mixture_entr:
            plt.subplot(212)
            plt.plot(mean_exp_mixt_entr, colors[i])
            plt.fill_between(mean_exp_mixt_entr.shape[0], mean_exp_mixt_entr-2*std_exp_mixt_entr,mean_exp_mixt_entr+2*std_exp_mixt_entr,
                             color=colors[i], alpha = 0.2)
    if grid:
        plt.grid()
    if xlim is not None:
        plt.xlim(xlim[0], xlim[1])
    plt.legend(legend_names)
    return fig, legend_names

def plot_time_line_cmp_wise(path2exp_with_rep, load_func, colors, fig=None, cmp_indices=None):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)
    legend_names = []
    data = load_func(path2exp_with_rep)
    if cmp_indices is None:
        iterate_over = range(data.shape[0])
    else:
        iterate_over = cmp_indices
    for cmp_idx in iterate_over:
        legend_names.append(str(cmp_idx))
        plt.plot(range(data[cmp_idx, :, 0].shape[0]), data[cmp_idx, :, 0], color=colors[cmp_idx])
    plt.grid()
    plt.legend(legend_names)
    return fig

def plot_cmp_ctxts_4d(model, fig, colors, cmp_indices):
    if fig is None:
        fig = plt.figure()
        fig_axes_1 = fig.add_subplot(211)
        fig_axes_2 = fig.add_subplot(212)
    if cmp_indices is None:
        cmp_indices = range(model.num_components)
    for j, i in enumerate(cmp_indices):
        c_comp = model.ctxt_components[i]
        color = colors[i] if colors is not None else None
        c_comp_1_mean = c_comp.mean[:2]
        c_comp_1_cov = c_comp.covar[:2, :2]
        draw_2d_gaussian(c_comp_1_mean, c_comp_1_cov, fig_axes_1, color=color)
        c_comp_2_mean = c_comp.mean[2:]
        c_comp_2_cov = c_comp.covar[2:, 2:]
        draw_2d_gaussian(c_comp_2_mean, c_comp_2_cov, fig_axes_2, color=color)
    return fig

def plot_cmp_ctxts(model, fig=None, colors=None, cmp_indices=None, ctxt_range_min=-6, ctxt_range_max=6, swap_axis=False):
    ctxt_dim = model.ctxt_dim
    if ctxt_dim == 4:
        plot_cmp_ctxts_4d(model, fig, colors, cmp_indices)
    else:
        if fig is None:
            fig = plt.figure()
            plt.grid(True)
        plt.figure(fig.number)
        if cmp_indices is None:
            cmp_indices = range(model.num_components)

        for j, i in enumerate(cmp_indices):
            c_comp = model.ctxt_components[i]
            color = colors[i] if colors is not None else None
            if ctxt_dim == 2:
                draw_2d_gaussian(c_comp.mean, c_comp.covar, fig, color=color)
            elif ctxt_dim == 1:
                draw_1d_gaussian(c_comp.mean, sigma=c_comp.covar, range_min=ctxt_range_min, range_max=ctxt_range_max, fig=fig,
                                 swap_axis=swap_axis, color=color)
            else:
                raise ValueError("Cannot plot more than 2 dimensions")
    return fig

def plot_ctxt_heat_map(ctxts, success_rates, n_samples_per_row=None, fig=None):
    n_ctxts = ctxts.shape[0]
    if fig is None:
        fig = plt.figure()
    if n_samples_per_row is None:
        n_samples_per_row = n_ctxts
    n_samples_per_column = int(n_ctxts/n_samples_per_row)
    if ctxts.shape[1] != 2:
        raise ValueError("Can only plot 2 dim ctxt heat map")
    if type(success_rates) is list:
        success_rates = np.array(success_rates)

    fig_axes = fig.axes
    if len(fig_axes) == 0:
        fig_axes = fig.add_subplot(111)
    x_axis = np.round(np.unique(ctxts[:, 0]), 2)
    y_axis = np.round(np.unique(ctxts[:, 1])[::-1], 2)
    success_rates = np.reshape(success_rates, (n_samples_per_row, n_samples_per_column), 'F') # column-wised reshaping!
    im = fig_axes.imshow(success_rates, cmap='Blues')
    cbar = fig_axes.pcolor(success_rates, vmin=0.0, vmax=1.0)
    fig.colorbar(cbar)
    fig_axes.set_xticks(np.arange(x_axis.shape[0]))
    fig_axes.set_yticks(np.arange(y_axis.shape[0]))
    fig_axes.set_xticklabels(x_axis)
    fig_axes.set_yticklabels(y_axis)
    for k, label in enumerate(fig_axes.get_xticklabels()):
        if k% 2!=0:
            label.set_visible(False)
    for k, label in enumerate(fig_axes.get_yticklabels()):
        if k% 2!=0:
            label.set_visible(False)
    fig_axes.set_title("Success Rates in Context Space")
    fig.tight_layout()
    return fig

def plot_joint_entropy(path2exp_with_rep, fig=None, **kwargs):
    if fig is None:
        fig = plt.figure()
    models, sorted_models, _ = load_and_sort_all_models_linmoe(path2exp_with_rep)
    joint_entropies = np.zeros(len(sorted_models))
    print('number of models:', len(sorted_models))
    for it, model in enumerate(sorted_models):
        print('it:', it)
        joint_entropies[it] = model.joint_entropy(num_samples=7)
    plt.plot(joint_entropies, **kwargs)
    return fig

def draw_2d_gaussian(mu, sigma, fig, plt_std = 2, *args, **kwargs):
    try:
        plt.figure(fig.number)
    except Exception:
        plt.sca(fig)
    idx = np.where(np.abs(sigma)<1e-10)
    if idx[0].shape[0] and idx[1].shape[0]:
        sigma[idx] = 0
    (largest_eigval, smallest_eigval), eigvec = np.linalg.eig(sigma)
    phi = -np.arctan2(eigvec[0, 1], eigvec[0, 0])
    plt.plot(mu[0:1], mu[1:2], marker="x", *args, **kwargs)

    a = plt_std * np.sqrt(largest_eigval)
    b = plt_std * np.sqrt(smallest_eigval)

    ellipse_x_r = a * np.cos(np.linspace(0, 2 * np.pi, num=20))
    ellipse_y_r = b * np.sin(np.linspace(0, 2 * np.pi, num=20))

    R = np.array([[np.cos(phi), np.sin(phi)], [-np.sin(phi), np.cos(phi)]])
    r_ellipse = np.array([ellipse_x_r, ellipse_y_r]).T @ R
    plt.plot(mu[0] + r_ellipse[:, 0], mu[1] + r_ellipse[:, 1], *args, **kwargs)

def draw_context_bounds_2d(context_range_bounds, fig):
    try:
        plt.figure(fig.number)
    except Exception:
        plt.sca(fig)
    plt.plot([context_range_bounds[0][0], context_range_bounds[1][0]],
             [context_range_bounds[0][1], context_range_bounds[0][1]], 'r-', linewidth=1)
    plt.plot([context_range_bounds[0][0], context_range_bounds[1][0]],
             [context_range_bounds[1][1], context_range_bounds[1][1]], 'r-', linewidth=1)
    plt.plot([context_range_bounds[0][0], context_range_bounds[0][0]],
             [context_range_bounds[0][1], context_range_bounds[1][1]], 'r-', linewidth=1)
    plt.plot([context_range_bounds[1][0], context_range_bounds[1][0]],
             [context_range_bounds[0][1], context_range_bounds[1][1]], 'r-', linewidth=1)

def draw_1d_gaussian(mu, sigma, range_min, range_max, fig, swap_axis=False, n_points = 1000, grid=True, *args, **kwargs):
    points = np.linspace(range_min, range_max, n_points)
    densities = multivariate_normal.pdf(points, mean=mu, cov=sigma)
    plt.figure(fig.number)
    if swap_axis:
        x_axis = densities - 5
        y_axis = points
    else:
        x_axis = points
        y_axis = densities
    plt.plot(x_axis, y_axis, *args, **kwargs)
    if grid:
        plt.grid()

def plot_tip_trajs(tip_trajs_list, fig=None, idx=None, use_colors=None):
    if fig is None:
        fig = plt.figure()
        gs = GridSpec(2, 9, figure=fig)
        fig.add_subplot(gs[0, :-1])
        plt.grid()
        fig.add_subplot(gs[1, :-1])
        plt.grid()
        fig.add_subplot(gs[:, -1])
    if use_colors is None:
        use_colors = colors

    plt.figure(fig.number)
    c_axes = fig.axes
    n_elems = 0
    used_labels = []
    for k, trajs in enumerate(tip_trajs_list):
        if type(trajs) is list:
            n_elems = len(trajs)
            for c_cmp_traj in trajs:
                c_axes[0].plot(c_cmp_traj[:, 0], color=use_colors[k], alpha=0.5, label=str(k))
                c_axes[1].plot(c_cmp_traj[:, 1], color=use_colors[k], alpha=0.5, label=str(k))
        else:
            if idx is not None:
                c_label = str(idx[k])
                c_color = use_colors[int(idx[k])]
            else:
                c_label = str(k)
                c_color = use_colors[k]
            if c_label in used_labels:
                c_label = None
            else:
                used_labels.append(c_label)
            n_elems = 1
            c_axes[0].plot(trajs[:, 0], color=c_color, alpha=0.5, label=c_label)
            c_axes[1].plot(trajs[:, 1], color=c_color, alpha=0.5, label=c_label)
    handles, labels = c_axes[2].get_legend_handles_labels()
    c_axes[-1].legend(handles[::n_elems], labels[::n_elems])
    return fig

def plot_single_dim_ball_traj(ball_trajs_list, fig=None, idx=None, use_colors=None, dim=2):
    if fig is None:
        fig = plt.figure()
        fig.add_subplot(111)

    if use_colors is None:
        use_colors = colors

    plt.figure(fig.number)
    c_axes = fig.axes
    n_elems = 0
    used_labels = []
    for k, trajs in enumerate(ball_trajs_list):
        if type(trajs) is list:
            n_elems = len(trajs)
            for c_cmp_traj in trajs:
                c_axes[0].plot(c_cmp_traj[:, dim], color=use_colors[k], alpha=0.8, label=str(k))
        else:
            if idx is not None:
                c_label = str(idx[k])
                c_color = use_colors[int(idx[k])]
                if c_color == 'olive':
                    c_color = 'blue'
                if c_color == 'cyan':
                    c_color = 'green'
            else:
                c_label = str(k)
                c_color = use_colors[k]
            if c_label in used_labels:
                c_label = None
            else:
                used_labels.append(c_label)
            n_elems = 1
            c_axes[0].plot(trajs[:, dim], color=c_color, alpha = 0.8, label=c_label)
    handles, labels = c_axes[0].get_legend_handles_labels()
    c_axes[-1].legend(handles[::n_elems], labels[::n_elems])
    plt.grid(True)
    return fig

def plot_ball_trajs(ball_trajs_list, fig=None, idx=None, use_colors=None):
    if fig is None:
        fig = plt.figure()
        gs = GridSpec(3, 9, figure=fig)
        fig.add_subplot(gs[0, :-1])
        plt.grid()
        fig.add_subplot(gs[1, :-1])
        plt.grid()
        fig.add_subplot(gs[2, :-1])
        plt.grid()
        fig.add_subplot(gs[:, -1])
    if use_colors is None:
        use_colors = colors

    plt.figure(fig.number)
    c_axes = fig.axes
    n_elems = 0
    used_labels = []
    for k, trajs in enumerate(ball_trajs_list):
        if type(trajs) is list:
            n_elems = len(trajs)
            for c_cmp_traj in trajs:
                c_axes[0].plot(c_cmp_traj[:, 0], color=use_colors[k], alpha=0.5, label=str(k))
                c_axes[1].plot(c_cmp_traj[:, 1], color=use_colors[k], alpha=0.5, label=str(k))
                c_axes[2].plot(c_cmp_traj[:, 2], color=use_colors[k], alpha=0.5, label=str(k))
        else:
            if idx is not None:
                c_label = str(idx[k])
                c_color = use_colors[int(idx[k])]
            else:
                c_label = str(k)
                c_color = use_colors[k]
            if c_label in used_labels:
                c_label = None
            else:
                used_labels.append(c_label)
            n_elems = 1
            c_axes[0].plot(trajs[:, 0], color=c_color, alpha=0.5, label=c_label)
            c_axes[1].plot(trajs[:, 1], color=c_color, alpha=0.5, label=c_label)
            c_axes[2].plot(trajs[:, 2], color=c_color, alpha=0.5, label=c_label)
    handles, labels = c_axes[2].get_legend_handles_labels()
    # c_axes[-1].legend(handles[::len(ball_trajs_list[0])], labels[::len(ball_trajs_list[0])])
    c_axes[-1].legend(handles[::n_elems], labels[::n_elems])
    return fig

def plot_joint_trajs(joint_rajs_list, fig=None, idx=None, n_joints=7):
    if fig is None:
        fig = plt.figure(constrained_layout=True)
        gs = GridSpec(n_joints, 9, figure=fig)
        for i in range(n_joints):
            fig.add_subplot(gs[i, :-1])
        fig.add_subplot(gs[:, -1])
    plt.figure(fig.number)
    c_axes = fig.axes
    for k, trajs in enumerate(joint_rajs_list):
        if type(trajs) is list:
            n_elems = len(trajs)
            for c_cmp_traj in trajs:
                for i in range(n_joints):
                    c_axes[i].plot(c_cmp_traj[:, i], color=colors[k], alpha=0.5, label=str(k))
        else:
            n_elems = 1
            if idx is not None:
                c_label = str(idx[k])
                c_color = colors[int(idx[k])]
            else:
                c_label = str(k)
                c_color = colors[k]
            for i in range(n_joints):
                c_axes[i].plot(trajs[:, i], color=c_color, alpha=0.5, label=c_label)
    handles, labels = c_axes[2].get_legend_handles_labels()
    c_axes[-1].legend(handles[::n_elems], labels[::n_elems])
    return fig

def plot_entr_rewards(all_entropies, all_rewards, all_rewards_std, fig=None):
    if fig is None:
        fig = plt.figure()
    plt.figure(fig.number)
    all_entropies = np.array(all_entropies)
    all_rewards = np.array(all_rewards)
    all_rewards_std = np.array(all_rewards_std)
    sorted_idx = np.argsort(all_entropies)
    plt.errorbar(all_entropies[sorted_idx], all_rewards[sorted_idx], all_rewards_std[sorted_idx]*2, fmt='o')
    plt.grid()
