import json  # note: ujson fails this test due to float equality
import copy
import pickle
import tempfile

import numpy as np
import pytest

from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict


@pytest.mark.parametrize(
    "space",
    [
        Discrete(3),
        Discrete(5, start=-2),
        Box(low=0.0, high=np.inf, shape=(2, 2)),
        Tuple([Discrete(5), Discrete(10)]),
        Tuple(
            [
                Discrete(5),
                Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
            ]
        ),
        Tuple((Discrete(5), Discrete(2), Discrete(2))),
        Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
        MultiDiscrete([2, 2, 100]),
        MultiBinary(10),
        Dict(
            {
                "position": Discrete(5),
                "velocity": Box(
                    low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
                ),
            }
        ),
    ],
)
def test_roundtripping(space):
    sample_1 = space.sample()
    sample_2 = space.sample()
    assert space.contains(sample_1)
    assert space.contains(sample_2)
    json_rep = space.to_jsonable([sample_1, sample_2])

    json_roundtripped = json.loads(json.dumps(json_rep))

    samples_after_roundtrip = space.from_jsonable(json_roundtripped)
    sample_1_prime, sample_2_prime = samples_after_roundtrip

    s1 = space.to_jsonable([sample_1])
    s1p = space.to_jsonable([sample_1_prime])
    s2 = space.to_jsonable([sample_2])
    s2p = space.to_jsonable([sample_2_prime])
    assert s1 == s1p, f"Expected {s1} to equal {s1p}"
    assert s2 == s2p, f"Expected {s2} to equal {s2p}"


@pytest.mark.parametrize(
    "space",
    [
        Discrete(3),
        Discrete(5, start=-2),
        Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
        Box(low=-np.inf, high=np.inf, shape=(1, 3)),
        Tuple([Discrete(5), Discrete(10)]),
        Tuple(
            [
                Discrete(5),
                Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
            ]
        ),
        Tuple((Discrete(5), Discrete(2), Discrete(2))),
        Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
        MultiDiscrete([2, 2, 100]),
        MultiBinary(6),
        Dict(
            {
                "position": Discrete(5),
                "velocity": Box(
                    low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
                ),
            }
        ),
    ],
)
def test_equality(space):
    space1 = space
    space2 = copy.deepcopy(space)
    assert space1 == space2, f"Expected {space1} to equal {space2}"


