import unittest

from src.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
from src.rllib.env.multi_agent_env import make_multi_agent


class TestGroupAgentsWrapper(unittest.TestCase):
    def test_group_agents_wrapper(self):
        MultiAgentCartPole = make_multi_agent("CartPole-v0")
        grouped_ma_cartpole = GroupAgentsWrapper(
            env=MultiAgentCartPole({
                "num_agents": 4
            }),
            groups={
                "group1": [0, 1],
                "group2": [2, 3]
            })
        obs = grouped_ma_cartpole.reset()
        self.assertTrue(len(obs) == 2)
        self.assertTrue("group1" in obs and "group2" in obs)
        self.assertTrue(
            isinstance(obs["group1"], list) and len(obs["group1"]) == 2)
        self.assertTrue(
            isinstance(obs["group2"], list) and len(obs["group2"]) == 2)


if __name__ == "__main__":
    import sys
    import pytest
    sys.exit(pytest.main(["-v", __file__]))
