import numpy as np

import gym
import torch
from gym import spaces
from gym.spaces import Box


class TrafficJunctionWrapper(gym.Wrapper):
    r"""Add the environment dimension used in POMNIST to other environments."""

    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.num_agents = env.nagents
        self.alive_mask = np.zeros(self.num_agents)
        self.next_alive_mask = self.alive_mask
        self.num_actions = self.env.action_space.n
        self.epoch = None

        assert env.vocab_type == 'bool', "Only bool vocal type is supported (as used in IC3Net)"
        # we just redefine the observation space
        # original space:
        # 0: last action {0, 1}
        # 1: normalized road id [0, 1]
        # 3: one-hot encoded road information (includes one-hot encoded position)
        self.observation_space = [
            spaces.Box(low=0, high=1, shape=(1, 2 + np.prod(env.observation_space.spaces[2].n)), dtype=np.float32)
        ] * self.num_agents

    def convert_and_join_observation(self, obs):
        """
        Converts the given agent observations into a single tensor with updated observation space.

        :param obs: list of agents observations
        :return: agent's observations in new format
        """
        obs_tensor = np.zeros((self.num_agents, *self.observation_space[0].shape), dtype=float)
        for a in range(0, self.num_agents):
            obs_tensor[a, 0, 0] = obs[a][0]
            obs_tensor[a, 0, 1] = obs[a][1]
            obs_tensor[a, 0, 2:] = obs[a][2]

        return obs_tensor

    def reset(self):
        initial_obs = self.env.reset(self.epoch)
        # for some reason, traffic junction starts without any agents
        self.alive_mask = self.next_alive_mask = np.zeros_like(self.alive_mask)
        return self.convert_and_join_observation(initial_obs)

    def step(self, action):
        self.alive_mask = self.next_alive_mask
        next_obs, reward, done, info = self.env.step(action)
        next_obs = self.convert_and_join_observation(next_obs)
        reward = np.expand_dims(reward, 1)
        # the alive mask of traffic junction tells you whether an agent is visible in the next timestep, not the
        # current one. we want to capture the agents that have been alive in this timestep and executed an action
        self.next_alive_mask = info['alive_mask']
        return next_obs, reward, done, {
            'alive_mask': np.expand_dims(self.alive_mask, 1),
            'agent_done': np.expand_dims(info['is_completed'], 1),
            'success': np.array(1 - self.env.has_failed, dtype=bool)
        }
