import numpy as np
import os
# import tensorflow as tf
from gym import utils
from gym.envs.mujoco import mujoco_env


class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(
        self, mass_scale_set=[0.75, 1.0, 1.25], damping_scale_set=[0.75, 1.0, 1.25]
    ):
        self.prev_qpos = None
        dir_path = os.path.dirname(os.path.realpath(__file__))
        mujoco_env.MujocoEnv.__init__(self, "%s/assets/walker2d.xml" % dir_path, 5)

        self.original_mass = np.copy(self.model.body_mass)
        self.original_damping = np.copy(self.model.dof_damping)

        self.mass_scale_set = mass_scale_set
        self.damping_scale_set = damping_scale_set

        utils.EzPickle.__init__(self, mass_scale_set, damping_scale_set)

    def _set_observation_space(self, observation):
        super(Walker2dEnv, self)._set_observation_space(observation)
        proc_observation = self.obs_preproc(observation[None])
        self.proc_observation_space_dims = proc_observation.shape[-1]

    def step(self, action):
        posbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        posafter, height, ang = self.sim.data.qpos[0:3]
        alive_bonus = 1.0
        reward = ((posafter - posbefore) / self.dt)
        reward += alive_bonus
        reward -= 1e-3 * np.square(action).sum()
        done = not (height > 0.8 and height < 2.0 and
                    ang > -1.0 and ang < 1.0)
        ob = self._get_obs()
        return ob, reward, done, {}

    def seed(self, seed=None):
        if seed is None:
            self._seed = 0
        else:
            self._seed = seed
        super().seed(seed)

    def _get_obs(self, full=False):
        if full:
            raise NotImplementedError
        else:
            qpos = self.sim.data.qpos
            qvel = self.sim.data.qvel
            return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel()

    def obs_preproc(self, obs):
        if isinstance(obs, np.ndarray):

            return obs
        else:
            raise NotImplementedError

    def obs_postproc(self, obs, pred):
        if isinstance(obs, np.ndarray):
            return obs + pred
        else:
            raise NotImplementedError
            # return tf.concat([pred[..., :1], obs[..., 1:] + pred[..., 1:]], axis=-1)

    def targ_proc(self, obs, next_obs):
        return next_obs - obs

    def reset_model(self):
        self.set_state(
            self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq),
            self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
        )

        self.prev_qpos = np.copy(self.sim.data.qpos.flat)

        random_index = self.np_random.randint(len(self.mass_scale_set))
        self.mass_scale = self.mass_scale_set[random_index]

        random_index = self.np_random.randint(len(self.damping_scale_set))
        self.damping_scale = self.damping_scale_set[random_index]

        self.change_env()
        return self._get_obs()

    # def reward(self, obs, action, next_obs):
        # ctrl_cost = 1e-1 * np.sum(np.square(action), axis=-1)
        # forward_reward = obs[..., 0]
        # reward = forward_reward - ctrl_cost
        # return reward

    # def tf_reward_fn(self):
    #     def _thunk(obs, act, next_obs):
    #         ctrl_cost = 1e-1 * tf.reduce_sum(tf.square(act), axis=-1)
    #         forward_reward = obs[..., 0]
    #         reward = forward_reward - ctrl_cost
    #         return reward

    #     return _thunk

    def change_env(self):
        mass = np.copy(self.original_mass)
        damping = np.copy(self.original_damping)
        mass *= self.mass_scale
        damping *= self.damping_scale

        self.model.body_mass[:] = mass
        self.model.dof_damping[:] = damping

    def change_mass(self, mass):
        self.mass_scale = mass

    def change_damping(self, damping):
        self.damping_scale = damping

    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 2
        self.viewer.cam.distance = self.model.stat.extent * 0.5
        self.viewer.cam.lookat[2] = 1.15
        self.viewer.cam.elevation = -20

    def get_sim_parameters(self):
        training_mass_set = [0.25, 0.5, 1.5, 2.5]
        mass_avg = np.mean(training_mass_set)
        mass_std = np.std(training_mass_set)
        normalized_mass = (self.mass_scale - mass_avg) / mass_std
        return np.array([normalized_mass, self.damping_scale])

    def num_modifiable_parameters(self):
        return 2

    def log_diagnostics(self, paths, prefix):
        return
