import numpy as np
import os, sys 
import gym
import warmup
from types import SimpleNamespace
import json
import torch

from utils import compute_entropy, plot_trajs, print_data, plot_ee_arm26
from replay_memory import ReplayMemory
from controllers import HomeoKinesisSplitBrainController
from controllers import DEP, NoiseController
from cluster import read_params_from_cmdline, save_metrics_params, announce_fraction_finished, \
    announce_early_results, exit_for_resume
from env_wrappers import apply_wrappers
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes
plt.rcParams.update(bundles.neurips2022())

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

#fig_dict = figsizes.neurips2022()
plt.rcParams.update(change_figsize(figsizes.neurips2022(), fraction_width=0.4))
#bundle = bundles.neurips2022()
#bundle['figure.figsize'] = (4.5, 2.8)
# bundle['font.size'] = 14
#bundle['axes.labelsize'] = 14
#bundle['legend.fontsize'] = 9
#bundle['xtick.labelsize'] = 12
#bundle['ytick.labelsize'] = 12
#bundle['axes.titlesize'] = 22
#plt.rcParams.update(bundle)



WINDOW = 10
#PREFIX = 'torque'
#ACT_MULT = [1, 6, 15, 30, 300]
PREFIX = 'muscle'
ACT_MULT = [1, 2, 5, 10, 100]
#ACT_MULT = [1, 2]


def prepare_params():
    if sys.argv[-1] == '0':
        print()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'param_files/default_params.json'), 'r') as f:
            params = json.load(f)
        orig_params, params = get_params(params)
    else:
        orig_params = read_params_from_cmdline()
        _, params = prepare_cluster(orig_params)
    return orig_params, params

def prepare_cluster(params):
    os.makedirs(params.working_dir, exist_ok=True)
    return get_params(params)


def get_params(params):
    orig_params = params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return orig_params, params

def get_controller(orig_params, params, env):
    controllers = {'DEP' : DEP,
                   'HomeoSplitBrain': HomeoKinesisSplitBrainController,
                   'Noise': NoiseController}
    # return controllers.get(params.general.controller)()
    return controllers.get(params.general.controller)(orig_params[params.general.controller], env.action_space.shape[0], env.action_space.shape[0], env)


def post_run(orig_params, avg_return, action_buff, state_buff):
    entropies = compute_entropy(state_buff)
    print_data(entropies, state_buff)
    if sys.argv[-1] == '0':
        pass
        # plot_trajs(action_buff, state_buff)
        # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
    else:
        metrics = {'avg_return': avg_return,
                   'entropy': np.mean(entropies)}
        save_metrics_params(metrics, orig_params)
    return entropies

def plot_entrops(fig, ax, x, means, std, color, label=''):
    #x = np.arange(len(means))
    ax.errorbar(x, means, yerr=stds, color=color, label=label)
    ax.set_ylabel('state coverage')
    #ax.set_title(title_dict[act_mult])

def update_prop(handle, orig):
    handle.update_from(orig)
    handle.set_marker("-")


if __name__ == '__main__':
    if PREFIX == 'muscle':
        x = [6, 12, 30, 60, 600][:len(ACT_MULT)]
    else:
        x = [2, 12, 30, 60, 600][:len(ACT_MULT)]
    colors = ['tab:blue', '#cd66b1', 'tab:red', 'tab:green', 'tab:orange']
    labels = ['White', 'Pink', 'Red', 'OU', 'DEP']
    orig_params, params = prepare_params()
    data = {}
    color_sample = []
    fig, ax = plt.subplots(1, 1)
    for control in ['Noise', 'DEP']:
    #for control in ['Noise']:
        params.general.controller = control
        if control == 'Noise':
            for n_id, noise_type in enumerate(['colored_0','colored_1', 'colored_2','ou']):
            #for n_id, noise_type in enumerate(['colored_0','colored_1', 'colored_2']):
                params.Noise.noise_type = noise_type
                filler = "_" + params.Noise.noise_type
                print(filler)
                means = []
                stds = []
                for i, act_mult in enumerate(ACT_MULT):
                    entr = np.load(f'./entropy_exp/{PREFIX}_entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy')
                    data[f'{control}_{noise_type}_{act_mult}'] = entr
                    means.append(np.mean(entr))
                    stds.append(np.std(entr))
                plot_entrops(fig, ax, x, means, stds, color=colors[n_id], label=labels[n_id])
                print(means)
        else:
            means = []
            stds = []
            for i, act_mult in enumerate(ACT_MULT):
                filler = ""
                entr = np.load(f'./entropy_exp/{PREFIX}_entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy')
                data[f'{control}_{act_mult}'] = entr
                means.append(np.mean(entr))
                stds.append(np.std(entr))
            plot_entrops(fig, ax, x, means, stds, color=colors[-1], label=labels[-1])
            print(means)
    ax.set_xlabel('\# of actions')
    ax.set_xscale('log')
    ax.set_xticks(x)
    ax.set_xticklabels(x)
    #plt.tight_layout()
    #ax.legend([Line2D([0,1],[0,1], linestyle='-', color=line.get_color()) for line in ax.lines], [label for label in
    #                                                                              labels], ncol=5,
    #          bbox_to_anchor=(-0.14, 1.0), loc='lower left', frameon=False)
    plt.savefig(f'lineplot_entropy_{PREFIX}.pdf')
