import re

import gym
import numpy as np
from dm_control.utils import rewards

_WALKER_STAND_HEIGHT = 1.2
_WALKER_WALK_SPEED = 1
_WALKER_RUN_SPEED = 8
_CHEETAH_RUN_SPEED = 10
_QUADRUPED_RUN_SPEED = 5
_QUADRUPED_WALK_SPEED = 0.5
_HUMANOID_STAND_HEIGHT = 1.4

dmc_id_regex = r"dmc_(?P<env_name>[\w_]+)_[0-9]+-v[0-9]+"


def reward_components_walker(env, move_speed):
    standing = rewards.tolerance(
        env.physics.torso_height(),
        bounds=(_WALKER_STAND_HEIGHT, float("inf")),
        margin=_WALKER_STAND_HEIGHT / 2,
    )
    upright = (1 + env.physics.torso_upright()) / 2
    stand_reward = (3 * standing + upright) / 4
    move_reward = rewards.tolerance(
        env.physics.horizontal_velocity(),
        bounds=(move_speed, float("inf")),
        margin=move_speed / 2,
        value_at_margin=0.5,
        sigmoid="linear",
    )

    components = {
        "move": move_reward,
        "stand": stand_reward,
    }

    components["reward"] = ground_truth_func(env)(components)
    return components


def reward_components_cheetah_run(env):
    speed_reward = rewards.tolerance(
        env.physics.speed(),
        bounds=(_CHEETAH_RUN_SPEED, float("inf")),
        margin=_CHEETAH_RUN_SPEED,
        value_at_margin=0,
        sigmoid="linear",
    )

    # Cheetah only has one reward component. Add a pseudoreward based on control cost,
    # to allow us to plot trajectories on a 2d plane.
    ctrl_pseudoreward = -np.sum(env.physics.control() ** 2)

    components = {
        "speed": speed_reward,
        "ctrl": ctrl_pseudoreward,
    }
    components["reward"] = ground_truth_func(env)(components)
    return components


def reward_components_quadruped(env, desired_speed):
    deviation = 1.0

    upright_reward = rewards.tolerance(
        env.physics.torso_upright(),
        bounds=(deviation, float("inf")),
        sigmoid="linear",
        margin=1 + deviation,
        value_at_margin=0,
    )

    move_reward = rewards.tolerance(
        env.physics.torso_velocity()[0],
        bounds=(desired_speed, float("inf")),
        margin=desired_speed,
        value_at_margin=0.5,
        sigmoid="linear",
    )

    components = {
        "move": move_reward,
        "upright": upright_reward,
    }

    components["reward"] = ground_truth_func(env)(components)
    return components


def reward_components_humanoid_stand(env):
    standing = rewards.tolerance(
        env.physics.head_height(),
        bounds=(_HUMANOID_STAND_HEIGHT, float("inf")),
        margin=_HUMANOID_STAND_HEIGHT / 4,
    )

    upright = rewards.tolerance(
        env.physics.torso_upright(),
        bounds=(0.9, float("inf")),
        sigmoid="linear",
        margin=1.9,
        value_at_margin=0,
    )

    stand_reward = standing * upright
    small_control = rewards.tolerance(
        env.physics.control(), margin=1, value_at_margin=0, sigmoid="quadratic"
    ).mean()  # type: ignore
    small_control = (4 + small_control) / 5

    horizontal_velocity = env.physics.center_of_mass_velocity()[[0, 1]]
    dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()  # type: ignore

    components = {
        "small_control": small_control,
        "stand": stand_reward,
        "dont_move": dont_move,
        "reward": small_control * stand_reward * dont_move,
    }

    components["reward"] = ground_truth_func(env)(components)
    return components


ground_truth_rewards_dispatch = {
    "walker_walk": lambda d: d["stand"] * (5 * d["move"] + 1) / 6,
    "walker_run": lambda d: d["stand"] * (5 * d["move"] + 1) / 6,
    "cheetah_run": lambda d: d["speed"],
    "quadruped_walk": lambda d: d["upright"] * d["move"],
    "quadruped_run": lambda d: d["upright"] * d["move"],
    "humanoid_stand": lambda d: d["small_control"] * d["stand"] * d["dont_move"],
}

reward_component_dispatch = {
    "walker_walk": (reward_components_walker, _WALKER_WALK_SPEED),
    "walker_run": (reward_components_walker, _WALKER_RUN_SPEED),
    "cheetah_run": (reward_components_cheetah_run,),
    "quadruped_walk": (reward_components_quadruped, _QUADRUPED_WALK_SPEED),
    "quadruped_run": (reward_components_quadruped, _QUADRUPED_RUN_SPEED),
    "humanoid_stand": (reward_components_humanoid_stand,),
}


def ground_truth_func(env):
    match = re.fullmatch(dmc_id_regex, env.spec.id)
    assert match is not None

    env_name = match.group("env_name")

    if env_name not in reward_component_dispatch:
        raise ValueError(f"Unsupported env {env_name}")

    return ground_truth_rewards_dispatch[env_name]


def reward_components(env):
    match = re.fullmatch(dmc_id_regex, env.spec.id)
    assert match is not None

    env_name = match.group("env_name")

    if env_name not in reward_component_dispatch:
        raise ValueError(f"Unsupported env {env_name}")

    dispatch_func, *dispatch_args = reward_component_dispatch[env_name]
    return dispatch_func(env, *dispatch_args)


class RewardComponentsWrapper(gym.Wrapper):
    """Add the components of the reward to the info dict."""

    def __init__(self, env):
        match = re.fullmatch(dmc_id_regex, env.spec.id)
        assert match is not None

        env_name = match.group("env_name")

        if env_name not in reward_component_dispatch:
            raise ValueError(f"Unsupported env {env_name}")

        super().__init__(env)

        # Propagate the _max_episode_steps property
        if hasattr(env, "_max_episode_steps"):
            self._max_episode_steps = self.env._max_episode_steps


    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        info["reward_components"] = reward_components(self.env)
        return obs, rew, done, info
