import collections

import gym

from ray.rllib.policy.sample_batch import SampleBatch


class RandomPolicy:
    """A policy that randomly samples actions from an action space."""
    def __init__(self, action_space):
        self.action_space = action_space

    # pylint: disable=unused-argument
    def step(self, *args, **kwargs):
        return self.action_space.sample()


def sample_rollout(env: gym.Env, policy, num_steps: int) -> SampleBatch:
    """Samples a rollout from an environment.
    
    This is for testing purposes. Do not use this for any actual algorithm. If you
    need to do a rollout in an actual environment use rllib functionality instead.
    Ideally this would also use some existing rllib functionality, but I can't find
    anything simple that does this.

    Args:
        env: The env to sample the rollout from.
        policy: The policy from which to sample actions.
        num_steps: The number of steps to collect.

    Returns:
        A SampleBatch with the rollout.
    """
    rollout = collections.defaultdict(list)

    def add_step(**kwargs):
        for k, v in kwargs.items():
            rollout[k].append(v)

    x = env.reset()
    for _ in range(num_steps):
        a = policy.step(x)
        nx, reward, terminal, info = env.step(a)
        add_step(**{
            SampleBatch.OBS: x,
            SampleBatch.ACTIONS: a,
            SampleBatch.NEXT_OBS: nx,
            SampleBatch.REWARDS: reward,
            SampleBatch.DONES: terminal,
            SampleBatch.INFOS: info,
        })
        if terminal:
            break
        x = nx
    return SampleBatch(**rollout)
