from typing import Tuple

import torch

from offline_rl.rewards.evaluation.transition_sampler import ActionSampler, TransitionSampler, repeat_interleave_and_reshape_to_3d


class PointMazeEnvActionSamplingTransitionSampler(TransitionSampler):
    """Samples transitions for the point maze environment using sampled actions.

    Args:
        dt: The time delta between states (should be the same as that which is used in the environment).
        action_sampler: A callable object that samples actions.
    """
    def __init__(self, dt: float, action_sampler: ActionSampler):
        self.dt = dt
        self.action_sampler = action_sampler
        # Alias for readability in the implementation.
        self.num_actions = self.action_sampler.num_actions

    def sample(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """See base class documentation."""
        assert states.ndim == 2
        assert states.shape[1] == 6, "Only 6-dim version of env is supported."
        num_states = len(states)

        actions = self.action_sampler(num_states, states.dtype, states.device)

        # Only operates on x and y dimensions.
        positions = states[:, :2]
        positions = repeat_interleave_and_reshape_to_3d(positions, self.num_actions)
        velocities = states[:, 3:5]
        velocities = repeat_interleave_and_reshape_to_3d(velocities, self.num_actions)

        next_positions = positions + self.dt * velocities + 0.5 * self.dt**2 * actions
        next_velocities = velocities + self.dt * actions
        # Pad with zeros of z dimension.
        z_pad = torch.zeros((num_states, self.num_actions, 1), device=positions.device, dtype=positions.dtype)
        next_states = torch.cat((next_positions, z_pad, next_velocities, z_pad), dim=-1)

        weights = torch.ones((len(states), self.num_actions), dtype=states.dtype, device=states.device)
        return actions, next_states, weights

    @property
    def num_transitions_per_state(self) -> int:
        return self.action_sampler.num_actions