@pytest.mark.parametrize(
    "spaces",
    [
        (Discrete(3), Discrete(4)),
        (Discrete(3), Discrete(3, start=-1)),
        (MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
        (MultiBinary(8), MultiBinary(7)),
        (
            Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
            Box(low=np.array([-10, 0]), high=np.array([10, 9]), dtype=np.float32),
        ),
        (
            Box(low=-np.inf, high=0.0, shape=(2, 1)),
            Box(low=0.0, high=np.inf, shape=(2, 1)),
        ),
        (Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
        (
            Tuple([Discrete(5), Discrete(10)]),
            Tuple([Discrete(5, start=7), Discrete(10)]),
        ),
        (Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
        (Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
    ],
)
def test_inequality(spaces):
    space1, space2 = spaces
    assert space1 != space2, f"Expected {space1} != {space2}"


@pytest.mark.parametrize(
    "space",
    [
        Discrete(5),
        Discrete(8, start=-20),
        Box(low=0, high=255, shape=(2,), dtype="uint8"),
        Box(low=-np.inf, high=np.inf, shape=(3, 3)),
        Box(low=1.0, high=np.inf, shape=(3, 3)),
        Box(low=-np.inf, high=2.0, shape=(3, 3)),
    ],
)
def test_sample(space):
    space.seed(0)
    n_trials = 100
    samples = np.array([space.sample() for _ in range(n_trials)])
    expected_mean = 0.0
    if isinstance(space, Box):
        if space.is_bounded():
            expected_mean = (space.high + space.low) / 2
        elif space.is_bounded("below"):
            expected_mean = 1 + space.low
        elif space.is_bounded("above"):
            expected_mean = -1 + space.high
        else:
            expected_mean = 0.0
    elif isinstance(space, Discrete):
        expected_mean = space.start + space.n / 2
    else:
        raise NotImplementedError
    np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())


@pytest.mark.parametrize(
    "spaces",
    [
        (Discrete(5), MultiBinary(5)),
        (
            Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
            MultiDiscrete([2, 2, 8]),
        ),
        (
            Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8),
            Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
        ),
        (Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
        (Dict({"position": Discrete(5)}), Discrete(5)),
        (Tuple((Discrete(5),)), Discrete(5)),
        (
            Box(low=np.array([-np.inf, 0.0]), high=np.array([0.0, np.inf])),
            Box(low=np.array([-np.inf, 1.0]), high=np.array([0.0, np.inf])),
        ),
    ],
)
def test_class_inequality(spaces):
    assert spaces[0] == spaces[0]
    assert spaces[1] == spaces[1]
    assert spaces[0] != spaces[1]
    assert spaces[1] != spaces[0]


@pytest.mark.parametrize(
    "space_fn",
    [
        lambda: Dict(space1="abc"),
        lambda: Dict({"space1": "abc"}),
        lambda: Tuple(["abc"]),
    ],
)
def test_bad_space_calls(space_fn):
    with pytest.raises(AssertionError):
        space_fn()


def test_seed_Dict():
    test_space = Dict(
        {
            "a": Box(low=0, high=1, shape=(3, 3)),
            "b": Dict(
                {
                    "b_1": Box(low=-100, high=100, shape=(2,)),
                    "b_2": Box(low=-1, high=1, shape=(2,)),
                }
            ),
            "c": Discrete(5),
        }
    )

    seed_dict = {
        "a": 0,
        "b": {
            "b_1": 1,
            "b_2": 2,
        },
        "c": 3,
    }

    test_space.seed(seed_dict)

    # "Unpack" the dict sub-spaces into individual spaces
    a = Box(low=0, high=1, shape=(3, 3))
    a.seed(0)
    b_1 = Box(low=-100, high=100, shape=(2,))
    b_1.seed(1)
    b_2 = Box(low=-1, high=1, shape=(2,))
    b_2.seed(2)
    c = Discrete(5)
    c.seed(3)

    for i in range(10):
        test_s = test_space.sample()
        a_s = a.sample()
        assert (test_s["a"] == a_s).all()
        b_1_s = b_1.sample()
        assert (test_s["b"]["b_1"] == b_1_s).all()
        b_2_s = b_2.sample()
        assert (test_s["b"]["b_2"] == b_2_s).all()
        c_s = c.sample()
        assert test_s["c"] == c_s


def test_box_dtype_check():
    # Related Issues:
    # https://github.com/openai/gym/issues/2357
    # https://github.com/openai/gym/issues/2298

    space = Box(0, 2, tuple(), dtype=np.float32)

    # casting will match the correct type
    assert space.contains(0.5)

    # float64 is not in float32 space
    assert not space.contains(np.array(0.5))
    assert not space.contains(np.array(1))


@pytest.mark.parametrize(
    "space",
    [
        Discrete(3),
        Discrete(3, start=-4),
        Box(low=0.0, high=np.inf, shape=(2, 2)),
        Tuple([Discrete(5), Discrete(10)]),
        Tuple(
            [
                Discrete(5),
                Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
            ]
        ),
        Tuple((Discrete(5), Discrete(2), Discrete(2))),
        MultiDiscrete([2, 2, 100]),
        MultiBinary(10),
        Dict(
            {
                "position": Discrete(5),
                "velocity": Box(
                    low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
                ),
            }
        ),
    ],
)
def test_seed_returns_list(space):
    def assert_integer_list(seed):
        assert isinstance(seed, list)
        assert len(seed) >= 1
        assert all([isinstance(s, int) for s in seed])

    assert_integer_list(space.seed(None))
    assert_integer_list(space.seed(0))


def convert_sample_hashable(sample):
    if isinstance(sample, np.ndarray):
        return tuple(sample.tolist())
    if isinstance(sample, (list, tuple)):
        return tuple(convert_sample_hashable(s) for s in sample)
    if isinstance(sample, dict):
        return tuple(
            (key, convert_sample_hashable(value)) for key, value in sample.items()
        )

    return sample


def sample_equal(sample1, sample2):
    return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)


@pytest.mark.parametrize(
    "space",
    [
        Discrete(3),
        Discrete(3, start=-4),
        Box(low=0.0, high=np.inf, shape=(2, 2)),
        Tuple([Discrete(5), Discrete(10)]),
        Tuple(
            [
                Discrete(5),
                Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
            ]
        ),
        Tuple((Discrete(5), Discrete(2), Discrete(2))),
        MultiDiscrete([2, 2, 100]),
        MultiBinary(10),
        Dict(
            {
                "position": Discrete(5),
                "velocity": Box(
                    low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
                ),
            }
        ),
    ],
)
def test_seed_reproducibility(space):
    space1 = space
    space2 = copy.deepcopy(space)

    space1.seed(None)
    space2.seed(None)

    assert space1.seed(0) == space2.seed(0)
    assert sample_equal(space1.sample(), space2.sample())


@pytest.mark.parametrize(
    "space",
    [
        Tuple([Discrete(100), Discrete(100)]),
        Tuple([Discrete(5), Discrete(10)]),
        Tuple([Discrete(5), Discrete(5, start=10)]),
        Tuple(
            [
                Discrete(5),
                Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
            ]
        ),
        Tuple((Discrete(5), Discrete(2), Discrete(2))),
        Dict(
            {
                "position": Discrete(5),
                "velocity": Box(
                    low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
                ),
            }
        ),
    ],
)
def test_seed_subspace_incorrelated(space):
    subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()

    space.seed(0)
    states = [
        convert_sample_hashable(subspace.np_random.bit_generator.state)
        for subspace in subspaces
    ]

    assert len(states) == len(set(states))


def test_tuple():
    spaces = [Discrete(5), Discrete(10), Discrete(5)]
    space_tuple = Tuple(spaces)

    assert len(space_tuple) == len(spaces)
    assert space_tuple.count(Discrete(5)) == 2
    assert space_tuple.count(MultiBinary(2)) == 0
    for i, space in enumerate(space_tuple):
        assert space == spaces[i]
    for i, space in enumerate(reversed(space_tuple)):
        assert space == spaces[len(spaces) - 1 - i]
    assert space_tuple.index(Discrete(5)) == 0
    assert space_tuple.index(Discrete(5), 1) == 2
    with pytest.raises(ValueError):
        space_tuple.index(Discrete(10), 0, 1)


def test_multidiscrete_as_tuple():
    # 1D multi-discrete
    space = MultiDiscrete([3, 4, 5])

    assert space.shape == (3,)
    assert space[0] == Discrete(3)
    assert space[0:1] == MultiDiscrete([3])
    assert space[0:2] == MultiDiscrete([3, 4])
    assert space[:] == space and space[:] is not space
    assert len(space) == 3

    # 2D multi-discrete
    space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])

    assert space.shape == (2, 3)
    assert space[0, 1] == Discrete(4)
    assert space[0] == MultiDiscrete([3, 4, 5])
    assert space[0:1] == MultiDiscrete([[3, 4, 5]])
    assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
    assert space[:, 0:1] == MultiDiscrete([[3], [6]])
    assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
    assert space[:] == space and space[:] is not space
    assert space[:, :] == space and space[:, :] is not space


