import copy
from typing import Dict, Optional, Tuple, Union

import gym
import numpy as np
import pyglet
import torch

from offline_rl.rewards.reward_model import RewardModel


class Ball:
    """Represents a ball in the environment.

    Args:
        max_position: Maximum position at which to reset or bounce off a wall.
        min_speed: Minimum speed at which to travel.
        max_speed: Maximum speed at which to travel.
        min_radius: Minimum radius of the ball.
        max_radius: Maximum radius of the ball.
        action_mode: Indicates how the ball accelerations should be sampled.
            See `act` for available modes.
    """
    RADIUS = 1

    def __init__(
            self,
            max_position: float,
            min_speed: float = 0,
            max_speed: float = 5,
            action_mode: str = "zero",
    ):
        self.max_position = max_position
        self.min_speed = min_speed
        self.max_speed = max_speed
        self.action_mode = action_mode
        self.reset()

        # This class member is used for rendering.
        self.trans = None

    @staticmethod
    def observation_space(
            *,
            max_position: float,
            max_speed: float,
            mode: str,
            position_buffer: float,
            **kwargs,
    ) -> gym.spaces.Space:
        """Gets the observation space of this type of ball.

        Args:
            max_position: Maximum position of the ball.
            max_speed: Maximum speed of the ball.
            mode: The mode in which to return the observation space.
            position_buffer: Buffer to add to the low/high position value.

        Returns:
            Space associated with this class.
        """
        del kwargs
        if mode == "dict":
            return gym.spaces.Dict({
                "position":
                gym.spaces.Box(low=-position_buffer, high=max_position + position_buffer, shape=(2, )),
                "velocity":
                gym.spaces.Box(low=0, high=max_speed, shape=(2, )),
            })
        elif mode == "flat":
            return gym.spaces.Box(
                low=np.array([-position_buffer, -position_buffer, -max_speed, -max_speed]).astype(np.float32),
                high=np.array([
                    max_position + position_buffer,
                    max_position + position_buffer,
                    max_speed,
                    max_speed,
                ]).astype(np.float32),
            )
        else:
            raise ValueError("Invalid mode: {mode}")

    @property
    def position(self) -> np.ndarray:
        return np.array([self.x, self.y])

    @property
    def velocity(self) -> np.ndarray:
        return np.array([self.vx, self.vy])

    def get_obs(self, mode: str = "dict") -> Union[Dict, np.ndarray]:
        if mode == "dict":
            return dict(
                position=self.position,
                velocity=self.velocity,
            )
        elif mode == "flat":
            return np.concatenate((self.position, self.velocity))
        else:
            raise ValueError("Invalid mode: {mode}")

    def reset(self) -> None:
        """Resets internal state, but unlike the env does not return that state."""
        self.r = self.RADIUS
        self.x = np.random.uniform(0, self.max_position)
        self.y = np.random.uniform(0, self.max_position)
        self.vx = np.random.uniform(-self.max_speed, self.max_speed)
        self.vy = np.random.uniform(-self.max_speed, self.max_speed)
        self._cap_speed()

    def act(self) -> np.ndarray:
        """Returns an action for this ball."""
        if self.action_mode == "zero":
            return np.array([0, 0])
        elif self.action_mode == "random":
            return np.random.uniform(-1, 1, size=2)
        else:
            raise ValueError("Invalid action mode: {self.action_mode}")

    def apply_accel(self, accel: np.ndarray, dt: float) -> None:
        """Applies an acceleration to this ball.

        Args:
            accel: Array of accel values.
            dt: Time delta / time for which to apply the acceleration.
        """
        self.x, self.y = self.position + self.velocity * dt + 0.5 * accel * dt**2
        self.vx, self.vy = self.velocity + accel * dt
        self._cap_speed()

    def _cap_speed(self) -> None:
        """Limits the speed of the ball."""
        scale = np.linalg.norm(self.velocity) / self.max_speed
        if scale > 1:
            self.vx /= scale
            self.vy /= scale

    def bounce(self, other: "Ball"):
        """Bounces this ball off the other ball."""


