import pytest
import torch

from offline_rl.envs.bouncing_balls_env import BouncingBallsEnv
from offline_rl.rewards.evaluation.bouncing_balls_env_transition_samplers import (
    BouncingBallsEnvConstantVelocityTransitionSampler,
    BouncingBallsEnvActionSamplingTransitionSampler,
)
from offline_rl.rewards.evaluation.transition_sampler import BoundaryActionSampler


class TestBouncingBallsEnvConstantVelocityTransitionSampler:
    @pytest.mark.parametrize("dt,num_states", [
        (0.2, 2),
        (0.1, 2),
        (0.1, 10),
        (0.1, 1),
    ])
    def test_sample(self, dt, num_states):
        env = BouncingBallsEnv(dt=dt)
        sampler = BouncingBallsEnvConstantVelocityTransitionSampler(dt=dt)

        states = torch.tensor([env.reset() for _ in range(num_states)])
        actions, next_states, weights = sampler.sample(states)
        assert actions.shape == (num_states, 1, 2)
        assert next_states.shape == (num_states, 1, states.shape[1])
        assert weights.shape == (num_states, 1)

        for state, next_state in zip(states, next_states[:, 0, :]):
            ball_states = state[:-2].reshape(-1, 4)
            ball_next_states = next_state[:-2].reshape(-1, 4)
            expected_ball_next_positions = ball_states[:, :2] + dt * ball_states[:, 2:4]
            expected_ball_next_velocities = ball_states[:, 2:4]
            assert torch.allclose(expected_ball_next_positions, ball_next_states[:, :2])
            assert torch.allclose(expected_ball_next_velocities, ball_next_states[:, 2:4])


class TestBouncingBallsEnvActionSamplingTransitionSampler:
    @pytest.mark.parametrize("dt,num_states,state_dim", [
        (1.0, 10, 10),
        (0.0, 10, 10),
        (1.0, 1, 10),
        (1.0, 10, 22),
    ])
    def test_sample(self, dt, num_states, state_dim, max_magnitude=5.0):
        action_sampler = BoundaryActionSampler(max_magnitude)
        sampler = BouncingBallsEnvActionSamplingTransitionSampler(dt, action_sampler)
        states = torch.zeros(num_states, state_dim)
        actions, next_states, weights = sampler.sample(states)
        assert actions.shape == (num_states, action_sampler.num_actions, 2)
        assert next_states.shape == (num_states, action_sampler.num_actions, state_dim)
        assert weights.shape == (num_states, action_sampler.num_actions)
