import numpy as np
from gym.envs.mujoco import Walker2dEnv
from . import register_env

@register_env('walker-rand-params')
class Walker2DRandParamsEnv(Walker2dEnv):
    """
    This class provides functionality for randomizing the physical parameters of a mujoco model
    The following parameters are changed:
        - body_mass
        - body_inertia
        - damping coeff at the joints
    """
    RAND_PARAMS = ['body_mass', 'dof_damping', 'body_inertia', 'geom_friction']
    RAND_PARAMS_EXTENDED = RAND_PARAMS + ['geom_size']
    
    def __init__(self, n_tasks=2, randomize_tasks=True, max_episode_steps=200, log_scale_limit=3.0, rand_params=RAND_PARAMS):
        
        self.randomize_tasks = randomize_tasks
        self._max_episode_steps = max_episode_steps
        self._step = 0
        
        super().__init__()
        self.log_scale_limit = log_scale_limit            
        self.rand_params = rand_params
        self.save_parameters()

        self.tasks = self.sample_tasks(n_tasks)
        self.reset_task(0)


    def step(self, a):
        posbefore = self.sim.data.qpos[0]
        self.do_simulation(a, self.frame_skip)
        posafter, height, ang = self.sim.data.qpos[0:3]
        alive_bonus = 1.0
        reward = ((posafter - posbefore) / self.dt)
        reward += alive_bonus
        reward -= 1e-3 * np.square(a).sum()
        done = not (height > 0.8 and height < 2.0 and
                    ang > -1.0 and ang < 1.0)
        ob = self._get_obs()

        done = False
        self._step += 1
        if self._step >= self._max_episode_steps:
            done = True
        return ob, reward, done, {}

    def _get_obs(self):
        qpos = self.sim.data.qpos
        qvel = self.sim.data.qvel
        return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel()

    def reset_model(self):
        self.set_state(
            self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq),
            self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
        )
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 2
        self.viewer.cam.distance = self.sim.stat.extent * 0.5
        self.viewer.cam.lookat[2] += .8
        self.viewer.cam.elevation = -20

        
    def sample_tasks(self, n_tasks):
        """
        Generates randomized parameter sets for the mujoco env

        Args:
            n_tasks (int) : number of different meta-tasks needed

        Returns:
            tasks (list) : an (n_tasks) length list of tasks
        """
        param_sets = []

        if self.randomize_tasks:
            for _ in range(n_tasks):
                new_params = {}

                # body mass -> one multiplier for all body parts
                if 'body_mass' in self.rand_params:
                    body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit,  size=self.model.body_mass.shape)
                    new_params['body_mass'] = self.init_params['body_mass'] * body_mass_multiplyers

                # body_inertia
                if 'body_inertia' in self.rand_params:
                    body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit,  size=self.model.body_inertia.shape)
                    new_params['body_inertia'] = body_inertia_multiplyers * self.init_params['body_inertia']

                # damping -> different multiplier for different dofs/joints
                if 'dof_damping' in self.rand_params:
                    dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.dof_damping.shape)
                    new_params['dof_damping'] = np.multiply(self.init_params['dof_damping'], dof_damping_multipliers)

                # friction at the body components
                if 'geom_friction' in self.rand_params:
                    dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.geom_friction.shape)
                    new_params['geom_friction'] = np.multiply(self.init_params['geom_friction'], dof_damping_multipliers)

                param_sets.append(new_params)
        else:
            raise NotImplemented

        return param_sets

    def set_task(self, task):
        for param, param_val in task.items():
            param_variable = getattr(self.model, param)
            assert param_variable.shape == param_val.shape, 'shapes of new parameter value and old one must match'
            param_variable[:] = param_val
        self.cur_params = task

    def get_task(self):
        return self.cur_params

    def save_parameters(self):
        self.init_params = {}
        if 'body_mass' in self.rand_params:
            self.init_params['body_mass'] = self.model.body_mass

        # body_inertia
        if 'body_inertia' in self.rand_params:
            self.init_params['body_inertia'] = self.model.body_inertia

        # damping -> different multiplier for different dofs/joints
        if 'dof_damping' in self.rand_params:
            self.init_params['dof_damping'] = self.model.dof_damping

        # friction at the body components
        if 'geom_friction' in self.rand_params:
            self.init_params['geom_friction'] = self.model.geom_friction
        self.cur_params = self.init_params        
        
    def get_all_task_idx(self):
        return range(len(self.tasks))

    def reset_task(self, idx):
        self._goal_idx = idx
        self._task = self.tasks[idx]
        self._goal = idx
        self.set_task(self._task)
        self.reset()


@register_env('sparse-walker-rand-params')
class SparseWalkerRandParamsWrappedEnv(Walker2DRandParamsEnv):
    def __init__(self, n_tasks=2, randomize_tasks=True, max_episode_steps=200, goal_radius=0.5):
        self.goal_radius = goal_radius
        super(SparseWalkerRandParamsWrappedEnv, self).__init__(n_tasks, randomize_tasks, max_episode_steps)

    def step(self, action):
        ob, reward, done, d = super().step(action)
        sparse_reward = self.sparsify_rewards(reward)
        #if reward >= self.goal_radius:
        #    sparse_reward += 1
        d.update({'sparse_reward': sparse_reward})
        return ob, reward, done, d

    def sparsify_rewards(self, r):
        ''' zero out rewards when outside the goal radius '''
        # mask = (r >= self.goal_radius)
        # r = r * mask
        return r

if __name__ == "__main__":

    env = SparseWalkerRandParamsWrappedEnv()
    tasks = env.sample_tasks(40)
    while True:
        env.reset()
        env.set_task(np.random.choice(tasks))
        print(env.model.body_mass)
        for _ in range(100):
            # env.render()
            env.step(env.action_space.sample())  # take a random action