import numpy as np
import os, sys
import gym
import warmup
from types import SimpleNamespace
import json
import torch

from utils import compute_entropy, plot_trajs, print_data, plot_ee_arm26
from replay_memory import ReplayMemory
from controllers import HomeoKinesisSplitBrainController
from controllers import DEP, NoiseController
from cluster import read_params_from_cmdline, save_metrics_params, announce_fraction_finished, \
    announce_early_results, exit_for_resume
from env_wrappers import apply_wrappers
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes
from pudb import set_trace
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_half())
bundle = bundles.icml2022()
bundle['font.size'] = 14
bundle['axes.labelsize'] = 10
bundle['legend.fontsize'] = 8
bundle['xtick.labelsize'] = 8
bundle['ytick.labelsize'] = 8
bundle['axes.titlesize'] = 10
plt.rcParams.update(bundle)
plt.rcParams.update(figsizes.icml2022_half())

#ACT_MULT = [1, 6, 15, 30, 300]
#ACT_MULT = [1, 6, 15]
ACT_MULT = [1, 2, 5, 10, 100]
#ACT_MULT = [1, 10]
def prepare_params():
    if sys.argv[-1] == '0':
        print()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'param_files/default_params.json'), 'r') as f:
            params = json.load(f)
        orig_params, params = get_params(params)
    else:
        orig_params = read_params_from_cmdline()
        _, params = prepare_cluster(orig_params)
    return orig_params, params

def prepare_cluster(params):
    os.makedirs(params.working_dir, exist_ok=True)
    return get_params(params)


def get_params(params):
    orig_params = params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return orig_params, params

def get_controller(orig_params, params, env):
    controllers = {'DEP' : DEP,
                   'HomeoSplitBrain': HomeoKinesisSplitBrainController,
                   'Noise': NoiseController}
    # return controllers.get(params.general.controller)()
    return controllers.get(params.general.controller)(orig_params[params.general.controller], env.action_space.shape[0], env.action_space.shape[0], env)


def post_run(orig_params, avg_return, action_buff, state_buff, typed):
    muscle = True if not 'torque' in orig_params['general']['env'] else False
    entropies = compute_entropy(state_buff, typed, muscle)
    print_data(entropies, state_buff)
    if sys.argv[-1] == '0':
        pass
        # plot_trajs(action_buff, state_buff)
        # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
    else:
        metrics = {'avg_return': avg_return,
                   'entropy': np.mean(entropies)}
        save_metrics_params(metrics, orig_params)
    return entropies

def entropy_loop(orig_params, params, typed="", episodes=50, save=False):
    print(f'Params are {orig_params}')
    env = gym.make(params.general.env, identifier=params.id)
    env.merge_args({"action_multiplier": act_mult})
    env.apply_args()
    print(f'Env is {params.general.env}')
    # env = apply_wrappers(params.general.controller, env)
    ctrl = get_controller(orig_params, params, env)
    try:
        ctrl.initialize(env.observation_space, env.action_space)
    except:
        pass

    avg_return = 0
    state_buff = []
    action_buff = []
    full_states = []
    for ep in range(episodes):
        print(f'EP {ep} of {episodes}')
        ep_return = 0
        ep_steps = 0
        state = env.reset()
        state = env.muscles_dep
        # state = [state[0]]
        # full_state = env.full_state.copy()
        while True:
            action = ctrl.step(state)
            next_state, reward, terminal, info = env.step(action)
            next_state = env.muscles_dep
            # env.render()
            state = next_state
            next_full_state = env.full_state
            full_states.append(next_full_state.copy())
            #state_buff.append(env.joint_state.copy())
            state_buff.append(env.ee_state[:-1].copy())
            action_buff.append(action.copy())

            ep_return += reward

            # reset = 1 if (ep_steps >= env.max_episode_steps or terminal) else 0
            reset = 1 if ep_steps >= 1000 else 0

            if reset:  # or terminal:
                avg_return += ep_return
                print(f'{avg_return / (ep + 1)}')
                break
            ep_steps += 1

        ctrl.reset()
        if not ep % 5:
            ctrl.reset()
            # announce_fraction_finished(ep / episodes)
            # torch.save(mem.buffer_full, f'memory_dep_arm750.pt')
            pass

    avg_return /= episodes
    actions = torch.as_tensor(np.array(action_buff).transpose())
    full_states = np.array(full_states)

    entropies = post_run(orig_params, avg_return, action_buff, state_buff, typed)
    filler = '_' + params.Noise.noise_type if params.general.controller == 'Noise' else ""
    prefix = 'torque' if 'torque' in params.general.env else 'muscle'
    if save:
        name = f'{prefix}_entropies_{params.general.controller}_act_mult{act_mult}{filler}'
        np.save(f'./entropy_exp/{name}.npy', entropies)
        np.save(f'./entropy_exp/state_trajs/{name}.npy', full_states)
    return np.mean(entropies)
    # plot_ee_arm26(full_states, env, episodes=episodes, title='noise', name=name)

