import numpy as np
import os, sys
import gym
import warmup
from types import SimpleNamespace
import json
import torch

from utils import compute_entropy, plot_trajs, print_data, plot_ee_arm26
from replay_memory import ReplayMemory
from controllers import HomeoKinesisSplitBrainController
from controllers import DEP, NoiseController
from cluster import read_params_from_cmdline, save_metrics_params, announce_fraction_finished, \
    announce_early_results, exit_for_resume
from env_wrappers import apply_wrappers
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes
plt.rcParams.update(bundles.neurips2022())
plt.rcParams.update(figsizes.neurips2022())
bundle = bundles.neurips2022()
bundle['figure.figsize'] = (3.5, 2.8)
# bundle['font.size'] = 14
bundle['axes.labelsize'] = 20
# bundle['legend.fontsize'] = 8
bundle['xtick.labelsize'] = 18
bundle['ytick.labelsize'] = 18
bundle['axes.titlesize'] = 22
plt.rcParams.update(bundle)


WINDOW = 10
#PREFIX = 'torque'
#ACT_MULT = [1, 6, 15, 30, 300]
PREFIX = 'muscle'
ACT_MULT = [1, 2, 5, 10, 100]


def prepare_params():
    if sys.argv[-1] == '0':
        print()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'param_files/default_params.json'), 'r') as f:
            params = json.load(f)
        orig_params, params = get_params(params)
    else:
        orig_params = read_params_from_cmdline()
        _, params = prepare_cluster(orig_params)
    return orig_params, params

def prepare_cluster(params):
    os.makedirs(params.working_dir, exist_ok=True)
    return get_params(params)


def get_params(params):
    orig_params = params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return orig_params, params

def get_controller(orig_params, params, env):
    controllers = {'DEP' : DEP,
                   'HomeoSplitBrain': HomeoKinesisSplitBrainController,
                   'Noise': NoiseController}
    # return controllers.get(params.general.controller)()
    return controllers.get(params.general.controller)(orig_params[params.general.controller], env.action_space.shape[0], env.action_space.shape[0], env)


def post_run(orig_params, avg_return, action_buff, state_buff):
    entropies = compute_entropy(state_buff)
    print_data(entropies, state_buff)
    if sys.argv[-1] == '0':
        pass
        # plot_trajs(action_buff, state_buff)
        # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
    else:
        metrics = {'avg_return': avg_return,
                   'entropy': np.mean(entropies)}
        save_metrics_params(metrics, orig_params)
    return entropies

def entropy_loop(orig_params, params):
    print(f'Params are {orig_params}')
    env = gym.make(params.general.env, identifier=params.id)
    env.merge_args({"action_multiplier": act_mult})
    env.apply_args()
    print(f'Env is {params.general.env}')
    # env = apply_wrappers(params.general.controller, env)
    ctrl = get_controller(orig_params, params, env)
    try:
        ctrl.initialize(env.observation_space, env.action_space)
    except:
        pass

    mem = ReplayMemory(50000, 0, store_all_data=True)

    episodes = 10
    avg_return = 0
    state_buff = []
    action_buff = []
    full_states = []
    for ep in range(episodes):
        print(f'EP {ep} of {episodes}')
        ep_return = 0
        ep_steps = 0
        state = env.reset()
        state = env.muscles_dep
        # state = [state[0]]
        # full_state = env.full_state.copy()
        while True:
            action = ctrl.step(state)
            # print(np.max(action))
            # print(np.min(action))
            # print('next')
            # action=action[0,:]
            next_state, reward, terminal, info = env.step(action)
            next_state = env.muscles_dep
            # env.render()
            # next_state = next_state[0]
            # env.render()
            state = next_state
            state_buff.append(env.joint_state.copy())
            next_full_state = env.full_state
            full_states.append(next_full_state.copy())
            state_buff.append(env.joint_state.copy())
            action_buff.append(action.copy())

            ep_return += reward

            # reset = 1 if (ep_steps >= env.max_episode_steps or terminal) else 0
            reset = 1 if ep_steps >= 500 else 0
            # mem.push(env.full_state.copy(), action, reward, next_state, terminal, reset)
            # full_state = next_full_state

            # if ep_steps >= env.max_episode_steps:  #reset or terminal:
            if reset:  # or terminal:
                avg_return += ep_return
                print(f'{avg_return / (ep + 1)}')
                break
            ep_steps += 1

        if not ep % 10:
            # announce_fraction_finished(ep / episodes)
            # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
            pass

    avg_return /= episodes
    actions = torch.as_tensor(np.array(action_buff).transpose())
    # print(actions.shape)
    corr = torch.corrcoef(actions)
    fig, ax = plt.subplots()
    ax.imshow(corr, cmap='gist_heat')
    ax.set_xticks([0, 25, 49])
    ax.set_yticks([0, 25, 49])
    ax.set_xticklabels([1, 25, 50])
    ax.set_yticklabels([1, 25, 50])
    ax.set_xlabel('muscle i', fontsize=8)
    ax.set_ylabel('muscle j', fontsize=8)
    full_states = np.array(full_states)
    # plt.savefig("correlation_matrix.pdf", bbox_inches='tight',pad_inches = 0)
    entropies = post_run(orig_params, avg_return, action_buff, state_buff)
    if params.general.controller == 'Noise':
        filler = "_" + params.Noise.noise_type
    else:
        filler = ""
    np.save(f'./entropy_exp/entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy', entropies)
    # plot_ee_arm26(full_states, env, episodes=episodes, title='noise', name='noise')
    # post_run(orig_params, avg_return, action_buff, state_buff)
    # plot_ee_arm26(full_states, env, episodes=episodes, title='DEP', name='dep')


if __name__ == '__main__':
    title_dict = {}
    for a in ACT_MULT:
        if PREFIX == 'torque':
            title_dict[a] = f'{2*a} actions'
        else:
            title_dict[a] = f'{6*a} actions'
    orig_params, params = prepare_params()
    data = {}
    color_sample = []
    for i, act_mult in enumerate(ACT_MULT):
        fig, ax = plt.subplots(1, 1)
        means = []
        stds = []
        for control in ['Noise', 'DEP']:
        #for control in ['Noise']:
            params.general.controller = control
            if control == 'Noise':
                for noise_type in ['colored_0','colored_1', 'colored_2','ou']:
                #for noise_type in ['colored_0','colored_1', 'colored_2']:
                    params.Noise.noise_type = noise_type
                    filler = "_" + params.Noise.noise_type
                    print(filler)
                    entr = np.load(f'./entropy_exp/{PREFIX}_entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy')
                    data[f'{control}_{noise_type}_{act_mult}'] = entr
                    means.append(np.mean(entr))
                    stds.append(np.std(entr))
            else:
                filler = ""
                entr = np.load(f'./entropy_exp/{PREFIX}_entropies_{params.general.controller}_act_mult{act_mult}{filler}.npy')
                data[f'{control}_{act_mult}'] = entr
                means.append(np.mean(entr))
                stds.append(np.std(entr))
        x = np.arange(len(means))
        bars = ax.bar(x, means, yerr=stds)
        if i == 0:
            ax.set_ylabel('sample entropy')
        ax.set_title(title_dict[act_mult])
        colors = ['tab:blue', 'tab:pink', 'tab:red', 'tab:green', 'tab:orange']
        labels = ['white', 'pink', 'red', 'OU', 'DEP']
        for color, bar in zip(colors, bars):
            bar.set_color(color)
        ax.set_xticks(np.arange(len(labels)), labels)
        plt.savefig(f'{PREFIX}_act_mult_{act_mult}_entropy.pdf')