class BouncingBallsEnv(gym.Env):
    """A goal reaching environment with bouncing balls.

    Args:
        min_num_balls: Minimum number of other agent balls in the scene.
        max_num_balls: Maximum number of other agent balls in the scene (inclusive).
        side_length: Length of a side of the box in which the balls bounce.
        dt: Time delta between time steps.
        terminate_on_wall_contact: If True, ends the episode when the ego hits the wall.
        obs_space_mode: The mode to use for the observation space. See code for options.
        is_continuing: If True, make the env infinite horizon and implicitly reset upon normal termination.
        max_timesteps: The maximum number of timesteps before ending the episode.
            This applies even in the continuing case.
        ball_cls: The class of the ball to use in the env.
        ball_kwargs: Key word arguments to use in creating the ball class.
    """
    goal_radius = 1
    max_accel_magnitude = 5
    reached_goal_reward = 1
    metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50}

    def __init__(
            self,
            min_num_balls: int = 4,
            max_num_balls: int = 4,
            side_length: float = 50,
            dt: float = 0.2,
            terminate_on_wall_contact: bool = False,
            obs_space_mode: str = "flat",
            is_continuing: bool = True,
            max_timesteps: int = 400,
            ball_cls: type = Ball,
            # pylint: disable=dangerous-default-value
            ball_kwargs: Dict = dict(max_speed=5, action_mode="random"),
    ):
        assert min_num_balls >= 0
        assert max_num_balls >= min_num_balls
        self.min_num_balls = min_num_balls
        self.max_num_balls = max_num_balls
        self.side_length = side_length
        self.dt = dt
        self.terminate_on_wall_contact = terminate_on_wall_contact
        self.obs_space_mode = obs_space_mode
        self.is_continuing = is_continuing
        self.max_timesteps = max_timesteps
        self.ball_cls = ball_cls
        self.ball_kwargs = ball_kwargs

        self.action_space = gym.spaces.Box(low=-self.max_accel_magnitude, high=self.max_accel_magnitude, shape=(2, ))
        self.observation_space = self._get_obs_space(position_buffer=5)

        self.t = 0
        self.viewer = None
        # Only used if environment is in continuing mode.
        self.reset_next_step = False

    def _get_obs_space(self, position_buffer):
        if self.obs_space_mode == "dict":
            return gym.spaces.Dict({
                "ego":
                Ball.observation_space(
                    max_position=self.side_length,
                    mode=self.obs_space_mode,
                    position_buffer=position_buffer,
                    **self.ball_kwargs,
                ),
                "balls":
                gym.spaces.Tuple([
                    Ball.observation_space(
                        max_position=self.side_length,
                        mode=self.obs_space_mode,
                        position_buffer=position_buffer,
                        **self.ball_kwargs,
                    ) for _ in range(self.max_num_balls)
                ]),
                "goal":
                gym.spaces.Box(low=0, high=self.side_length, shape=(2, )),
            })
        elif self.obs_space_mode == "flat":
            ball_obs_space = Ball.observation_space(
                max_position=self.side_length,
                mode=self.obs_space_mode,
                position_buffer=position_buffer,
                **self.ball_kwargs,
            )
            # The state is [ego, balls, goal].
            low = np.concatenate((
                np.tile(ball_obs_space.low, 1 + self.max_num_balls),
                [0, 0],
            ))
            high = np.concatenate((
                np.tile(ball_obs_space.high, 1 + self.max_num_balls),
                [self.side_length, self.side_length],
            ))
            return gym.spaces.Box(low=low.astype(np.float32), high=high.astype(np.float32))
        else:
            raise ValueError("Invalid obs space mode: {self.obs_space_mode}")

    def _get_obs_dict(self):
        obs = dict()
        obs["ego"] = self.balls[0].get_obs()

        ball_obs = []
        for ball in self.balls[1:]:
            ball_obs.append(ball.get_obs())
        obs["balls"] = ball_obs

        obs["goal"] = self.goal_position
        return obs

    def _get_obs_flat(self):
        state = np.zeros_like(self.observation_space.low)
        for i, ball in enumerate(self.balls):
            ball_state = ball.get_obs(mode="flat")
            start_index = i * len(ball_state)
            end_index = start_index + len(ball_state)
            state[start_index:end_index] = ball_state
        state[-2:] = self.goal_position
        return state

    def _get_obs(self) -> Dict:
        """Gets the observation given internal state."""
        if self.obs_space_mode == "dict":
            return self._get_obs_dict()
        elif self.obs_space_mode == "flat":
            return self._get_obs_flat()
        else:
            raise ValueError("Invalid obs space mode: {self.obs_space_mode}")

    def _reset_goal(self) -> None:
        self.goal_position = np.array([np.random.uniform(0, self.side_length), np.random.uniform(0, self.side_length)])

    def _reset_ego(self) -> None:
        self.balls[0] = Ball(max_position=self.side_length, **self.ball_kwargs)

    def _continuing_reset(self) -> None:
        # Do _not_ reset self.t here o/w it will not be enforced.
        self.viewer = None
        self._reset_goal()
        self._reset_ego()

    def reset(self) -> Dict:
        """Resets the environment, sampling new state values for the ego, goal, and balls."""
        # Do reset t here b/c if this is getting called it's either b/c we're not running in
        # continuing mode or b/c we are but hit the time limit anyway.
        self.t = 0
        self.viewer = None
        self._reset_goal()
        # Plus 1 since `self.max_num_balls` is inclusive.
        num_balls = np.random.randint(low=self.min_num_balls, high=self.max_num_balls + 1)
        self.balls = []
        for _ in range(num_balls + 1):
            self.balls.append(Ball(max_position=self.side_length, **self.ball_kwargs))
        return self._get_obs()

    def step(self, action: np.ndarray) -> Tuple:
        """Step the environment forward.

        Args: 
            action: An element from the action space, to be applied to the ego.

        Returns:
            The (next state, reward, terminal, info) tuple.
        """
        action = np.clip(action, -self.max_accel_magnitude, self.max_accel_magnitude)
        info = dict()
        if self.reset_next_step:
            # If this is true, then the previous step resulted in a terminal state.
            # Last timestep we returned that terminal observation, and this step we now
            # reset the environment. When we reset here we have to not step the dynamics
            # later, while still appropriately identifying an initial terminal state.
            # Also, only reset ego and goal so ados don't move randomly.
            self._continuing_reset()
            update_state = False
            # Go back to normal execution after this step.
            self.reset_next_step = False
        else:
            # Current state is not the result of a reset, so step the dynamics per usual.
            update_state = True

        ego_collided = self._step_dynamics(action, update_state=update_state)
        ego_reached_goal = self._check_if_ego_reached_goal()

        terminal = ego_collided or ego_reached_goal
        if self.is_continuing and terminal:
            # Don't reset here so that the agent observes the terminal state.
            terminal = False
            self.reset_next_step = True
            info["originally_terminal"] = True

        self.t += 1
        if self.t >= self.max_timesteps:
            terminal = True
            # In case these both occur at the same timestep, avoid resetting twice.
            self.reset_next_step = False

        reward = self.reached_goal_reward if ego_reached_goal else 0
        return self._get_obs(), reward, terminal, info

    def _check_if_ego_reached_goal(self) -> bool:
        """Returns True if ego is overlapping the goal."""
        ego_ball = self.balls[0]
        dist = np.linalg.norm(self.goal_position - ego_ball.position)
        inside_dist = ego_ball.r + self.goal_radius
        return dist < inside_dist

    def _step_dynamics(self, action: np.ndarray, update_state: bool = True) -> bool:
        """Step the agent states forward in response to the action.

        Args:
            action: The action for the ego to take.
            update_state: If True updates the internal state.

        Returns:
            True if ego collided.
        """
        if update_state:
            balls = self.balls
        else:
            balls = copy.deepcopy(self.balls)

        # Skip ego ball in collecting the actions.
        accels = [action] + [ball.act() for ball in balls[1:]]

        # Apply the accelerations.
        for ball, accel in zip(balls, accels):
            ball.apply_accel(accel, self.dt)

        # Check for collisions, and if any occur update state in response.
        # Wall collisions.
        ego_collided_with_wall = True
        for i, ball in enumerate(balls):
            # What force to apply when a ball hit the wall?
            # The actual force
            if ball.x + ball.r >= self.side_length:
                # Right wall.
                ball.x = self.side_length - ball.r
                ball.vx = -ball.vx
            elif ball.y - ball.r <= 0:
                # Bottom wall.
                ball.y = ball.r
                ball.vy = -ball.vy
            elif ball.x - ball.r <= 0:
                # Left wall.
                ball.x = ball.r
                ball.vx = -ball.vx
            elif ball.y + ball.r >= self.side_length:
                # Top wall.
                ball.y = self.side_length - ball.r
                ball.vy = -ball.vy
            elif i == 0:
                ego_collided_with_wall = False

        ego_collided_with_wall = ego_collided_with_wall and self.terminate_on_wall_contact

        # Ball collisions.
        ego_collided_with_ball = False
        for i, b1 in enumerate(balls):
            for b2 in balls[i + 1:]:
                dx = b2.x - b1.x
                dy = b2.y - b1.y
                combined_radius_sq = (b1.r + b2.r)**2
                dist_sq = dx**2 + dy**2
                if dist_sq < combined_radius_sq:
                    # Collision occurred.
                    b1.bounce(b2)
                    if i == 0:
                        ego_collided_with_ball = True
        return ego_collided_with_wall or ego_collided_with_ball

    def render(self, mode: str = "human") -> Optional[np.ndarray]:
        """Renders the environment as an rgb image or directly to gui.

        Args:
            mode: The rendering mode. See metadata class constant for options.

        Returns:
            Image of the environment (if mode == "rgb_array") or None.
        """
        screen_width = 1000
        screen_height = 1000
        scale = 20

        if self.viewer is None:
            try:
                from gym.envs.classic_control import rendering
            except pyglet.canvas.xlib.NoSuchDisplayException as e:
                print("Rendering with pyglet requires a display.\n"
                      "If running in a jupyter notebook, see the following link:\n"
                      "https://stackoverflow.com/questions/40195740/how-to-run-openai-gym-render-over-a-server\n")
                raise e

            self.viewer = rendering.Viewer(screen_width, screen_height)
            self.circles = []
            for i, ball in enumerate(self.balls):
                circle = rendering.make_circle(radius=ball.r * scale)
                if i == 0:
                    color = (0, 0, 1)
                else:
                    color = (0, 0, 0)
                circle.set_color(*color)
                circle_trans = rendering.Transform()
                circle.add_attr(circle_trans)
                self.viewer.add_geom(circle)
                ball.trans = circle_trans

            goal = rendering.make_circle(radius=1 * scale)
            goal.set_color(0, 1, 0)
            goal_trans = rendering.Transform()
            goal.add_attr(goal_trans)
            goal_trans.set_translation(*(self.goal_position * scale))
            self.viewer.add_geom(goal)

        for ball in self.balls:
            ball.trans.set_translation(ball.x * scale, ball.y * scale)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")

    @staticmethod
    def get_ego_states_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the ego states from a batch of states."""
        assert len(states.shape) == 2
        return states[:, :4]

    @staticmethod
    def get_ego_positions_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the ego positions from a batch of states."""
        assert len(states.shape) == 2
        return states[:, :2]

    @staticmethod
    def get_ego_velocities_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the ego velocities from a batch of states."""
        assert len(states.shape) == 2
        return states[:, 2:4]

    @staticmethod
    def get_other_states_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the states of the other balls from a batch of states."""
        assert len(states.shape) == 2
        return states[:, 4:-2]

    @staticmethod
    def get_other_positions_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the positions of the other balls from a batch of states.

        Returns:
            The positions as a shape (batch size, num other obstacles, 2D positions) tensor.
        """
        assert len(states.shape) == 2
        batch_size = states.shape[0]
        return states[:, 4:-2].reshape(batch_size, -1, 4)[:, :, :2]

    @staticmethod
    def get_other_velocities_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the velocities of the other balls from a batch of states.

        Returns:
            The velocities as a shape (batch size, num other obstacles, 2D positions) tensor.
        """
        assert len(states.shape) == 2
        batch_size = states.shape[0]
        return states[:, 4:-2].reshape(batch_size, -1, 4)[:, :, 2:4]

    @staticmethod
    def get_goal_positions_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Extracts the goal positions from flat versions of the states."""
        assert len(states.shape) == 2
        return states[:, -2:]

    @staticmethod
    def get_distances_from_other_balls_from_flat_states(states: torch.FloatTensor) -> torch.FloatTensor:
        """Computes the distance of the ego ball from the other balls.

        Returns:
            Returns a tensor of distances of shape (batch size, num other balls).
        """
        ego_positions = BouncingBallsEnv.get_ego_positions_from_flat_states(states)
        ado_positions = BouncingBallsEnv.get_other_positions_from_flat_states(states)
        position_differences = ado_positions - ego_positions[:, None, :]
        distances = torch.norm(position_differences, dim=-1)
        return distances

    @staticmethod
    def get_collisions(states: torch.FloatTensor) -> torch.Tensor:
        """Returns a boolean tensor indicating collisions between the ego and other balls."""
        distances = BouncingBallsEnv.get_distances_from_other_balls_from_flat_states(states)
        collisions = distances < 2 * Ball.RADIUS
        return collisions

    @staticmethod
    def get_distances_from_goal(states: torch.FloatTensor) -> torch.Tensor:
        """Computes the distances from the goal across the batch."""
        ego_positions = BouncingBallsEnv.get_ego_positions_from_flat_states(states)
        goal_positions = BouncingBallsEnv.get_goal_positions_from_flat_states(states)
        distances = torch.norm(ego_positions - goal_positions, dim=-1).reshape(-1, 1)
        return distances

    @staticmethod
    def get_reached_goal_indicators(states: torch.FloatTensor) -> torch.Tensor:
        """Returns a boolean tensor indicating for each element in the batch whether the goal has been reached."""
        distances = BouncingBallsEnv.get_distances_from_goal(states)
        return distances < Ball.RADIUS + BouncingBallsEnv.goal_radius


class BouncingBallsEnvRewardModel(RewardModel):
    """A reward model for the bouncing balls environment.

    Args:
        obs_space: The observation space used in the environment.
        act_space: The action space used in the environment.
        reaching_goal_reward: The reward received upon reaching the goal location.
        obstacle_collision_reward: The reward received upon collision with an obstacle.
        action_magnitude_reward: The value to multiply the action magnitude to get the reward component.
        distance_from_goal_reward: The scaling factor on the (negative) distance to the goal reward component.
        shaping_toward_goal_factor: The value to scale the toward goal shaping.
        shaping_discount: The discount factor used in potential shaping.
    """
    def __init__(
            self,
            obs_space: gym.spaces.Box,
            act_space: gym.spaces.Box,
            reaching_goal_reward: float,
            obstacle_collision_reward: float,
            action_magnitude_reward: float,
            distance_from_goal_reward: float,
            shaping_toward_goal_factor: float,
            shaping_discount: float,
    ):
        self.obs_space = obs_space
        self.act_space = act_space
        self.reaching_goal_reward = reaching_goal_reward
        self.obstacle_collision_reward = obstacle_collision_reward
        self.action_magnitude_reward = action_magnitude_reward
        self.distance_from_goal_reward = distance_from_goal_reward
        self.shaping_toward_goal_factor = shaping_toward_goal_factor
        self.shaping_discount = shaping_discount

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.obs_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.act_space

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Computes the reward for the environment accounting various factors (beyond that of the ground truth reward).

        Implementation note: The reward in this environment is typically computed on the next states.

        See base class for documentation on args and return value.
        """
        # Initialize the reward as zeros and then add to it.
        rewards = torch.zeros((len(states), 1)).to(states.device)

        # Goal reaching.
        reached_goal = BouncingBallsEnv.get_reached_goal_indicators(next_states)
        rewards += reached_goal * self.reaching_goal_reward

        # Collisions.
        in_collision_per_ado = BouncingBallsEnv.get_collisions(next_states)
        in_collision_per_state = torch.any(in_collision_per_ado, dim=1, keepdims=True)
        rewards += in_collision_per_state * self.obstacle_collision_reward

        # Action magnitude.
        action_magnitudes = torch.norm(actions, dim=1).reshape(-1, 1)
        rewards += action_magnitudes * self.action_magnitude_reward

        # Penalize distance from goal directly (reward is the negative of the distance).
        goal_distances = BouncingBallsEnv.get_distances_from_goal(states)
        rewards += -goal_distances * self.distance_from_goal_reward

        # Potential shaping towards goal.
        # The potential shaping is based on the sqrt distance so that it's
        # sufficiently different from the distance-based reward.
        states_goal_distances = BouncingBallsEnv.get_distances_from_goal(states)**(1 / 2)
        next_states_goal_distances = BouncingBallsEnv.get_distances_from_goal(next_states)**(1 / 2)
        shaping_toward_goal = -(next_states_goal_distances * self.shaping_discount - states_goal_distances)
        rewards += shaping_toward_goal * self.shaping_toward_goal_factor

        return rewards

    @classmethod
    def make_ground_truth_reward_model(cls, obs_space, act_space):
        return cls(
            obs_space,
            act_space,
            reaching_goal_reward=BouncingBallsEnv.reached_goal_reward,
            obstacle_collision_reward=0,
            action_magnitude_reward=0,
            distance_from_goal_reward=0,
            shaping_toward_goal_factor=0,
            shaping_discount=1,
        )