def test_multidiscrete_subspace_reproducibility():
    # 1D multi-discrete
    space = MultiDiscrete([100, 200, 300])
    space.seed(None)

    assert sample_equal(space[0].sample(), space[0].sample())
    assert sample_equal(space[0:1].sample(), space[0:1].sample())
    assert sample_equal(space[0:2].sample(), space[0:2].sample())
    assert sample_equal(space[:].sample(), space[:].sample())
    assert sample_equal(space[:].sample(), space.sample())

    # 2D multi-discrete
    space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
    space.seed(None)

    assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
    assert sample_equal(space[0].sample(), space[0].sample())
    assert sample_equal(space[0:1].sample(), space[0:1].sample())
    assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
    assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
    assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
    assert sample_equal(space[:].sample(), space[:].sample())
    assert sample_equal(space[:, :].sample(), space[:, :].sample())
    assert sample_equal(space[:, :].sample(), space.sample())


def test_space_legacy_state_pickling():
    legacy_state = {
        "shape": (
            1,
            2,
            3,
        ),
        "dtype": np.int64,
        "np_random": np.random.default_rng(),
        "n": 3,
    }
    space = Discrete(1)
    space.__setstate__(legacy_state)

    assert space.shape == legacy_state["shape"]
    assert space._shape == legacy_state["shape"]
    assert space.np_random == legacy_state["np_random"]
    assert space._np_random == legacy_state["np_random"]
    assert space.n == 3
    assert space.dtype == legacy_state["dtype"]


