from typing import Tuple

import torch

from offline_rl.rewards.evaluation.transition_sampler import TransitionSampler


class LineEnvUniformPolicyTransitionSampler(TransitionSampler):
    """Samples transitions in the line env under a uniformly random policy.

    Args:
        side_length: The maximum length of a side in the line environment.
    """
    def __init__(self, side_length: int):
        self.side_length = side_length

    def sample(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """See base class documentation."""
        assert states.ndim == 2
        num_states = len(states)

        actions = torch.zeros((num_states, 2, 1), dtype=states.dtype, device=states.device)
        actions[:, 1, :] = 1

        next_states = states.detach().clone()
        next_states = torch.zeros((num_states, 2, 1), dtype=states.dtype, device=states.device)
        next_states[:, 0, :] = torch.maximum(
            next_states[:, 0, :] - 1,
            torch.tensor(0, dtype=states.dtype, device=states.device),
        )
        next_states[:, 1, :] = torch.minimum(
            next_states[:, 1, :] + 1,
            torch.tensor(self.side_length * 2, dtype=states.dtype, device=states.device),
        )

        weights = torch.ones((num_states, 2), device=states.device)

        return actions, next_states, weights

    @property
    def num_transitions_per_state(self) -> int:
        return 2
