from numpy.lib.arraysetops import isin
import pytest
import gym
import numpy as np

from gym import spaces

from expground.algorithms.base_policy import pack_action_to_policy
from expground.algorithms.random_policy import RandomPolicy


@pytest.fixture(scope="session")
def observation_space():
    return spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)


@pytest.fixture(scope="session")
def action_space():
    return spaces.Discrete(10)


@pytest.mark.parametrize("is_fixed", [(False,), (True,)])
def test_pack_action_to_policy(
    observation_space: gym.Space, action_space: gym.Space, is_fixed: bool
):
    # generate action set
    assert isinstance(action_space, spaces.Discrete)
    action_set = [i for i in range(action_space.n)]

    distribution = None
    if is_fixed:
        logits = np.random.random(action_space.n)
        distribution = logits / np.sum(logits)

    policy = pack_action_to_policy(
        action_set, observation_space, is_fixed=is_fixed, distribution=distribution
    )

    res1 = policy.compute_action(observation_space.sample(), None, False)

    assert action_space.contains(res1[0]), (action_space, res1[0])

    res2 = policy.compute_action(observation_space.sample(), None, True)

    assert action_space.contains(res2[0]), (action_space, res2[0])

    print("========== outputs ==========")
    print("train mode is: {}\neval mode is: {}".format(res1, res2))


@pytest.mark.parametrize("is_fixed", [(False,), (True,)])
def test_pack_policy_to_policy(
    observation_space: gym.Space, action_space: gym.Space, is_fixed: bool
):
    assert isinstance(action_space, spaces.Discrete)
    sub_policies = [
        RandomPolicy(observation_space, action_space, None, None)
        for _ in range(action_space.n)
    ]

    distribution = None
    if is_fixed:
        logits = np.random.random(len(sub_policies))
        distribution = logits / np.sum(logits)

    policy = pack_action_to_policy(
        sub_policies, observation_space, is_fixed, distribution
    )

    res1 = policy.compute_action(observation_space.sample(), None, False)

    assert action_space.contains(res1[0]), (action_space, res1[0])

    res2 = policy.compute_action(observation_space.sample(), None, True)

    assert action_space.contains(res2[0]), (action_space, res2[0])

    print("========== outputs ==========")
    print("train mode is: {}\neval mode is: {}".format(res1, res2))
