import pytest
import torch

from offline_rl.rewards.evaluation.transition_sampler import (
    add_dim_and_tile_first_dim,
    FixedDistributionTransitionSampler,
    UniformlyRandomActionSampler,
    BoundaryActionSampler,
    LinearActionSampler,
)


@pytest.mark.parametrize("arr,reps,expected_shape", [
    (torch.zeros(10, 20), 5, (5, 10, 20)),
    (torch.zeros(1, 20), 5, (5, 1, 20)),
    (torch.zeros(10, 20), 1, (1, 10, 20)),
    (torch.zeros(1, 20), 1, (1, 1, 20)),
])
def test_add_dim_and_tile_first_dim(arr, reps, expected_shape):
    result = add_dim_and_tile_first_dim(arr, reps)
    assert result.shape == expected_shape


class TestFixedDistributionTransitionSampler:
    @pytest.mark.parametrize("num_samples_per_pair,state_dim,action_dim,batch_size,provide_weights", [
        (10, 5, 2, 128, True),
        (1, 5, 2, 128, True),
        (10, 5, 2, 1, True),
        (1, 5, 2, 1, True),
        (1, 1, 1, 1, True),
        (10, 1, 10, 1, False),
        (1, 1, 10, 1, False),
        (1, 1, 1, 1, False),
    ])
    def test_sample(self, num_samples_per_pair, state_dim, action_dim, batch_size, provide_weights):
        actions = torch.zeros((num_samples_per_pair, action_dim))
        next_states = torch.zeros((num_samples_per_pair, state_dim))
        weights = torch.ones(num_samples_per_pair) if provide_weights else None
        transition_fn = FixedDistributionTransitionSampler(actions, next_states, weights)

        states = torch.zeros((batch_size, state_dim))
        tiled_actions, tiled_next_states, tiled_weights = transition_fn.sample(states)
        assert tiled_actions.shape == (batch_size, num_samples_per_pair, action_dim)
        assert tiled_next_states.shape == (batch_size, num_samples_per_pair, state_dim)
        assert tiled_weights.shape == (batch_size, num_samples_per_pair)


class TestUniformlyRandomActionSampler:
    @pytest.mark.parametrize("num_actions,max_magnitude,num_samples,dtype,device", [
        (10, 5.0, 20, torch.float32, "cpu"),
        (1, 5.0, 20, torch.float32, "cpu"),
        (10, 5.0, 1, torch.float32, "cpu"),
        (10, 1.0, 20, torch.float32, "cpu"),
        (10, 1.0, 20, torch.float64, "cpu"),
    ])
    def test_call(self, num_actions, max_magnitude, num_samples, dtype, device):
        sampler = UniformlyRandomActionSampler(num_actions, max_magnitude)
        actions = sampler(num_samples, dtype, device)
        assert actions.shape == (num_samples, num_actions, 2)
        assert torch.all(torch.gt(actions, -max_magnitude))
        assert torch.all(torch.lt(actions, max_magnitude))
        assert actions.dtype == dtype
        assert actions.device == torch.device(device)


class TestBoundaryActionSampler:
    @pytest.mark.parametrize("max_magnitude,num_samples,dtype,device", [
        (5.0, 20, torch.float32, "cpu"),
        (1.0, 20, torch.float32, "cpu"),
        (5.0, 1, torch.float32, "cpu"),
        (5.0, 20, torch.float64, "cpu"),
    ])
    def test_call(self, max_magnitude, num_samples, dtype, device):
        sampler = BoundaryActionSampler(max_magnitude)
        actions = sampler(num_samples, dtype, device)
        assert actions.shape == (num_samples, sampler.num_actions, 2)
        assert torch.all(torch.logical_or(
            torch.eq(actions, -max_magnitude),
            torch.eq(actions, max_magnitude),
        ))
        assert actions.dtype == dtype
        assert actions.device == torch.device(device)


class TestLinearActionSampler:
    @pytest.mark.parametrize("ndim", [2, 3, 1])
    @pytest.mark.parametrize("num_actions_each_dim,max_magnitude,num_samples,dtype,device", [
        (5, 5.0, 20, torch.float32, "cpu"),
        (5, 1.0, 20, torch.float32, "cpu"),
        (5, 5.0, 1, torch.float32, "cpu"),
        (5, 5.0, 20, torch.float64, "cpu"),
        (2, 5.0, 20, torch.float64, "cpu"),
    ])
    def test_call(self, ndim, num_actions_each_dim, max_magnitude, num_samples, dtype, device):
        sampler = LinearActionSampler(num_actions_each_dim, max_magnitude, ndim)
        actions = sampler(num_samples, dtype, device)
        assert actions.shape == (num_samples, sampler.num_actions, ndim)
        assert actions.dtype == dtype
        assert actions.device == torch.device(device)