@pytest.mark.parametrize(
    "space",
    [
        Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
        Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
        Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
        Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
    ],
)
def test_infinite_space(space):
    # for this test, make sure that spaces that are passed in have only 0 or infinite bounds
    # because space.high and space.low are both modified within the init
    # so we check for infinite when we know it's not 0
    space.seed(0)

    assert np.all(space.high > space.low), "High bound not higher than low bound"

    sample = space.sample()

    # check if space contains sample
    assert space.contains(
        sample
    ), "Sample {sample} not inside space according to `space.contains()`"

    # manually check that the sign of the sample is within the bounds
    assert np.all(
        np.sign(space.high) >= np.sign(sample)
    ), f"Sign of sample {sample} is less than space upper bound {space.high}"
    assert np.all(
        np.sign(space.low) <= np.sign(sample)
    ), f"Sign of sample {sample} is more than space lower bound {space.low}"

    # check that int bounds are bounded for everything
    # but floats are unbounded for infinite
    if np.any(space.high != 0):
        assert (
            space.is_bounded("above") == False
        ), "inf upper bound supposed to be unbounded"
    else:
        assert (
            space.is_bounded("above") == True
        ), "non-inf upper bound supposed to be bounded"

    if np.any(space.low != 0):
        assert (
            space.is_bounded("below") == False
        ), "inf lower bound supposed to be unbounded"
    else:
        assert (
            space.is_bounded("below") == True
        ), "non-inf lower bound supposed to be bounded"

    # check for dtype
    assert (
        space.high.dtype == space.dtype
    ), "High's dtype {space.high.dtype} doesn't match `space.dtype`'"
    assert (
        space.low.dtype == space.dtype
    ), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"


def test_discrete_legacy_state_pickling():
    legacy_state = {
        "n": 3,
    }

    d = Discrete(1)
    assert "start" in d.__dict__
    del d.__dict__["start"]  # legacy did not include start param
    assert "start" not in d.__dict__

    d.__setstate__(legacy_state)

    assert d.start == 0
    assert d.n == 3


@pytest.mark.parametrize(
    "space",
    [
        Discrete(3),
        Discrete(5, start=-2),
        Box(low=0.0, high=np.inf, shape=(2, 2)),
        Tuple([Discrete(5), Discrete(10)]),
        Tuple(
            [
                Discrete(5),
                Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
            ]
        ),
        Tuple((Discrete(5), Discrete(2), Discrete(2))),
        Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
        MultiDiscrete([2, 2, 100]),
        MultiBinary(10),
        Dict(
            {
                "position": Discrete(5),
                "velocity": Box(
                    low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
                ),
            }
        ),
    ],
)
def test_pickle(space):
    space.sample()

    # Pickle and unpickle with a string
    pickled = pickle.dumps(space)
    space2 = pickle.loads(pickled)

    # Pickle and unpickle with a file
    with tempfile.TemporaryFile() as f:
        pickle.dump(space, f)
        f.seek(0)
        space3 = pickle.load(f)

    sample = space.sample()
    sample2 = space2.sample()
    sample3 = space3.sample()
    assert sample_equal(sample, sample2)
    assert sample_equal(sample, sample3)