def give_tuned_scale(orig_params, params, episodes):
    entropy_vals = []
    #scales = np.linspace(10, 0.0001, 10)
    scales = np.linspace(0.001, 50, 100)
    for scale in scales:
        orig_params['Noise']['noise_scale_colored'] = scale
        params.Noise.noise_scale_colored = scale
        ent = entropy_loop(orig_params, params, episodes=episodes)
        entropy_vals.append(ent)
    idx = np.argmax(entropy_vals)
    #scales[idx] = 0.5
    orig_params['Noise']['noise_scale_colored'] = scales[idx]
    params.Noise.noise_scale_colored = scales[idx]
    return orig_params, params

def give_tuned_ou(orig_params, params, episodes):
    entropy_vals = []
    N = 10
    scales = np.linspace(10, 0.0001, N)
    sigma_ou = np.linspace(0, 2, N)
    thetas = np.linspace(0, 2, N)
    for scale in scales:
        for sig in sigma_ou:
            for theta in thetas:
                orig_params['Noise']['sigma'] = scale
                orig_params['Noise']['sigma_ou'] = sig
                orig_params['Noise']['theta'] = theta
                params.Noise.sigma = scale
                params.Noise.sigma_ou = sig
                params.Noise.theta = theta
                ent = entropy_loop(orig_params, params, episodes=episodes)
                entropy_vals.append((ent, scale, sig, theta))
    idx = np.argmax([x[0] for x in entropy_vals])
    orig_params['Noise']['sigma'] = entropy_vals[idx][1]
    orig_params['Noise']['sigma_ou'] = entropy_vals[idx][1]
    orig_params['Noise']['theta'] = entropy_vals[idx][1]
    params.Noise.sigma = entropy_vals[idx][1]
    params.Noise.sigma_ou = entropy_vals[idx][1]
    params.Noise.theta = entropy_vals[idx][1]
    return orig_params, params

if __name__ == '__main__':
    orig_params, params = prepare_params()
    for act_mult in ACT_MULT:
        for control in ['Noise', 'DEP']:
        #for control in ['DEP']:
            params.general.controller = control
            orig_params['general']['controller'] = control
            if control == 'Noise':
                for noise_type in ['colored_0', 'colored_1', 'colored_2', 'ou']:
                #for noise_type in ['colored_0', 'colored_1', 'colored_2']:
                    typed = noise_type
                    print(typed)
                    typed = typed + f'_actmult_{act_mult}'
                    params.Noise.noise_type = noise_type
                    orig_params['Noise']['noise_type'] = noise_type
                    if 'colored' in noise_type:
                        print(f'Tuning {noise_type}')
                        orig_params, params = give_tuned_scale(orig_params, params, episodes=5)
                    if 'ou' in noise_type:
                        print('Tuning OU')
                        orig_params, params = give_tuned_ou(orig_params, params, episodes=5)
                    print(f'For type {params.Noise.noise_type}')
                    print(f'scale {params.Noise}')

                    entropy_loop(orig_params, params, typed, save=True)
            else:
                typed = 'dep' 
                typed = typed + f'_act_mult_{act_mult}'
                entropy_loop(orig_params, params, typed, save=True)


