import numpy as np
import os, sys
import gym
import warmup
from types import SimpleNamespace
import json
#import torch
import colorednoise as cn
from osim_rl.gym_version import NeuripsGym

#from utils import compute_entropy, plot_trajs, print_data, plot_ee_arm26
#from replay_memory import ReplayMemory
from controllers import HomeoKinesisSplitBrainController
from controllers import DEP, NoiseController
from cluster import read_params_from_cmdline, save_metrics_params, announce_fraction_finished,\
    announce_early_results, exit_for_resume
from env_wrappers import apply_wrappers
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes 
plt.rcParams.update(bundles.neurips2022())
plt.rcParams.update(figsizes.neurips2022())
#bundle = bundles.icml2022()
#bundle['font.size'] = 14
#bundle['axes.labelsize'] = 10
#bundle['legend.fontsize'] = 8
#bundle['xtick.labelsize'] = 8
#bundle['ytick.labelsize'] = 8
#bundle['axes.titlesize'] = 10
#plt.rcParams.update(bundle)
#plt.rcParams.update(figsizes.icml2022_half())




def prepare_params():
    if sys.argv[-1] == '0':
        print()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'param_files/default_params.json'), 'r') as f:
            params = json.load(f)
        orig_params, params = get_params(params)
    else:
        orig_params = read_params_from_cmdline()
        _, params = prepare_cluster(orig_params)
    return orig_params, params

def prepare_cluster(params):
    os.makedirs(params.working_dir, exist_ok=True)
    return get_params(params)


def get_params(params):
    orig_params = params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return orig_params, params

def get_controller(orig_params, params, env):
    controllers = {'DEP' : DEP,
                   'HomeoSplitBrain': HomeoKinesisSplitBrainController,
                   'Noise': NoiseController}
    # return controllers.get(params.general.controller)()
    return controllers.get(params.general.controller)(orig_params[params.general.controller], env.action_space.shape[0], env.action_space.shape[0], env)


def post_run(orig_params, avg_return, action_buff, state_buff):
    muscle = True if not 'torque' in orig_params['general']['env'] else False
    entropies = compute_entropy(state_buff, muscle)
    print_data(entropies, state_buff)
    if sys.argv[-1] == '0':
        pass
        # plot_trajs(action_buff, state_buff)
        # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
    else:
        metrics = {'avg_return': avg_return,
                   'entropy': np.mean(entropies)}
        save_metrics_params(metrics, orig_params)
    return entropies


if __name__ == '__main__':
    orig_params, params = prepare_params()
    print(f'Params are {orig_params}')
    #env = gym.make(params.general.env, identifier=params.id)
    #env.merge_args({"action_multiplier": 1})
    #env.apply_args()
    env = NeuripsGym(visualize=True)
    print(f'Env is {params.general.env}')
    print(env.action_space)
    # env = apply_wrappers(params.general.controller, env)
    ctrl = get_controller(orig_params, params, env)
    try:
        ctrl.initialize(env.observation_space, env.action_space)
    except:
        pass

    #mem = ReplayMemory(50000, 0, store_all_data=True)

    episodes = 25
    avg_return = 0
    state_buff = []
    action_buff = []
    action_buff_int = []
    full_states = []
    add = 0
    i_range = np.arange(6)
    np.random.shuffle(i_range)
    for ep in range(episodes):
        print(f'EP {ep} of {episodes}')
        ep_return = 0
        ep_steps = 0
        state = env.reset()
        state = env.muscles_dep
        # state = [state[0]]
        # full_state = env.full_state.copy()
        while True:
            action = ctrl.step(state)
            #action[:] = 1.0
            next_state, reward, terminal, info = env.step(action)
            next_state = env.muscles_dep.copy()
            #from pudb import set_trace
            #set_trace()
            #next_state = np.array([next_state[i] for i in i_range])
            #if ep > 10:
            #env.render()
            # next_state = next_state[0]
            state = next_state
            print(ep_steps)
            #state_buff.append(env.ee_state[:-1].copy())
            #next_full_state = env.full_state
            #full_states.append(next_full_state.copy())
            #action_buff.append(action.copy())
            #if not hasattr(env, 'internal_action'):
            #    add = 1
            #if add:
            #    env.internal_action = action.copy()
            #print(env.internal_action)
            #action_buff_int.append(env.internal_action.copy())

            ep_return += reward

            # reset = 1 if (ep_steps >= env.max_episode_steps or terminal) else 0
            reset = 1 if ep_steps >= 200 else 0
            # mem.push(env.full_state.copy(), action, reward, next_state, terminal, reset)
            # full_state = next_full_state

            # if ep_steps >= env.max_episode_steps:  #reset or terminal:
            if reset: # or terminal:
                avg_return += ep_return
                print(f'{avg_return / (ep+1)}')
                break
            ep_steps += 1

        if not ep % 5:

            ctrl.reset()
            # announce_fraction_finished(ep / episodes)
            # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
            pass
    print('actions')
    for i in range(action_buff[0].shape[0]):
        print(np.std([x[i] for x in action_buff]))
    print('internal actions')
    for i in range(action_buff_int[0].shape[0]):
        print(np.std([x[i] for x in action_buff_int]))
    avg_return /= episodes
    #actions = torch.as_tensor(np.array(action_buff).transpose())
    print(actions.shape)
    #corr = torch.corrcoef(actions)
    fig, ax = plt.subplots()
    ax.imshow(corr, cmap='gist_heat') 
    ax.set_xticks([0, 25, 49])
    ax.set_yticks([0, 25, 49])
    ax.set_xticklabels([1, 25, 50])
    ax.set_yticklabels([1, 25, 50])
    ax.set_xlabel('muscle i', fontsize=8)
    ax.set_ylabel('muscle j', fontsize=8)
    full_states = np.array(full_states)
    # plt.savefig("correlation_matrix.pdf", bbox_inches='tight',pad_inches = 0)
    entropies = post_run(orig_params, avg_return, action_buff, state_buff)
    # plot_ee_arm26(full_states, env, episodes=episodes, title='noise', name='noise')
    # post_run(orig_params, avg_return, action_buff, state_buff)
    # plot_ee_arm26(full_states, env, episodes=episodes, title='DEP', name='dep')


