import numpy as np
import os, sys
import gym
import warmup
from types import SimpleNamespace
import json
import torch

from utils import compute_entropy, plot_trajs, print_data, plot_ee_arm26
from replay_memory import ReplayMemory
from controllers import HomeoKinesisSplitBrainController
from controllers import DEP, NoiseController
from cluster import read_params_from_cmdline, save_metrics_params, announce_fraction_finished, \
    announce_early_results, exit_for_resume
from env_wrappers import apply_wrappers
from matplotlib import pyplot as plt
from utils import plot_ee_arm26

from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

plt.rcParams.update(bundles.neurips2022())
fig_dict = figsizes.neurips2022(nrows=2, ncols=5)
plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.75))


#bundle = bundles.neurips2022()
#bundle['figure.figsize'] = (3.5, 2.8)
#bundle['axes.labelsize'] = 40
#bundle['xtick.labelsize'] = 22
#bundle['ytick.labelsize'] = 22
#bundle['axes.titlesize'] = 26
#plt.rcParams.update(bundle)

# bundle['font.size'] = 14
# bundle['legend.fontsize'] = 8

WINDOW = 10
PREFIX = 'torque'
ACT_MULT = [1, 300]
#PREFIX = 'muscle'
#ACT_MULT = [1, 100]


def prepare_params():
    if sys.argv[-1] == '0':
        print()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'param_files/default_params.json'), 'r') as f:
            params = json.load(f)
        orig_params, params = get_params(params)
    else:
        orig_params = read_params_from_cmdline()
        _, params = prepare_cluster(orig_params)
    return orig_params, params

def prepare_cluster(params):
    os.makedirs(params.working_dir, exist_ok=True)
    return get_params(params)


def get_params(params):
    orig_params = params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return orig_params, params

def get_controller(orig_params, params, env):
    controllers = {'DEP' : DEP,
                   'HomeoSplitBrain': HomeoKinesisSplitBrainController,
                   'Noise': NoiseController}
    # return controllers.get(params.general.controller)()
    return controllers.get(params.general.controller)(orig_params[params.general.controller], env.action_space.shape[0], env.action_space.shape[0], env)


def post_run(orig_params, avg_return, action_buff, state_buff):
    entropies = compute_entropy(state_buff)
    print_data(entropies, state_buff)
    if sys.argv[-1] == '0':
        pass
        # plot_trajs(action_buff, state_buff)
        # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
    else:
        metrics = {'avg_return': avg_return,
                   'entropy': np.mean(entropies)}
        save_metrics_params(metrics, orig_params)
    return entropies



if __name__ == '__main__':
    title_dict_noise = dict(ou='OU', colored_0='White', colored_1='Pink', colored_2='Red')

    orig_params, params = prepare_params()
    data = {}
    color_sample = []
    fig, axs = plt.subplots(2, 5, sharex=True, sharey=True)
    for i, act_mult in enumerate(ACT_MULT):
        means = []
        stds = []
        for control in ['Noise', 'DEP']:
        # for control in ['Noise']:
            params.general.controller = control
            if control == 'Noise':
                for noise_id, noise_type in enumerate(['colored_0','colored_1', 'colored_2', 'ou']):
                # for noise_type in ['colored_0', 'colored_1', 'colored_2']:
                    params.Noise.noise_type = noise_type
                    filler = "_" + params.Noise.noise_type
                    print(filler)
                    state_buff = np.load(f'./entropy_exp/state_trajs/{PREFIX}_entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy')
                    title_bool = (i == 0)
                    title = title_dict_noise[noise_type]
                    left = noise_type == 'colored_0'
                    
                    plot_ee_arm26(axs[i, noise_id], state_buff, prefix=PREFIX, name=f'{PREFIX}_{params.general.controller}_{act_mult}_{filler}', title_bool=title_bool, title=title, left=left)
            else:
                filler = ""
                state_buff = np.load(f'./entropy_exp/state_trajs/{PREFIX}_entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy')
                title_bool = (i == 0)
                title = 'DEP'
                plot_ee_arm26(axs[i, 4], state_buff, prefix=PREFIX, name=f'{PREFIX}_{params.general.controller}_{act_mult}_{filler}', title_bool=title_bool, title=title, left=False)
    for idx in range(5):
        axs[1, idx].set_xlabel('hand-x')
    for idx in range(2):
        axs[idx, 0].set_ylabel('hand-y')
    plt.savefig(f'ee_{PREFIX}_arm26.pdf')


