import copy
from multiprocessing.context import assert_spawning
import random

import numpy as np

from typing import List
from abc import ABC, abstractmethod
from collections import OrderedDict


from envs.multiagentenv import MultiAgentEnv


class GameObject(ABC):
    @abstractmethod
    def reset(self):
        pass

    @abstractmethod
    def move(self):
        pass


class Agent(GameObject):
    """ The player is the agent in the environment """
    def __init__(self,
                 init_pos: List[int],
                 speed: int = 1,
                 agent_id: int = 0,
                 dest_pos=None) -> None:
        """
        init_pos: initial position of current player
        speed: agent's speed, 1 by default
        """
        if dest_pos is None:
            self.dest_pos = copy.deepcopy(init_pos)
        else:
            self.dest_pos = dest_pos

        self.init_pos = np.array(init_pos)
        self.curr_pos = copy.deepcopy(self.init_pos)
        self.prev_pos = copy.deepcopy(self.init_pos)
        self.speed = speed
        self.agent_id = agent_id

    def reset(self) -> None:
        self.curr_pos = copy.deepcopy(self.init_pos)
        self.prev_pos = copy.deepcopy(self.init_pos)

    def _update_prev_pos(self):
        self.prev_pos = copy.deepcopy(self.curr_pos)

    def move(self) -> None:
        NotImplemented

    def move_down(self) -> None:
        self._update_prev_pos()
        self.curr_pos[0] += self.speed

    def move_up(self) -> None:
        self._update_prev_pos()
        self.curr_pos[0] -= self.speed


class Tree(GameObject):
    """ The Tree planted by the agent """
    def __init__(self,
                 init_pos: List[int],
                 growth_cycle: int = -1,
                 plant_timestep: int = -1,
                 agent_id: int = 0) -> None:
        """
        init_pos: initial position of current player
        speed: agent's speed, 1 by default
        """
        self.init_pos = np.array(init_pos)
        self.curr_pos = copy.deepcopy(self.init_pos)
        self.prev_pos = copy.deepcopy(self.init_pos)
        self.agent_id = agent_id
        self.growth_cycle = growth_cycle
        self.plant_timestep = plant_timestep
        self.timer = 0
        self._is_tree_mature = False
    
    def reset(self):
        self.curr_pos = copy.deepcopy(self.init_pos)
        self.prev_pos = copy.deepcopy(self.init_pos)

    def grow(self):
        self.timer += 1
    
    def grow_info(self):
        """
        Tell the agent how tall the tree is now.
        """
        # NOTE that .grow() will be called later, so add 1 here
        grow_info = self.timer + 1
        if grow_info > self.growth_cycle:
            return 0  # means that the tree is mature before current timestep
        return grow_info

    @property
    def duration(self):
        return self.growth_cycle

    @property
    def time_left(self):
        return self.growth_cycle - self.timer

    def is_mature(self):
        # if it is mature and the timestep is not the last timestep, 
        # it will be reset
        if not self._is_tree_mature:
            self._is_tree_mature = self.timer == self.growth_cycle
        return self._is_tree_mature

    def move(self):
        """ Trees cannot move """
        NotImplemented


