from typing import Dict, List
from copy import deepcopy

import numpy as np
from gym import Wrapper
from gym.envs.mujoco.mujoco_env import MujocoEnv


class MujocoDynamicsWrapper(Wrapper):
    """Dynamics shifts along the lines of Pinto et al [1].

    Args:
        env (MujocoEnv): `Mujoco` Environment from `gym`
        friction_coef (float, optional): Multiplication coefficient for the friction. Defaults to 1.0. #noqa: E501
        mass_coef (float, optional): Multiplication coefficient for the mass. Defaults to 1.0.

    """

    def __init__(
        self, env: MujocoEnv, friction_coef: float = 1.0, mass_coef: float = 1.0
    ) -> None:
        super().__init__(env)
        # AttributeError: attribute 'geom_friction' of 'mujoco_py.cymj.PyMjModel' objects is not writable #noqa: E501
        # Inplace multiplication
        self.env.model.geom_friction[:] *= friction_coef
        self.env.model.body_mass[:] *= mass_coef
        self.friction_coef = friction_coef
        self.mass_coef = mass_coef


class DomainRandomizationMujocoWrapper(Wrapper):
    def __init__(self, env: MujocoEnv, params_bound: Dict[str, List[float]]):
        super().__init__(env)
        self.params_bound = params_bound
        self.base_friction = deepcopy(env.model.geom_friction)
        self.base_mass = deepcopy(env.model.body_mass)

    def reset(self):
        self.domain_randomization_mujoco()
        return self.env.reset()

    def domain_randomization_mujoco(self):
        self.env.model.geom_friction[:] = self.base_friction
        self.env.model.body_mass[:] = self.base_mass
        random_friction = np.random.uniform(
            low=self.params_bound["friction_coef"][0],
            high=self.params_bound["friction_coef"][1],
        )
        random_mass = np.random.uniform(
            low=self.params_bound["mass_coef"][0],
            high=self.params_bound["mass_coef"][1],
        )

        self.env.model.geom_friction[:] *= random_friction
        self.env.model.body_mass[:] *= random_mass