class BouncingBallsEnvFeasibilityRewardWrapper:
    """A reward wrapper that returns the base reward for feasible transitions and a different reward otherwise.

    Args:
        base_reward: The reward model to use for feasible transitions.
        alternative_reward: The reward model to use for infeasible transitions.
        dt: The time delta used by the environment to generate transitions.
        max_position_error: The maximum error in the position.
        check_ego: If true check ego dynamics as part of feasibility check.
    """
    def __init__(
            self,
            base_reward: RewardModel,
            alternative_reward: RewardModel,
            dt: float,
            max_position_error: float,
            check_ego: bool = False,
    ):
        self.base_reward = base_reward
        self.alternative_reward = alternative_reward
        self.dt = dt
        self.max_position_error = max_position_error
        self.check_ego = check_ego

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.base_reward.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.base_reward.action_space

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Checks for transition feasibility and returns the base reward if feasible and alternative otherwise.

        See base class for documentation.
        """
        assert next_states is not None
        feasibility_indicators = self.transitions_are_feasible(states, actions, next_states)
        base = self.base_reward.reward(states, actions, next_states, terminals)
        alternative = self.alternative_reward.reward(states, actions, next_states, terminals)
        reward = feasibility_indicators * base + torch.logical_not(feasibility_indicators) * alternative
        return reward

    # pylint: disable=unused-argument
    def transitions_are_feasible(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: torch.Tensor,
    ) -> torch.Tensor:
        """Computes whether the provided transitions are possible / reasonable in the bouncing balls env.

        Args:
            states: The states transitioned from.
            actions: The actions taken in those states.
            next_states: The next states transitioned to.

        Returns:
            A bool tensor where `True` indicates the corresponding transition is feasible.
        """
        feasibility = torch.ones((len(states), 1), dtype=bool, device=states.device)

        # Check for ego feasibility.
        if self.check_ego:
            ego_positions = BouncingBallsEnv.get_ego_positions_from_flat_states(states)
            ego_velocities = BouncingBallsEnv.get_ego_velocities_from_flat_states(states)
            expected_next_ego_positions = ego_positions + self.dt * ego_velocities
            next_ego_positions = BouncingBallsEnv.get_ego_positions_from_flat_states(next_states)
            ego_position_errors = (expected_next_ego_positions - next_ego_positions).norm(dim=1, keepdim=True)
            ego_feasibility = ego_position_errors < self.max_position_error
            feasibility = torch.logical_and(feasibility, ego_feasibility)

        # Check for ado feasibility.
        ado_positions = BouncingBallsEnv.get_other_positions_from_flat_states(states)
        ado_velocities = BouncingBallsEnv.get_other_velocities_from_flat_states(states)
        expected_next_ado_positions = ado_positions + self.dt * ado_velocities
        next_ado_positions = BouncingBallsEnv.get_other_positions_from_flat_states(next_states)
        ado_position_errors = (expected_next_ado_positions - next_ado_positions).norm(dim=2)
        ado_feasibility = ado_position_errors < self.max_position_error
        ado_feasibility = torch.all(ado_feasibility, dim=1, keepdim=True)
        feasibility = torch.logical_and(feasibility, ado_feasibility)

        return feasibility
