import time
import itertools
import numpy as np
from mujoco_py import MjViewer
from mushroom_rl.environments.mujoco_envs.humanoids.reward_goals import EKEYS, QKEYS


def determine_dims(env_id):

    # determine the dimensionality of qpos and qvel depending on the environment
    if env_id == "HalfCheetah-v2" or env_id == "Walker2d-v2":
        qpos_start = 1
        size_obs_space = 8
        size_act_space = 9
    elif env_id == "Hopper-v2":
        qpos_start = 1
        size_obs_space = 5
        size_act_space = 6
    elif env_id == "Humanoid-v2" or env_id == "HumanoidStandup-v2":
        qpos_start = 2
        size_obs_space = 22
        size_act_space = 23
    elif env_id == "Ant-v2":
        qpos_start = 2
        size_obs_space = 13
        size_act_space = 14
    else:
        raise NotImplementedError("The environment %s is not supported for the player." % env_id)

    return qpos_start, size_obs_space, size_act_space


def replay_expert_trajectories_stepwise_data(expert_dataset, env_id, mdp):
    # get a viewer for rendering
    viewer = MjViewer(mdp.env.sim)
    xpos = 0

    # determine the dimensionality of qpos and qvel depending on the environment
    qpos_start, size_obs_space, size_act_space = determine_dims(env_id)

    states = expert_dataset['states']
    starts = expert_dataset['episode_starts']

    it_states = iter(states)
    it_starts = iter(starts)

    for state, start in zip(it_states, it_starts):
        sim_state = mdp.env.sim.get_state()
        sim_state.qpos[0] = xpos
        sim_state.qpos[qpos_start:] = state[0:size_obs_space]
        sim_state.qvel[:] = state[size_obs_space:(size_obs_space+size_act_space)]
        mdp.env.sim.set_state(sim_state)
        mdp.env.sim.forward()
        viewer.render()
        dx = state[size_obs_space] * mdp.env.dt
        xpos += dx
        viewer.cam.lookat[0] += dx
        time.sleep(mdp.env.dt)
        if start:
            xpos = 0
            viewer.cam.lookat[0] = xpos
            # get the next state and starts to through them away
            next(it_states)
            next(it_starts)
            print('Resetting episode ...')
            time.sleep(1)


def replay_expert_trajectories_episodewise_data(expert_dataset, env_id, mdp, order=None):
    # get a viewer for rendering
    viewer = MjViewer(mdp.env.sim)
    xpos = 0

    # determine the dimensionality of qpos and qvel depending on the environment
    qpos_start, size_obs_space, size_act_space = determine_dims(env_id)

    eps_states = np.array(expert_dataset['states'])
    eps_returns = np.array(expert_dataset['episode_returns'])

    # check if order of episodes (idxs) is provided, if not random shuffle
    if order is not None:
        eps_states = eps_states[order]
        eps_returns = eps_returns[order]
    else:
        s = np.arange(len(eps_states))
        np.random.shuffle(s)
        eps_states = eps_states[s]
        eps_returns = eps_returns[s]

    for single_eps_states, single_eps_return in zip(eps_states, eps_returns):
        for state in single_eps_states:
            sim_state = mdp.env.sim.get_state()
            sim_state.qpos[0] = xpos
            sim_state.qpos[qpos_start:] = state[0:size_obs_space]
            sim_state.qvel[:] = state[size_obs_space:(size_obs_space + size_act_space)]
            mdp.env.sim.set_state(sim_state)
            mdp.env.sim.forward()
            viewer.render()
            dx = state[size_obs_space] * mdp.env.dt
            xpos += dx
            viewer.cam.lookat[0] += dx
            time.sleep(mdp.env.dt)
        xpos = 0
        viewer.cam.lookat[0] = xpos

        print('Resetting episode ...')
        time.sleep(1)


def prepare_expert_data(data_path):
    dataset = dict()

    # load expert training data
    expert_files = np.load(data_path)
    dataset["states"] = expert_files["states"]
    dataset["actions"] = expert_files["actions"]
    dataset["episode_starts"] = expert_files["episode_starts"]

    # maybe we have next action and next next state
    try:
        dataset["next_actions"] = expert_files["next_actions"]
        dataset["next_next_states"] = expert_files["next_next_states"]
    except KeyError as e:
        print("Did not find next action or next next state.")

    # maybe we have next states and dones in the dataset
    try:
        dataset["next_states"] = expert_files["next_states"]
        dataset["absorbing"] = expert_files["absorbing"]
    except KeyError as e:
        print("Warning Dataset: %s" % e)

    # maybe we have episode returns, if yes done
    try:
        dataset["episode_returns"] = expert_files["episode_returns"]
        return dataset
    except KeyError:
        print("Warning Dataset: No episode returns. Falling back to step-based reward.")

    # this has to work
    try:
        dataset["rewards"] = expert_files["rewards"]
        return dataset
    except KeyError:
        raise KeyError("The dataset has neither an episode nor a step-based reward!")

