from typing import Dict

import gym
from ray.rllib.policy.sample_batch import SampleBatch
import torch


def get_random_state_action_next_state_batch(
        obs_space: gym.spaces.Space,
        act_space: gym.spaces.Space,
        batch_size: int,
) -> Dict:
    """Gets a random batch of torch tensors for (s, a, s').

    Args:
        obs_space: The same from which to sample states.
        act_space: The same from which to sample actions.
        batch_size: The number of samples in the batch.

    Returns:
        A dictionary batch of (s, a, s') tensors.
    """
    obs = torch.Tensor([obs_space.sample() for _ in range(batch_size)])
    act = torch.Tensor([act_space.sample() for _ in range(batch_size)])
    next_obs = torch.Tensor([obs_space.sample() for _ in range(batch_size)])

    if isinstance(obs_space, gym.spaces.Discrete):
        obs = obs.to(torch.long)
        next_obs = next_obs.to(torch.long)
    if isinstance(act_space, gym.spaces.Discrete):
        act = act.to(torch.long)
    return {
        SampleBatch.OBS: obs,
        SampleBatch.ACTIONS: act,
        SampleBatch.NEXT_OBS: next_obs,
    }
