import numpy as np
import os
from gym import utils
from gym.envs.mujoco import mujoco_env


# For the first 1k steps, this is the same, then it changes to the new mass and damping.
class HalfCheetahEnvChange(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(
            self, mass_scale_set=[0.75, 1.0, 1.25], damping_scale_set=[0.75, 1.0, 1.25], change_step=100,
    ):
        self.prev_qpos = None
        dir_path = os.path.dirname(os.path.realpath(__file__))
        mujoco_env.MujocoEnv.__init__(self, "%s/assets/half_cheetah.xml" % dir_path, 5)

        self.original_mass = np.copy(self.model.body_mass)
        self.original_damping = np.copy(self.model.dof_damping)

        self.mass_scale_set = mass_scale_set
        self.damping_scale_set = damping_scale_set

        self.change_step = change_step
        self.steps = 0
        self.changed = False

        utils.EzPickle.__init__(self, mass_scale_set, damping_scale_set)

    def _set_observation_space(self, observation):
        observation = self._get_obs(full=True)
        super(HalfCheetahEnvChange, self)._set_observation_space(observation)
        proc_observation = self.obs_preproc(observation[None])
        self.proc_observation_space_dims = proc_observation.shape[-1]

    def step(self, action):
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        self.do_simulation(action, self.frame_skip)
        ob = self._get_obs()

        reward_ctrl = -0.1 * np.square(action).sum()
        reward_run = (self.sim.data.qpos.flat[0] - self.prev_qpos[0]) / self.dt
        reward = reward_run + reward_ctrl

        done = False

        try:
            self.steps += 1
            if self.steps > self.change_step and not self.changed:
                print('Changing')

                # x, y = self.get_state()
                # self.set_state(x, y)

                self.mass_scale = 0.75
                self.damping_scale = 0.5

                self.change_env()
                self.changed = True
        except AttributeError:
            pass

        return ob, reward, done, {}

    def seed(self, seed=None):
        if seed is None:
            self._seed = 0
        else:
            self._seed = seed
        super().seed(seed)

    def _get_obs(self, full=False):
        if full:
            return np.concatenate(
                [
                    (self.sim.data.qpos.flat[:1] - self.prev_qpos[:1]) / self.dt,
                    self.sim.data.qpos.flat[1:],
                    self.sim.data.qvel.flat,
                ]
            )
        else:
            return np.concatenate(
                [
                    self.sim.data.qpos.flat[1:],
                    self.sim.data.qvel.flat,
                ]
            )

    def obs_preproc(self, obs):
        if isinstance(obs, np.ndarray):

            return np.concatenate(
                [
                    obs[..., 1:2],
                    np.sin(obs[..., 2:3]),
                    np.cos(obs[..., 2:3]),
                    obs[..., 3:],
                ],
                axis=-1,
            )
        else:
            raise NotImplementedError

    def obs_postproc(self, obs, pred):
        if isinstance(obs, np.ndarray):
            return np.concatenate(
                [pred[..., :1], obs[..., 1:] + pred[..., 1:]], axis=-1
            )
        else:
            raise NotImplementedError

    def targ_proc(self, obs, next_obs):
        return np.concatenate(
            [next_obs[..., :1], next_obs[..., 1:] - obs[..., 1:]], axis=-1
        )

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.normal(
            loc=0, scale=0.001, size=self.model.nq
        )
        qvel = self.init_qvel + self.np_random.normal(
            loc=0, scale=0.001, size=self.model.nv
        )
        self.set_state(qpos, qvel)
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)

        random_index = self.np_random.randint(len(self.mass_scale_set))
        self.mass_scale = self.mass_scale_set[random_index]

        random_index = self.np_random.randint(len(self.damping_scale_set))
        self.damping_scale = self.damping_scale_set[random_index]

        self.change_env()
        self.steps = 0
        self.changed = False
        return self._get_obs()

    def change_env(self):
        mass = np.copy(self.original_mass)
        damping = np.copy(self.original_damping)
        mass *= self.mass_scale
        damping *= self.damping_scale

        self.model.body_mass[:] = mass
        self.model.dof_damping[:] = damping

    def change_mass(self, mass):
        self.mass_scale = mass

    def change_damping(self, damping):
        self.damping_scale = damping

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.25
        self.viewer.cam.elevation = -55

    def get_sim_parameters(self):
        training_mass_set = [0.25, 0.5, 1.5, 2.5]
        mass_avg = np.mean(training_mass_set)
        mass_std = np.std(training_mass_set)
        normalized_mass = (self.mass_scale - mass_avg) / mass_std
        return np.array([normalized_mass, self.damping_scale])

    def num_modifiable_parameters(self):
        return 2

    def log_diagnostics(self, paths, prefix):
        return
