import gymnasium as gym
import numpy as np


class DistractorWrapper(gym.Wrapper):
    """Adds a distractor to the states. Distractor remains fixed throughout the episode and changes only on resets."""

    def __init__(self, env: gym.Env, num_extra_dims: int):
        super().__init__(env)
        self.num_extra_dims = num_extra_dims
        self.observation_space = gym.spaces.Box(
            low=np.concatenate(
                [
                    self.observation_space.low,
                    -np.inf * np.ones(self.num_extra_dims, dtype=np.float32),
                ]
            ),
            high=np.concatenate(
                [
                    self.observation_space.high,
                    np.inf * np.ones(self.num_extra_dims, dtype=np.float32),
                ]
            ),
            dtype=np.float32,
        )

    def _transform_obs(self, obs, distractor):
        return np.concatenate([obs, distractor])

    def reset(self, **kwargs):
        """Reset the environment and sample a new distractor vector"""
        self.distractor = 5 * np.random.randn(self.num_extra_dims)
        obs, info = self.env.reset(**kwargs)
        obs = self._transform_obs(obs, self.distractor)
        return obs, info

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)
        obs = self._transform_obs(obs, self.distractor)

        return obs, reward, terminated, truncated, info


class FrictionScalingWrapper(gym.Wrapper):
    """
    Wrapper that scales the friction of all Box2D bodies in the environment.
    """

    def __init__(self, env: gym.Env, friction_scale: float):
        """
        Args:
            env: The Gymnasium BipedalWalker environment.
            friction_scale (float): Factor to multiply each fixture's friction.
                                    For example, 0.5 will halve the friction.
        """
        super().__init__(env)
        self.friction_scale = friction_scale

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)

        # Check that the environment has created its Box2D world.
        world = self.env.unwrapped.world
        if world is not None:
            # Scale the friction of every fixture in every body.
            for body in world.bodies:
                for fixture in body.fixtures:
                    fixture.friction = fixture.friction * self.friction_scale
            # Ensure that all active contacts use the new friction values.
            for contact in world.contacts:
                contact.ResetFriction()
        return obs, info

    def step(self, action):
        return self.env.step(action)


class MassScalingWrapper(gym.Wrapper):
    """
    A wrapper that scales the mass of the robot's components (the hull and legs) in
    the BipedalWalker environment. It does so by scaling each fixture's density and then
    calling ResetMassData() on the affected bodies.
    """

    def __init__(self, env: gym.Env, mass_scale: float):
        """
        Args:
            env: The Gymnasium BipedalWalker environment.
            mass_scale (float): Factor to multiply each fixture's density by.
                                For example, 2.0 doubles the mass.
        """
        super().__init__(env)
        self.mass_scale = mass_scale

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)

        # Collect the robot components: the hull and the leg bodies.
        robot_bodies = []
        if hasattr(self.env.unwrapped, "hull") and self.env.unwrapped.hull is not None:
            robot_bodies.append(self.env.unwrapped.hull)
        if hasattr(self.env.unwrapped, "legs") and self.env.unwrapped.legs is not None:
            robot_bodies.extend(self.env.unwrapped.legs)

        # Scale each fixture's density and then reset the body mass.
        for body in robot_bodies:
            for fixture in body.fixtures:
                fixture.density *= self.mass_scale
            # Force the body to recalculate its mass properties.
            body.ResetMassData()
        return obs, info

    def step(self, action):
        return self.env.step(action)


if __name__ == "__main__":
    env = gym.make("BipedalWalker-v3")
    wrapped_env = gym.make("BipedalWalker-v3")
    test_env = gym.make("BipedalWalker-v3")

    wrapped_env = MassScalingWrapper(wrapped_env, mass_scale=2.0)

    obs, info = env.reset(seed=42)
    obs, info = wrapped_env.reset(seed=42)
    obs, info = test_env.reset(seed=42)
    done = False
    while not done:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        new_obs, reward, terminated, truncated, info = wrapped_env.step(action)
        diff = np.linalg.norm(obs - new_obs)
        print(diff)
        obs, reward, terminated, truncated, info = test_env.step(action)
        diff = np.linalg.norm(obs - new_obs)
        print(diff)

        if terminated or truncated:
            done = True

    env.close()
