import gym
import torch


def get_noncolliding_states(
        batch_size: int,
        num_balls: int,
        radius: float = 1,
        radius_multiple_separation: float = 4,
) -> torch.Tensor:
    states = torch.zeros((batch_size, 4 + num_balls * 4 + 2))
    states[:, :4] = 0
    for ball_index in range(1, num_balls + 1):
        start_index = ball_index * 4
        end_index = start_index + 4
        value = (ball_index + 1) * radius * radius_multiple_separation
        states[:, start_index:end_index] = value
    states[:, -2:] = (num_balls + 2) * radius * radius_multiple_separation
    return states


def get_colliding_states(batch_size: int, num_balls: int, radius: float = 1) -> torch.Tensor:
    assert num_balls > 0
    states = get_noncolliding_states(batch_size, num_balls, radius=radius)
    # Since the ego is at 0, put the first other ball there as well.
    states[:, 4:8] = 0
    return states


def get_goal_reached_states(batch_size: int, num_balls: int, radius: float = 1) -> torch.Tensor:
    states = get_noncolliding_states(batch_size, num_balls, radius=radius)
    # Move the goal to 0.
    states[:, -2:] = 0
    return states


def get_dist_from_goal_states(batch_size: int, num_balls: int, dist: float, radius: float = 1) -> torch.Tensor:
    states = get_noncolliding_states(batch_size, num_balls, radius=radius)
    # Move ego to 0 and goal dist from ego along x axis.
    states[:, :4] = 0
    states[:, -2] = dist
    states[:, -1] = 0
    return states


def get_noop_actions(batch_size: int) -> torch.Tensor:
    return torch.zeros((batch_size, 2))


def get_one_actions(batch_size: int) -> torch.Tensor:
    return torch.ones((batch_size, 2))


def dummy_bouncing_balls_obs_space(low: float = -1, high: float = 1, num_balls: int = 10) -> gym.spaces.Box:
    """The obs space doesn't matter for the reward models, so just make a dummy one."""
    size = 4 + 2 + num_balls * 4
    return gym.spaces.Box(low=low, high=high, shape=(size, ))


def dummy_bouncing_balls_act_space(low: float = -1, high: float = 1) -> gym.spaces.Box:
    """Same as the dummy obs space."""
    return gym.spaces.Box(low=low, high=high, shape=(2, ))
