from env.multi_agent_env import MultiAgentEnv
from env.rware.warehouse import Warehouse


class WarehouseMultiAgentEnv(Warehouse, MultiAgentEnv):
    """Wraps Google Football env to be compatible with RLlib multi-agent."""

    def __init__(self, **kwargs):
        Warehouse.__init__(self, **kwargs)
        self.action_space = self.sa_action_space
        self.observation_space = self.sa_observation_space

    def reset(self, **kwargs):
        return self.group_items(Warehouse.reset(self, **kwargs))

    def step(self, actions: dict):
        act = self.ungroup_items(actions)
        obs_list, rew_list, done, info = Warehouse.step(self, act)
        done = {"__all__": done}

        return (
            self.group_items(obs_list),
            self.group_items(rew_list),
            done,
            self.group_items(info),
        )

    def group_items(self, item):
        """Converts items to dict mapping."""
        if isinstance(item, dict):
            return {f"agent_{i}": item for i in range(self.num_agents)}
        else:
            return {f"agent_{i}": item[i] for i in range(self.num_agents)}

    def ungroup_items(self, item):
        """Converts dict mapping to list."""
        return [item[f"agent_{i}"] for i in range(self.num_agents)]
