from typing import Tuple

import torch

from offline_rl.envs.bouncing_balls_env import BouncingBallsEnv
from offline_rl.rewards.evaluation.transition_sampler import (
    ActionSampler,
    TransitionSampler,
    repeat_interleave_and_reshape_to_3d,
)


class BouncingBallsEnvConstantVelocityTransitionSampler(TransitionSampler):
    """Samples transitions for the bouncing balls environment under a constant velocity assumption.

    Args:
        dt: The time delta between states (should be the same as that which is used in the environment).
    """
    def __init__(self, dt: float):
        self.dt = dt

    def sample(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """See base class documentation."""
        assert states.ndim == 2
        ego_positions = BouncingBallsEnv.get_ego_positions_from_flat_states(states)
        ego_velocities = BouncingBallsEnv.get_ego_velocities_from_flat_states(states)
        next_ego_positions = ego_positions + self.dt * ego_velocities
        next_ego_states = torch.cat((next_ego_positions, ego_velocities), dim=1)

        ado_positions = BouncingBallsEnv.get_other_positions_from_flat_states(states)
        ado_velocities = BouncingBallsEnv.get_other_velocities_from_flat_states(states)
        next_ado_positions = ado_positions + self.dt * ado_velocities
        next_ado_states = torch.cat((next_ado_positions, ado_velocities), dim=-1)
        next_ado_states = next_ado_states.reshape(len(states), -1)

        goal_positions = BouncingBallsEnv.get_goal_positions_from_flat_states(states)

        next_states = torch.cat((next_ego_states, next_ado_states, goal_positions), dim=1)
        actions = torch.zeros((len(states), 2), dtype=states.dtype, device=states.device)
        weights = torch.ones((len(states), 1), dtype=states.dtype, device=states.device)

        # Since this is deterministic there's only a single action / next state per state, but nevertheless
        # we still need to add a dimension to adhere to the expected return shape.
        next_states = next_states[:, None, :]
        actions = actions[:, None, :]

        return actions, next_states, weights

    @property
    def num_transitions_per_state(self) -> int:
        return 1


class BouncingBallsEnvActionSamplingTransitionSampler(TransitionSampler):
    """Samples transitions for the bouncing balls environment using sampled actions.

    Args:
        dt: The time delta between states (should be the same as that which is used in the environment).
        action_sampler: A callable object that samples actions.
    """
    def __init__(self, dt: float, action_sampler: ActionSampler):
        self.dt = dt
        self.action_sampler = action_sampler
        # Alias for readability in the implementation.
        self.num_actions = self.action_sampler.num_actions

    def sample(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """See base class documentation."""
        assert states.ndim == 2
        num_states = len(states)

        actions = self.action_sampler(num_states, states.dtype, states.device)

        ego_positions = BouncingBallsEnv.get_ego_positions_from_flat_states(states)
        ego_positions = repeat_interleave_and_reshape_to_3d(ego_positions, self.num_actions)
        ego_velocities = BouncingBallsEnv.get_ego_velocities_from_flat_states(states)
        ego_velocities = repeat_interleave_and_reshape_to_3d(ego_velocities, self.num_actions)

        next_ego_positions = ego_positions + self.dt * ego_velocities + 0.5 * self.dt**2 * actions
        next_ego_velocities = ego_velocities + self.dt * actions
        next_ego_states = torch.cat((next_ego_positions, next_ego_velocities), dim=-1)

        ado_positions = BouncingBallsEnv.get_other_positions_from_flat_states(states)
        ado_velocities = BouncingBallsEnv.get_other_velocities_from_flat_states(states)
        next_ado_positions = ado_positions + self.dt * ado_velocities
        next_ado_states = torch.cat((next_ado_positions, ado_velocities), dim=-1)
        next_ado_states = next_ado_states.reshape(len(states), -1)
        next_ado_states = repeat_interleave_and_reshape_to_3d(next_ado_states, self.num_actions)

        goal_positions = BouncingBallsEnv.get_goal_positions_from_flat_states(states)
        goal_positions = repeat_interleave_and_reshape_to_3d(goal_positions, self.num_actions)

        next_states = torch.cat((next_ego_states, next_ado_states, goal_positions), dim=-1)

        weights = torch.ones((len(states), self.num_actions), dtype=states.dtype, device=states.device)

        return actions, next_states, weights

    @property
    def num_transitions_per_state(self) -> int:
        return self.action_sampler.num_actions
