import brax
from brax.envs.ant import _SYSTEM_CONFIG as ANT_CONFIG
from brax.envs.ant import Ant
from brax.envs.halfcheetah import _SYSTEM_CONFIG as HALFCHEETAH_CONFIG
from brax.envs.halfcheetah import Halfcheetah
from brax.envs.hopper import _SYSTEM_CONFIG as HOPPER_CONFIG
from brax.envs.hopper import Hopper
from brax.envs.humanoid import _SYSTEM_CONFIG as HUMANOID_CONFIG
from brax.envs.humanoid import Humanoid
from brax.envs.inverted_pendulum import _SYSTEM_CONFIG as INVERTED_CONFIG
from brax.envs.inverted_pendulum import InvertedPendulum
from brax.envs.walker2d import _SYSTEM_CONFIG as WALKER_CONFIG
from brax.envs.walker2d import Walker2d
from brax.physics import config_pb2
from google.protobuf import text_format


def _change_config(
    config: config_pb2, friction_coef: float = 1.0, mass_coef: float = 1.0
) -> config_pb2:
    """Change default friction and mass

    Args:
        config (config_pb2): `Brax` config
        friction_coef (float, optional): Multiply the friction by `friction_mass`. Defaults to 1.0.
        mass_coef (float, optional): Multiply the friction `mass_coef`. Defaults to 1.0.

    Returns:
        config_pb2: [description]
    """
    config.friction *= friction_coef

    for i in range(len(config.bodies)):
        config.bodies[i].mass *= mass_coef

    return config


class BraxDynamicsAnt(Ant):
    def __init__(self, friction_coef: float = 1.0, mass_coef: float = 1.0, **kwargs):
        config = text_format.Parse(ANT_CONFIG, brax.Config())
        custom_config = _change_config(
            config=config, friction_coef=friction_coef, mass_coef=mass_coef
        )
        self.sys = brax.System(custom_config)


class BraxDynamicsHumanoid(Humanoid):
    def __init__(self, friction_coef: float = 1.0, mass_coef: float = 1.0, **kwargs):
        config = text_format.Parse(HUMANOID_CONFIG, brax.Config())
        custom_config = _change_config(
            config=config, friction_coef=friction_coef, mass_coef=mass_coef
        )
        self.sys = brax.System(custom_config)


class BraxDynamicsHalfcheetah(Halfcheetah):
    def __init__(self, friction_coef: float = 1.0, mass_coef: float = 1.0, **kwargs):
        config = text_format.Parse(HALFCHEETAH_CONFIG, brax.Config())
        custom_config = _change_config(
            config=config, friction_coef=friction_coef, mass_coef=mass_coef
        )
        self.sys = brax.System(custom_config)


class BraxDynamicsHopper(Hopper):
    def __init__(self, friction_coef: float = 1.0, mass_coef: float = 1.0, **kwargs):
        config = text_format.Parse(HOPPER_CONFIG, brax.Config())
        custom_config = _change_config(
            config=config, friction_coef=friction_coef, mass_coef=mass_coef
        )

        self.sys = brax.System(custom_config)


class BraxDynamicsWalker(Walker2d):
    def __init__(self, friction_coef: float = 1.0, mass_coef: float = 1.0, **kwargs):
        config = text_format.Parse(WALKER_CONFIG, brax.Config())
        custom_config = _change_config(
            config=config, friction_coef=friction_coef, mass_coef=mass_coef
        )

        self.sys = brax.System(custom_config)


class BraxDynamicsInvertedPendulum(InvertedPendulum):
    def __init__(self, friction_coef: float = 1.0, mass_coef: float = 1.0, **kwargs):
        config = text_format.Parse(INVERTED_CONFIG, brax.Config())
        custom_config = _change_config(
            config=config, friction_coef=friction_coef, mass_coef=mass_coef
        )

        self.sys = brax.System(custom_config)