class Forest(GameObject):
    def __init__(self, n_agents, grid_size, tree_num) -> None:
        self.n_agents = n_agents
        self.grid_size = grid_size
        self.tree_num = tree_num
        self.agents_tree_num = OrderedDict({
            i: copy.deepcopy(self.tree_num) for i in range(self.grid_size)
        })
        self.agents_tree = {
            i: None for i in range(self.grid_size)
        }
        self._pending_agents_tree = {
            i: None for i in range(self.grid_size)
        }

    def plant_tree(self, agent_id, tree_obj, multiple_plant=True):
        # no trees available
        if self.agents_tree_num[agent_id] == 0:
            return 
        
        # if multiple_plant == True, the tree will be planted for only once until it is mature
        # if multiple_plant == False, the tree will be planted for many times
        if (self.agents_tree.get(agent_id) is None) or (not multiple_plant):
            self.agents_tree[agent_id] = tree_obj
            self.agents_tree_num[agent_id] -= 1
        else:
            # check if agent_id's tree is mature, if at current timestep
            # it is mature, it will be destoyed
            assert self.agents_tree[agent_id] is not None, "agent_id's tree is None"
            if self.agents_tree[agent_id]:
                if self.agents_tree[agent_id].is_mature():
                    self._pending_agents_tree[agent_id] = tree_obj
                else:
                    pass  # the tree will not be planted in the forest

    def count_matured_trees(self):
        return sum([1 for k, v in self.agents_tree.items() if v and v.is_mature()])

    def count_growing_trees(self):
        return len([1 for k, v in self.agents_tree.items() if v is not None])

    def is_tree_mature(self, agent_id):
        if self.agents_tree[agent_id].is_mature():
            return True
        return False

    def grow(self):
        """Let the trees grow"""
        for k, v in self.agents_tree.items():
            if v:
                v.grow()

    def destory_tree(self):
        for agent_id, tree in self.agents_tree.items():
            if tree and tree.is_mature():
                self.agents_tree[agent_id] = None
        self._postprocessing()  # post processing the tree

    def _postprocessing(self):
        # check if there is a pending tree to be planted at this timestep
        for agent_id, tree in self._pending_agents_tree.items():
            if tree and self.agents_tree[agent_id] is None:
                assert self.agents_tree_num[agent_id] > 0, \
                    f"agents_tree_num[agent_id] <= 0. Current value: {self.agents_tree_num[agent_id]}"
                
                self.agents_tree[agent_id] = tree
                self._pending_agents_tree[agent_id] = None

    def move(self):
        NotImplemented

    def reset(self):
        self.agents_tree = {
            i: None for i in range(self.grid_size)
        }
        self.agents_tree_num = OrderedDict({
            i: copy.deepcopy(self.tree_num) for i in range(self.grid_size)
        })


