import gym
import numpy as np
import pytest
import torch

import offline_rl.utils.space_utils as space_utils


def test_discrete_space_input_formatter():
    space = gym.spaces.Discrete(5)
    formatter = space_utils.DiscreteSpaceInputFormatter(space, torch.float32)
    result = formatter(torch.LongTensor([0, 1, 2, 3, 4]))
    assert torch.equal(result, torch.eye(5).to(torch.float32))


class TestMultiDiscreteSpaceInputFormatter:
    @pytest.mark.parametrize("inputs,expected", [
        (torch.zeros(1, 5), torch.zeros(1, 5)),
        (torch.ones(1, 5), torch.ones(1, 5)),
        (torch.tensor([[0, 0]]), torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0]])),
        (torch.tensor([[0, 1]]), torch.tensor([[1.0, 0.0, 0.0, 1.0, 0.0]])),
        (torch.tensor([[0, 2]]), torch.tensor([[1.0, 0.0, 0.0, 0.0, 1.0]])),
        (torch.tensor([[1, 0]]), torch.tensor([[0.0, 1.0, 1.0, 0.0, 0.0]])),
        (torch.tensor([[1, 1]]), torch.tensor([[0.0, 1.0, 0.0, 1.0, 0.0]])),
        (torch.tensor([[1, 2]]), torch.tensor([[0.0, 1.0, 0.0, 0.0, 1.0]])),
    ])
    def test_format(self, inputs, expected):
        space = gym.spaces.MultiDiscrete([2, 3])
        formatter = space_utils.MultiDiscreteSpaceInputFormatter(space, torch.float32)
        result = formatter(inputs)
        assert torch.equal(result, expected)

    def test_batch_format(self):
        space = gym.spaces.MultiDiscrete([3, 2])
        formatter = space_utils.MultiDiscreteSpaceInputFormatter(space, torch.float32)
        inputs = torch.tensor([
            [0, 0],
            [1, 1],
        ])
        result = formatter(inputs)
        expected = torch.tensor([
            [1.0, 0.0, 0.0, 1.0, 0.0],
            [0.0, 1.0, 0.0, 0.0, 1.0],
        ])
        assert torch.equal(result, expected)


class TestBoxSpaceInputFormatter:
    @pytest.mark.parametrize("inputs,expected", [
        (torch.zeros(1, 5), -1 / 3 * torch.ones(1, 5)),
        (torch.ones(1, 5), torch.ones(1, 5)),
        (-1 * torch.ones(1, 5), -1 * torch.ones(1, 5)),
        (-0.5 * torch.ones(1, 5), -1 * torch.ones(1, 5)),
        (2 * torch.ones(1, 5), torch.ones(1, 5)),
    ])
    def test_format(self, inputs, expected):
        low = np.ones(5).astype(np.float32) * -0.5
        high = np.ones(5).astype(np.float32)
        space = gym.spaces.Box(low=low, high=high)
        formatter = space_utils.BoxSpaceInputFormatter(space, should_normalize=True)
        result = formatter(inputs)
        assert torch.allclose(result, expected)

    @pytest.mark.parametrize("inputs,expected", [
        (torch.tensor([[-1.0, 1.0]]), torch.tensor([[-1.0, -1.0]])),
        (torch.tensor([[1.0, 2.0]]), torch.tensor([[1.0, 1.0]])),
        (torch.tensor([[0.0, 1.5]]), torch.tensor([[0.0, 0.0]])),
        (torch.tensor([[-1.0, 1.0], [1.0, 2.0]]), torch.tensor([[-1.0, -1.0], [1.0, 1.0]])),
    ])
    def test_format_multiple_valued_bounds(self, inputs, expected):
        low = np.array([-1, 1]).astype(np.float32)
        high = np.array([1, 2]).astype(np.float32)
        space = gym.spaces.Box(low=low, high=high)
        formatter = space_utils.BoxSpaceInputFormatter(space, should_normalize=True)
        result = formatter(inputs)
        assert torch.allclose(result, expected)


class TestClipActionsToSpaceBounds:
    @pytest.mark.parametrize("low,high,shape,batch_size", [
        (-5, 5, (2, ), 1),
        (-1, 1, (5, ), 1),
        (-5, 5, (2, ), 10),
    ])
    def test(self, low, high, shape, batch_size):
        space = gym.spaces.Box(low=low, high=high, shape=shape)
        actions = torch.zeros((batch_size, ) + shape, dtype=torch.float32)

        actions[0, 0] = low - 1
        clipped_actions = space_utils.clip_actions_to_space_bounds(actions, space)
        assert not torch.any(clipped_actions < low)

        actions[0, 0] = high + 1
        clipped_actions = space_utils.clip_actions_to_space_bounds(actions, space)
        assert not torch.any(clipped_actions > high)