class AfforestationEnv(MultiAgentEnv):
    """
    Multi-Agent Afforestation Environment
    Protect the Farm from the sand storm
    """
    ACTION_DIE = 0
    ACTION_NO_OP = 1
    ACTION_PLANT = 2
    ACTION_MOVE_UP = 3
    ACTION_MOVE_DOWN = 4

    AGENT_FLAG, AGENT_TAG, = 3, 'A'
    SAND_STORM_FLAG, SAND_STORM_TAG = 7, 'S'
    TREE_FLAG, TREE_TAG = 11, 'T'
    GRID_FLAG, GRID_TAG  = 0, '_'

    def __init__(self,
                 seed=None,
                 map_name='forest',
                 episode_limit=10,
                 reward_no_op=-0.1,
                 reward_single_tree=10,  # reward of only one tree is ok to protect the farm
                 reward_protect=100,  # reward of all tree are ok to protect the farm
                 reward_die=-2,  # agents die when they are attacked by sand storm after they plant a tree
                 reward_disaster=-100,  # reward of no tree is ok to protect the farm
                 n_agents=2,   # number of agents, aka the width of the field
                 tree_depth_size=1,   # the depth of the tree field
                 agent_depth_size=1,   # the depth of the countryside
                 storm_depth_size=1,   # the depth of the sand field
                 async_agents='all',  # all agents are async
                 max_delay=9,  # the maximum delay of the ACTION_PLANT action
                 return_to_the_origin=False,  # if True, agents should return to the initial position
                 return_to_origin_mode='strict',  # for v2.1+
                 return_to_origin_dist=-1,  # for safety and for v2.2+
                 time_limit=False,  # if True, there is a time limit for planting the tree for v2.3
                 state_last_action=True,
                 obs_hist_action=False,   # do not add historical actions into the observation
                 multiple_plant=True,
                 tree_num=2,  # the number of trees that can be planted for each agent
                 show_grow_info=False,  # show the grow info in the obs of the agent as well as in the state
                 disable_plant=False,  # disable the plant action when there is no tree
                 policy_override_rule='no',
                 disable_plant_when_tree_planted=False, # disable the plant action when there is already a tree planted
                 pos_next_to_tree_field=False,  # if True, the agent will be placed next to the tree field for v2.5
                 policy_override_prob=0):
        """
        N agents catching the stag in a 10x10 grid
         ___________
        |_X_|_X_|_X_|
        |_P_|_P_|_P_|
        |_A_|_A_|_A_|
        """
        self.seed = seed
        self.map_name = map_name
        self.episode_limit = episode_limit
        self.max_delay = self.max_delay_step = max_delay
        self.n_agents = n_agents
        self.tree_num = tree_num
        self.show_grow_info = show_grow_info

        if self.n_agents == 2:
            self.delay_types = [((i+1) * (self.max_delay//self.n_agents))
                if (i+1) < self.n_agents else self.max_delay for i in range(self.n_agents)]
        elif self.n_agents in {3, 4}:
            half_delay = self.max_delay // 2
            self.delay_types = [half_delay, half_delay] + [self.max_delay] * (self.n_agents - 2)

        self.agent_delays = [
            self.delay_types[i] for i in range(self.n_agents)
        ]

        self.reward_no_op = reward_no_op
        self.reward_single_tree = reward_single_tree
        self.reward_protect = reward_protect
        self.reward_die = reward_die
        self.reward_disaster = reward_disaster

        self.return_to_the_origin = return_to_the_origin
        self.return_to_origin_mode = return_to_origin_mode
        self.return_to_origin_dist = return_to_origin_dist
        self.time_limit = time_limit

        self.n_agents = n_agents
        self.grid_size = n_agents
        self.async_agents = async_agents
        
        self.state_last_action = state_last_action
        self.obs_hist_action = obs_hist_action
        self.policy_override_rule = policy_override_rule
        self.policy_override_prob = policy_override_prob

        self.agent_depth_size = agent_depth_size
        self.tree_depth_size = tree_depth_size
        self.storm_depth_size = storm_depth_size

        # if True, agents can plant multiple trees at the same time, 
        # but if there is a tree available, the new tree will not planted
        self.multiple_plant = multiple_plant

        # disable the plant action when there is no tree
        self.disable_plant = disable_plant
        # disable the plant action when there is already a tree planted
        self.disable_plant_when_tree_planted = disable_plant_when_tree_planted

        # if True, the agent will be placed next to the tree field for v2.5
        self.pos_next_to_tree_field = pos_next_to_tree_field

        self.curr_t = 0

        # 1. Init the farm
        self._init_farm()
        # 2. Init the forest
        self._forest = Forest(self.n_agents, self.grid_size, self.tree_num)
        # 3. Create the agents
        self._create_agents()
        # 4. Init obs and state
        self._init_state_obs()

        self.unit_dim = self.agent_depth_size + self.tree_depth_size + self.storm_depth_size
    
    def _get_max_delay(self):
        return self.max_delay_step

    def _create_agents(self):
        assert self.n_agents == self.grid_size, \
            "The number of agents should be equal to the width of the field"

        self._init_agent_action_sapce()

        self._agents = []
        for i in range(self.n_agents):
            # NOTE move_up (-1) and move_down (+1)
            if self.pos_next_to_tree_field:
                pos = [self.storm_depth_size + self.tree_depth_size, i]
            else:
                pos = [self.grid_depth-1, i]  # the last row, i-th column
            agent = Agent(init_pos=pos,
                          speed=1,
                          agent_id=i,
                          dest_pos=[self.grid_depth-1, i])
            self._agents.append(agent)
        self._agent_ids = list(range(self.n_agents))
        self.last_action = np.zeros((self.n_agents, self.n_actions)) - 1

    def _create_async_agents_sets(self, async_agents):
        if async_agents == 'all':
            self.async_agents_sets = set(range(self.n_agents))
        elif isinstance(async_agents, list) or isinstance(async_agents, set):
            self.async_agents_sets = set(async_agents)
        else:
            raise ValueError(f"there is a problem in async_agents: {async_agents}")

    @property
    def n_actions(self):
        return len(self._agent_action_space)

    def _init_agent_action_sapce(self):
        # if self.agent_depth_size > 1, agent's should return to the initial position
        # else agent's should not return to the initial position
        if self.agent_depth_size > 1:
            self._agent_action_space = [self.ACTION_DIE, self.ACTION_NO_OP, self.ACTION_PLANT, self.ACTION_MOVE_UP, self.ACTION_MOVE_DOWN]
        else:
            self._agent_action_space = [self.ACTION_DIE, self.ACTION_NO_OP, self.ACTION_PLANT]

        self._agent_action_space_index = {act: idx for idx, act in enumerate(self._agent_action_space)}
        self._agent_action_space_set = set(self._agent_action_space)
        self._agent_move_action_set = {self.ACTION_MOVE_UP, self.ACTION_MOVE_DOWN}

    def _init_farm(self):
        self.grid_depth = self.agent_depth_size + self.tree_depth_size + self.storm_depth_size
        self._grid = np.zeros((self.grid_depth, self.grid_size), dtype=np.int32)
        self._grid.fill(self.GRID_FLAG)

        # 1. put the agents in the field
        # 2. NOTE tree and sand storm will not be in the field.
        for i in range(self.n_agents):
            if self.pos_next_to_tree_field:
                self._grid[self.storm_depth_size + self.tree_depth_size, i] = self.AGENT_FLAG
            else:
                self._grid[self.grid_depth-1, i] = self.AGENT_FLAG

    def _clear_farm(self):
        self.grid_depth = self.agent_depth_size + self.tree_depth_size + self.storm_depth_size
        self._grid = np.zeros((self.grid_depth, self.grid_size), dtype=np.int32)
        self._grid.fill(self.GRID_FLAG)

    def _update_farm(self):
        # 1. init an empty farm
        self._init_farm()
        # 2. update the position of the agents
        self._clear_farm()
        for agent in self._agents:
            pos = agent.curr_pos
            self._grid[pos[0], pos[1]] = self.AGENT_FLAG
        # 3. update the position of the trees
        for agent_id, tree in self._forest.agents_tree.items():
            if tree is not None:
                pos = tree.curr_pos
                if self.show_grow_info:
                    self._grid[pos[0], pos[1]] = tree.grow_info()
                else:
                    self._grid[pos[0], pos[1]] = self.TREE_FLAG
        # 4. check if there is sandstorm for the next step
        # NOTE the update of state actually prepares the state for the next step
        if self.curr_t + 1 == self.episode_limit - 1:
            for row in range(self.storm_depth_size):
                for col in range(self.n_agents):
                    self._grid[row, col] = self.SAND_STORM_FLAG

    def _init_state_obs(self):
        # NOTE 1. no need to update the farm
        # self._update_farm()
        self._update_state()
        # 2. Extract the obs for each agent from the farm
        self._update_obs()

    def _update_obs(self):
        # 2, 0, 1, 1
        # means the initialized obs contains agent's obs of stag, grid cell, 
        # the agent itself
        state_grid = copy.deepcopy(self._grid)
        self._obs = [
            np.append(state_grid[:, i], np.array([self._forest.agents_tree_num[i]])) \
                for i in range(self.n_agents)
        ]

    def _update_state(self):
        """Update the farm in step()"""
        # update farm
        self._update_farm()

        self._state = []
        self._state.append(self._grid)
        if self.state_last_action:
            self._state.append(self.last_action)
        else:
            self._state.append(np.zeros((self.n_agents, self.n_actions)) - 1)

        # append the left tree num
        self._state.append(np.array([[v] for v in self._forest.agents_tree_num.values()]).astype(np.int32))

    def reset(self):
        """ Returns initial observations and states"""
        # reset game objects
        for agent in self._agents: agent.reset()

        # be careful, the order of the following two lines is important
        self._forest.reset()
        self._init_farm()
        # after resetting game objects, init state and obs
        self._init_state_obs()

        self._agent_ids = list(range(self.n_agents))
        self.curr_t = 0

        # for debugging purpose
        self.debug_episode_actions = []

        # last actions
        self.last_action = np.zeros((self.n_agents, self.n_actions)) - 1
        return self.get_obs(), self.get_state()

    def step(self, actions):
        """ Returns reward, terminated, info """
        actions = [int(act) for act in actions]

        self.last_action = np.eye(self.n_actions)[np.array([int(act) for act in actions])]
        self.debug_episode_actions.append(tuple([int(act) for act in actions]))

        for i, act in enumerate(actions):
            assert act in self._agent_action_space_set, \
                    f'action out of range, agent_id: {i}, action: {int(act)}, ' \
                    f'action_space: {self._agent_action_space_set}'

        # execute the action of each agent
        for i, act in enumerate([int(act) for act in actions]):
            if act == self.ACTION_PLANT:
                # create a tree
                # get the position of the tree
                pos = self._get_plant_pos(agent_id=i)
                tree = Tree(init_pos=pos,
                            growth_cycle=self.agent_delays[i],
                            plant_timestep=self.curr_t,
                            agent_id=i)
                self._forest.plant_tree(agent_id=i, tree_obj=tree, multiple_plant=self.multiple_plant)

            elif act in self._agent_move_action_set:
                # move the agent
                self._move_agent(agent_id=i, action=act)
            else:
                # it not no-op, do nothing here
                pass

        # update state first and then the obs can be easily collected
        self._update_state()  # states contains arrows
        self._update_obs()  # note: obs does not contains arrows

        reward, info, terminated = self._cal_reward()
        # count how many trees are mature?
        # calculate the reward

        # update the forst 
        self._forest.destory_tree()

        # let the trees grow
        self._forest.grow()

        self.curr_t += 1
        return reward, terminated, info

    def _cal_reward(self):
        terminated = False
        info = {'won': 0, 'one_tree': 0}

        reward = self.reward_no_op
        matured_trees_num = self._forest.count_matured_trees()
        if matured_trees_num == 1:
            info[f'one_tree'] = 1
        else:
            info[f'{matured_trees_num}_trees'] = 1

        if self.curr_t + 1 != self.episode_limit:
            assert matured_trees_num != self.n_agents, \
                "All agents have matured trees, when self.curr_t + 1 < self.episode_limit"

        # the last timestep
        if self.curr_t + 1 == self.episode_limit:
            terminated = True
            # check the num
            if matured_trees_num == self.n_agents:
                info['won'] = 1
                reward = self.reward_protect
            elif matured_trees_num >= 1:
                assert matured_trees_num < self.n_agents, \
                    "the num of mature trees should be less than the num of agents"
                reward = matured_trees_num * self.reward_single_tree
            else:
                reward = self.reward_disaster

            # NOTE agents should return to the origin
            # if matured_trees_num > 0 and self.return_to_the_origin:
            if self.return_to_the_origin:
                # count the number of agents that not at the origin
                agents_not_at_origin = self.n_agents - self._get_the_num_of_agents_at_origin()
                reward += agents_not_at_origin * self.reward_die
                info['won'] = 0  # add this to disable the won
        else:
            if self.time_limit:
                if self._time_limit_checker():
                    reward, terminated = self.reward_disaster, True
        return reward, info, terminated

    def _time_limit_checker(self):
        # when the timestep reaches the limit and there is no tree planted, ends the game
        # NOTE self.curr_t means the current timestep
        if self.curr_t == self.agent_depth_size - 1:
            # check the number of mature trees
            # NOTE check the num of mature trees for the last agent
            # use self.n_agents-1 or self.n_agents-2
            growing_trees_num = int(self._forest.agents_tree[self.n_agents-1] is not None)
            return growing_trees_num == 0
        return False

    def _get_the_num_of_agents_at_origin(self):
        return sum([self._is_agent_at_origin(agent_id) for agent_id in range(self.n_agents)])

    def _is_agent_at_origin(self, agent_id):
        version = int(self.map_name.split('_')[-1].split('.')[0][1])
        assert version >= 2, "the version of the map should be 1 or 2"
        if self.return_to_origin_mode == 'strict':
            return all(self._agents[agent_id].curr_pos == self._agents[agent_id].dest_pos)
        elif self.return_to_origin_mode == 'proximal':
            # only there is a cell between the tree and the agent
            dist = int(self._agents[agent_id].dest_pos[0] - self._agents[agent_id].curr_pos[0]) + 1
            if self.return_to_origin_dist == -1:
                if dist < self.agent_depth_size:  # for v2.1  one cell is enough
                    return 1
                else:
                    return 0
            else:
                assert self.return_to_origin_dist > 0, "the dist should be positive"
                if dist <= self.return_to_origin_dist:
                    return 1
                else:
                    return 0

    def _move_agent(self, agent_id, action):
        agent = self._agents[agent_id]
        if action == self.ACTION_MOVE_DOWN:
            agent.move_down()
        elif action == self.ACTION_MOVE_UP:
            agent.move_up()
        else:
            raise ValueError(f"action: {action} is not a valid action")

    def _get_plant_pos(self, agent_id):
        # agent_id is the column
        assert self.tree_depth_size == 1, f"tree_depth_size should be 1, current value: {self.tree_depth_size}"
        return [self.storm_depth_size, agent_id]

    def debug_get_true_reward(self, info, scheme='proximal'):
        """
        this is for debugging purpose
        use different schemes to hack the environment
        in order to find the key problem
        """
        # only for debugging purpose.
        t_1_4_flag = False
        episode_rewards = [self.reward_no_op] * len(self.debug_episode_actions)

        if scheme.startswith('proximal_'):
            # TODO
            pass
        return episode_rewards

    def check_async_actions(self, async_action_list, env_t, reward):
        pass

    def _bound_position(self, gameobject: GameObject):
        return np.clip(gameobject.curr_pos, a_min=0, a_max=self.grid_size-1)

    def get_obs(self, w_hist_actions=False):
        """ Returns all agent observations in a list """
        if w_hist_actions:
            return copy.deepcopy(self._obs)
        return self._obs

    def get_obs_agent(self, agent_id):
        """ Returns observation for agent_id """
        return self._obs[agent_id]

    def get_obs_size(self):
        """ Returns the shape of the observation """
        return len(self._obs[0])

    def get_state(self):
        return np.concatenate([arr.flatten() for arr in self._state])

    def get_state_size(self):
        """ Returns the shape of the state"""
        return int(sum([np.prod(np.array(arr).shape) for arr in self._state]))

    def get_avail_actions(self):
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        # check the agent's field depth 
        agent = self._agents[agent_id]

        available_actions = [0] * len(self._agent_action_space)
        available_actions[self._agent_action_space_index[self.ACTION_NO_OP]] = 1

        if self.agent_depth_size == 1:
            if self._forest.agents_tree.get(agent_id) is not None:  # check if there is a tree planted
                if self.multiple_plant:
                    if not self.disable_plant:
                        available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = 1
                    else:
                        available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = \
                            int(self._forest.agents_tree_num[agent_id] > 0)
                # else do nothing, do not allow to plant trees
                if self.disable_plant_when_tree_planted:
                    available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = 0
            else:
                if not self.disable_plant:
                    available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = 1
                else:
                    available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = \
                        int(self._forest.agents_tree_num[agent_id] > 0)
        else:
            # 1. check the agent is next to the tree field
            if agent.curr_pos[0] - 1 == self.storm_depth_size:
                if self._forest.agents_tree.get(agent_id) is not None:
                    if self.multiple_plant:
                        available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = 1
                        # else do nothing, do not allow to plant trees
                    if self.disable_plant_when_tree_planted:
                        available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = 0
                else:
                    available_actions[self._agent_action_space_index[self.ACTION_PLANT]] = \
                        int(self._forest.agents_tree_num[agent_id] > 0)

                available_actions[self._agent_action_space_index[self.ACTION_MOVE_DOWN]] = 1
            # 2. check the agent is next to the origin
            elif agent.curr_pos[0] + 1 == self.grid_depth:
                for act in [self.ACTION_MOVE_UP]:
                    available_actions[self._agent_action_space_index[act]] = 1
            # 3. agent can move up/down, noop or plant
            else:
                for act in [self.ACTION_MOVE_UP, self.ACTION_MOVE_DOWN]:
                    available_actions[self._agent_action_space_index[act]] = 1
        return available_actions

    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        return len(self._agent_action_space)

    def get_agent_pending_actions(self, agent_id: int) -> List[int]:
        """
        Get pending actions: [a_{t-m}, a_{t-m+1}, ..., a_{t-1}]
        """
        # TODO maybe this function can be deleted
        pending_actions = []
        return pending_actions

    def get_agents_pending_actions(self) -> List[List[int]]:
        pending_actions = []
        # TODO maybe this function can be deleted
        return pending_actions   

    def get_visibility_matrix(self):
        """Returns a boolean numpy array of dimensions 
        (n_agents, n_agents + n_enemies) indicating which units
        are visible to each agent.
        """
        arr = np.zeros(
            (self.n_agents, self.n_agents), 
            dtype=np.bool,
        )
        for agent_id in range(self.n_agents):
            for al_id in range(self.n_agents):
                if al_id > agent_id:
                    arr[agent_id, al_id] = arr[al_id, agent_id] = 1
        return arr 

    def policy_override(self, actions, override_policy=True):
        if self.policy_override_rule == 'no' or not override_policy:
            return actions
        else:
            # at t=0, set action shot and noop for agent 1 and agent 0 respectively
            # for all schemes
            if self.curr_t == 0:
                pass
        return actions

    def get_action_durations(self, actions: List, curr_timestep: int):
        """
        Return action durations
        """
        results = []
        for agent in self._agents:
            results.append(sum(np.abs(agent.init_pos-self.stag.init_pos)))

        assert len(results) == len(actions), "len(actions) != n_agents"
        return results

    def render(self):
        # TODO use pretty print
        # print the planted trees
        for agent_id, tree in self._forest.agents_tree.items():
            if tree:
                print(f'Agent: {agent_id}, timeleft: {tree.time_left}, duration: {tree.duration}')
        print(self._grid)

    def close(self):
        pass

    def seed(self):
        """Returns the random seed used by the environment."""
        return self._seed

    def save_replay(self):
        raise NotImplementedError

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit,
                    "max_delay": self._get_max_delay(),
                    "unit_dim": self.unit_dim}
        return env_info

    def get_stats(self):
        stats = {
            "catch": self.catch_the_stag,
        }
        return stats
