import logging
from collections import defaultdict
from enum import Enum
from itertools import product
from gym.utils import seeding
import numpy as np
import torch as th
from .utils import matched
from itertools import combinations
from envs.multiagentenv import MultiAgentEnv


class Action(Enum):
    NONE = 0
    NORTH = 1
    SOUTH = 2
    WEST = 3
    EAST = 4

class CellEntity(Enum):
    # entity encodings for grid observations
    EMPTY = 0
    LANDMARK = 1
    AGENT = 2

class Player:
    def __init__(self):
        self.controller = None
        self.position = None
        self.field_size = None
        self.current_step = None

    def setup(self, position, field_size):
        self.position = position
        self.field_size = field_size

    def set_controller(self, controller):
        self.controller = controller

    def step(self, obs):
        return self.controller._step(obs)

    @property
    def name(self):
        if self.controller:
            return self.controller.name
        else:
            return "Player"

# task_id2config = {
#     "2": {
#         "n_agents": 2,
#         "field_size": [6,6],
#         "sight": 5,
#         "episode_limit": 50,
#         "reach_range": 2,
#     },
#     "3": {
#         "n_agents": 3,
#         "field_size": [8,8],
#         "sight": 7,
#         "episode_limit": 50,
#         "reach_range": 2,
#     },
#     "4": {
#         "n_agents": 4,
#         "field_size": [10,10],
#         "sight": 9,
#         "episode_limit": 50,
#         "reach_range": 2,
#     },
#     "5": {
#         "n_agents": 5,
#         "field_size": [10,10],
#         "sight": 9,
#         "episode_limit": 50,
#         "reach_range": 2,
#     },
#     "6": {
#         "n_agents": 6,
#         "field_size": [15,15],
#         "sight": 14,
#         "episode_limit": 50,
#         "reach_range": 2,
#     },
#     "7": {
#         "n_agents": 7,
#         "field_size": [15,15],
#         "sight": 14,
#         "episode_limit": 70,
#         "reach_range": 2,
#     },
#     "8": {
#         "n_agents": 8,
#         "field_size": [15,15],
#         "sight": 14,
#         "episode_limit": 70,
#         "reach_range": 2,
#     },
#     "9": {
#         "n_agents": 9,
#         "field_size": [15,15],
#         "sight": 14,
#         "episode_limit": 70,
#         "reach_range": 2,
#     },
#     "10": {
#         "n_agents": 10,
#         "field_size": [15,15],
#         "sight": 14,
#         "episode_limit": 70,
#         "reach_range": 2,
#     },
#     "11": {
#         "n_agents": 11,
#         "field_size": [18, 18],
#         "sight": 17,
#         "episode_limit": 70,
#         "reach_range": 2,
#     },
#     "12": {
#         "n_agents": 12,
#         "field_size": [18, 18],
#         "sight": 17,
#         "episode_limit": 70,
#         "reach_range": 2,
#     },
#     "15": {
#         "n_agents": 15,
#         "field_size": [20,20],
#         "sight": 19,
#         "episode_limit": 80,
#         "reach_range": 2,
#     }
# }

task_id2config = {
    56: {
        "n_agents": 2,
        "n_landmarks": 2,
        "field_size": [12,12],
        "sight": 11,
        "episode_limit": 70,
        "reach_range": 2,
    },
    57: {
        "n_agents": 3,
        "n_landmarks": 3,
        "field_size": [12,12],
        "sight": 11,
        "episode_limit": 70,
        "reach_range": 2,
    },
    58: {
        "n_agents": 4,
        "n_landmarks": 4,
        "field_size": [12,12],
        "sight": 11,
        "episode_limit": 70,
        "reach_range": 2,
    },
    59: {
        "n_agents": 5,
        "n_landmarks": 5,
        "field_size": [12,12],
        "sight": 11,
        "episode_limit": 70,
        "reach_range": 2,
    },
    60: {
        "n_agents": 6,
        "n_landmarks": 6,
        "field_size": [15,15],
        "sight": 14,
        "episode_limit": 100,
        "reach_range": 2,
    },
    61: {
        "n_agents": 7,
        "n_landmarks": 7,
        "field_size": [15,15],
        "sight": 14,
        "episode_limit": 120,
        "reach_range": 2,
    },
    62: {
        "n_agents": 8,
        "n_landmarks": 8,
        "field_size": [15,15],
        "sight": 14,
        "episode_limit": 120,
        "reach_range": 2,
    },
}


task_id2config_new = {
    "n_agents": 3,
    "n_landmarks": 6,
    "field_size": [15,15],
    "sight": 14,
    "episode_limit": 70,
    "reach_range": 2,
}

task_id2config_new_1 = {
    "n_agents": 2,
    "n_landmarks": 6,
    "field_size": [15,15],
    "sight": 14,
    "episode_limit": 70,
    "reach_range": 2,
}

task_id2config_new_2 = {
    "n_agents": 4,
    "n_landmarks": 6,
    "field_size": [15,15],
    "sight": 14,
    "episode_limit": 100,
    "reach_range": 2,
}

task_id2config_new_3 = {
    "n_agents": 5,
    "n_landmarks": 6,
    "field_size": [15,15],
    "sight": 14,
    "episode_limit": 100,
    "reach_range": 2,
}


class MTGridMPEEnv(MultiAgentEnv):
    """
    Class for Multi-task GridMPE.
    """

    action_set = [Action.NORTH, Action.SOUTH, Action.WEST, Action.EAST]
    def __init__(
            self,
            n_agents,
            n_landmarks, ### n_landmarks should larger than n_agents
            field_size,
            sight,
            episode_limit,
            reach_range,
            seed=None,
            default_task=True,
            task_id=None,
            fixed_reset=False,
            **kwargs,
    ):
        self.logger = logging.getLogger(__name__)
        self.seed(seed)
    
        self.n_agents = n_agents
        self.n_landmarks = n_landmarks
        self.n_targets = n_agents
        self.players = [Player() for _ in range(n_agents)]
        self.sight = sight
        
        self.field = np.zeros(field_size, np.int32)
        self.landmark_locs = None
        self.LANDMARK = 1

        self._game_over = None        
        self._valid_actions = None
        self.episode_limit = episode_limit
        self.reach_range = reach_range

        self.fixed_reset = fixed_reset
        self.player_rows = [7, 6, 8]
        self.player_cols = [8, 7, 7]
        self.given_rows = [7, 3, 3, 7, 11, 11]
        self.given_cols = [14, 10, 4, 0, 4, 10]

        self._score = 0

        if default_task:
            if 0<=task_id<=19:
                task_config = task_id2config_new
            elif 20<=task_id<=34:
                task_config = task_id2config_new_1
            elif 35<=task_id<=49:
                task_config = task_id2config_new_2
            elif 50<=task_id<=55:
                task_config = task_id2config_new_3
            elif 56<=task_id<=62:
                task_config = task_id2config[task_id]
            self.n_agents = task_config["n_agents"]
            self.n_targets = self.n_agents
            self.n_landmarks = task_config["n_landmarks"]
            self.players = [Player() for _ in range(self.n_agents)]
            self.field = np.zeros(task_config["field_size"], np.int32)
            self.sight = task_config["sight"]
            self.episode_limit = task_config["episode_limit"]
            self.reach_range = task_config["reach_range"]
        
        self.task_id = task_id
        task_ls_0 = self.gen_tasks(6, 3)
        task_ls_1 = self.gen_tasks(6, 2)
        task_ls_2 = self.gen_tasks(6, 4)
        task_ls_3 = self.gen_tasks(6, 5)
        task_ls_4 = []
        for task_agent in range(2,9):
            tmp_target = [1]*task_agent
            task_ls_4.append(tmp_target)
        self.task_ls = task_ls_0 + task_ls_1 + task_ls_2 + task_ls_3 + task_ls_4
        assert len(self.task_ls)==63, "The number of tasks should be 63, but got {}".format(len(self.task_ls))
        if self.task_id is None:
            self.task_id = self.np_random.randint(0, len(self.task_ls))
        self.target_landmarks = self.task_ls[task_id]
    
    def reset_task(self, task_id=None):
        self.task_id = task_id
        if self.task_id is None:
            self.task_id = self.np_random.randint(0, len(self.task_ls))
        self.target_landmarks = self.task_ls[self.task_id]
    
    def gen_tasks(self, n, k):
        comb_list = list(combinations(range(n), k))
        binary_vectors = []
        for comb in comb_list:
            vector = [0] * n
            for idx in comb:
                vector[idx] = 1
            binary_vectors.append(vector)
        return binary_vectors

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    @property
    def field_size(self):
        return self.field.shape

    @property
    def rows(self):
        return self.field_size[0]

    @property
    def cols(self):
        return self.field_size[1]

    @property
    def game_over(self):
        return self._game_over
    
    @property
    def n_actions(self):
        return 5

    def _gen_valid_moves(self):
        self._valid_actions = {
            player: [
                action for action in Action if self._is_valid_action(player, action)
            ]
            for player in self.players
        }

    def neighborhood(self, row, col, distance=1, ignore_diag=False):
        if not ignore_diag:
            return self.field[
                   max(row - distance, 0): min(row + distance + 1, self.rows),
                   max(col - distance, 0): min(col + distance + 1, self.cols),
                   ]

        return (
                self.field[
                max(row - distance, 0): min(row + distance + 1, self.rows), col
                ].sum()
                + self.field[
                  row, max(col - distance, 0): min(col + distance + 1, self.cols)
                  ].sum()
        )

    def adjacent_players(self, row, col):
        return [
            player
            for player in self.players
            if abs(player.position[0] - row) == 1
               and player.position[1] == col
               or abs(player.position[1] - col) == 1
               and player.position[0] == row
        ]

    def spawn_landmarks(self, rows=None, cols=None):
        landmark_count = 0
        attempts = 0
        if rows is not None:
            while landmark_count < self.n_landmarks:
                row = rows[landmark_count]
                col = cols[landmark_count]
                self.field[row, col] = self.LANDMARK
                self.landmark_locs[landmark_count] = (row, col)
                landmark_count +=1
        else:
            while landmark_count < self.n_landmarks and attempts < 10000:
                attempts += 1
                row = self.np_random.randint(0, self.rows)
                col = self.np_random.randint(0, self.cols)

                if not self._is_empty_location(row, col):
                    continue
                
                if self._reached_by_agent(row, col):
                    continue
                
                self.field[row, col] = self.LANDMARK
                self.landmark_locs[landmark_count] = (row, col)

                landmark_count +=1

    def _is_empty_location(self, row, col):
        if self.field[row, col] != 0:
            return False

        for a in self.players:
            if a.position and row == a.position[0] and col == a.position[1]:
                return False

        return True
    
    def _reached_by_agent(self, row, col):
        for a in self.players:
            if a.position and a.position[0] - self.reach_range <= row <= a.position[0] + self.reach_range and \
                a.position[1] - self.reach_range <= col <= a.position[1] + self.reach_range:
                return True
        
        return False
    
    def spawn_players(self, player_rows=None, player_cols=None):
        for i, player in enumerate(self.players):
            if player_rows is not None:
                row = player_rows[i]
                col = player_cols[i]
                player.setup(
                    (row, col),
                    self.field_size,
                )
            else:
                attempts = 0
                
                while attempts < 10000:
                    row = self.np_random.randint(0, self.rows)
                    col = self.np_random.randint(0, self.cols)
                    if self._is_empty_location(row, col):
                        player.setup(
                            (row, col),
                            self.field_size,
                        )
                        break
                    attempts += 1
    
    def _is_valid_action(self, player, action):
        if action == Action.NONE:
            return True
        elif action == Action.NORTH:
            return (
                    player.position[0] > 0
            )
        elif action == Action.SOUTH:
            return (
                    player.position[0] < self.rows - 1
            )
        elif action == Action.WEST:
            return (
                    player.position[1] > 0
            )
        elif action == Action.EAST:
            return (
                    player.position[1] < self.cols - 1
            )

        # Should not get here!!!
        self.logger.error("Undefined action {} from {}".format(action, player.name))
        raise ValueError("Undefined action")

    def get_valid_actions(self):
        return list(product(*[self._valid_actions[player] for player in self.players]))

    def _within_sight(self, ego_position, other_position):
        return (ego_position[0] - self.sight <= other_position[0] <= ego_position[0] + self.sight) and \
            (ego_position[1] - self.sight <= other_position[1] <= ego_position[1] + self.sight)

    def _within_reach(self, ego_position, land_position):
        return (ego_position[0] - self.reach_range <= land_position[0] <= ego_position[0] + self.reach_range) and \
            (ego_position[1] - self.reach_range <= land_position[1] <= ego_position[1] + self.reach_range)

    def _loc_increasement(self, locs):
        return [[loc[0] + 1, loc[1] + 1] for loc in locs]

    def _get_state_info(self):
        state_info = [p.position for p in self.players]
        state_info.extend(self.landmark_locs)
        return state_info

    def _reach_landmark(self, ego_position):
        return np.sum(self.neighborhood(ego_position[0], ego_position[1], distance=self.reach_range)) > 0

    def get_obs_agent(self, agent_id):
        state_info = self._get_state_info()
        obs = [loc if self._within_sight(self.players[agent_id].position, loc) else [-1, -1] for loc in state_info[:agent_id] + state_info[agent_id+1:]]
        obs = [[self.players[agent_id].position[0], self.players[agent_id].position[1]]] + obs
        obs = self._loc_increasement(obs)
        return np.array(obs).flatten()

    def get_obs(self):
        return  [self.get_obs_agent(agent_id) for agent_id in range(self.n_agents)]

    def get_obs_size(self):
        return 2 * (self.n_agents + self.n_landmarks)

    def get_state(self):
        state_info = self._get_state_info()
        state_info = self._loc_increasement(state_info)
        return np.array(state_info).flatten()

    def get_state_size(self):
        return 2 * (self.n_agents + self.n_landmarks)
    
    def get_total_actions(self):
        # 5 possible actions
        return 5

    def get_avail_actions(self):
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_agent = self.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_agent)
        return avail_actions

    def get_avail_agent_actions(self, agent_id):
        valid_actions = self._valid_actions[self.players[agent_id]]
        avail_actions = [1 if Action(i) in valid_actions else 0 for i in range(self.n_actions)]
        return avail_actions

    def reset(self):
        self.field = np.zeros(self.field_size, np.int32)
        # spawn players on the board
        if not self.fixed_reset:
            self.spawn_players()    
            # spawn the food on the board
            self.landmark_locs = [None for _ in range(self.n_landmarks)]
            self.spawn_landmarks()
        else:
            self.spawn_players(player_rows=self.player_rows, player_cols=self.player_cols)    
            # spawn the food on the board
            self.landmark_locs = [None for _ in range(self.n_landmarks)]
            self.spawn_landmarks(rows=self.given_rows, cols=self.given_cols)
        
        self.current_step = 0
        self._score = 0
        self._game_over = False
        
        self._gen_valid_moves()
        
        return self.get_obs(), self.get_state()

    def step(self, actions):
        if actions.__class__ == th.Tensor:
            actions = actions.cpu().numpy()
        
        assert len(actions)==self.n_agents, len(actions)

        self.current_step += 1

        actions = [
            Action(a) if Action(a) in self._valid_actions[p] else Action.NONE
            for p, a in zip(self.players, actions)
        ]

        # check if actions are valid
        for i, (player, action) in enumerate(zip(self.players, actions)):
            if action not in self._valid_actions[player]:
                self.logger.info(
                    "{}{} attempted invalid action {}.".format(
                        player.name, player.position, action
                    )
                )
                actions[i] = Action.NONE

        # move players
        # if two or more players try to move to the same location they all fail
        collisions = defaultdict(list)

        # so check for collisions
        for player, action in zip(self.players, actions):
            if action == Action.NONE:
                collisions[player.position].append(player)
            elif action == Action.NORTH:
                collisions[(player.position[0] - 1, player.position[1])].append(player)
            elif action == Action.SOUTH:
                collisions[(player.position[0] + 1, player.position[1])].append(player)
            elif action == Action.WEST:
                collisions[(player.position[0], player.position[1] - 1)].append(player)
            elif action == Action.EAST:
                collisions[(player.position[0], player.position[1] + 1)].append(player)
    
        # and do movements for non colliding players
        for k, v in collisions.items():
            if len(v) > 1:  # make sure no more than an player will arrive at location
                continue
            v[0].position = k

        _succeed = self._succeed()
        reward_succ = 100 if _succeed else 0

        reward_near = 0
        for i in range(self.n_landmarks):
            if self.target_landmarks[i] != 0:
                dists = [abs(a.position[0] - self.landmark_locs[i][0]) + abs(a.position[1] - self.landmark_locs[i][1]) for a in self.players]
                reward_near -= min(dists)
        
        reward = reward_succ + reward_near * 0.1

        self._game_over = done = (
            _succeed or self.episode_limit <= self.current_step
        )
        
        # update valid moves
        self._gen_valid_moves() 

        return reward, done, {"battle_won": _succeed}

    def _succeed(self):
        _score = 0
        for p in self.players:
            if self._reach_landmark(p.position):
                _score += 1
        if _score != self.n_targets:
            return False
        matrix = [
            [10 for _ in range(self.n_agents)] for _ in range(self.n_targets)
        ]
        target_locs = [self.landmark_locs[i] for i in range(self.n_landmarks) if self.target_landmarks[i] != 0]
        for i, landmark_loc in enumerate(target_locs):
            for j, p in enumerate(self.players):
                if self._within_reach(p.position, landmark_loc):
                    matrix[i][j] = 1
        _succeed = matched(matrix, self.n_agents)
        return _succeed
    
    def render(self):
        pass

    def close(self):
        pass

    def save_replay(self):
        pass

    def get_stats(self):
        return {}



class GridMPEEnv(MultiAgentEnv):
    """
    Class for GridMPE.
    """

    action_set = [Action.NORTH, Action.SOUTH, Action.WEST, Action.EAST]
    def __init__(
            self,
            n_agents,
            field_size,
            sight,
            episode_limit,
            reach_range,
            seed=None,
            default_task=False,
            task_id=None, # [0, 1, 2, 3, 4]
            **kwargs,
    ):
        self.logger = logging.getLogger(__name__)
        self.seed(seed)
    
        self.n_agents = self.n_landmarks = n_agents
        self.players = [Player() for _ in range(n_agents)]
        self.sight = sight
        
        self.field = np.zeros(field_size, np.int32)
        self.landmark_locs = None
        self.LANDMARK = 1

        self._game_over = None        
        self._valid_actions = None
        self.episode_limit = episode_limit
        self.reach_range = reach_range

        self._score = 0

        if default_task:
            task_config = task_id2config[str(task_id)]
            self.n_agents = self.n_landmarks = task_config["n_agents"]
            self.players = [Player() for _ in range(self.n_agents)]
            self.field = np.zeros(task_config["field_size"], np.int32)
            self.sight = task_config["sight"]
            self.episode_limit = task_config["episode_limit"]
            self.reach_range = task_config["reach_range"]

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    @property
    def field_size(self):
        return self.field.shape

    @property
    def rows(self):
        return self.field_size[0]

    @property
    def cols(self):
        return self.field_size[1]

    @property
    def game_over(self):
        return self._game_over
    
    @property
    def n_actions(self):
        return 5

    def _gen_valid_moves(self):
        self._valid_actions = {
            player: [
                action for action in Action if self._is_valid_action(player, action)
            ]
            for player in self.players
        }

    def neighborhood(self, row, col, distance=1, ignore_diag=False):
        if not ignore_diag:
            return self.field[
                   max(row - distance, 0): min(row + distance + 1, self.rows),
                   max(col - distance, 0): min(col + distance + 1, self.cols),
                   ]

        return (
                self.field[
                max(row - distance, 0): min(row + distance + 1, self.rows), col
                ].sum()
                + self.field[
                  row, max(col - distance, 0): min(col + distance + 1, self.cols)
                  ].sum()
        )

    def adjacent_players(self, row, col):
        return [
            player
            for player in self.players
            if abs(player.position[0] - row) == 1
               and player.position[1] == col
               or abs(player.position[1] - col) == 1
               and player.position[0] == row
        ]

    def spawn_landmarks(self):
        landmark_count = 0
        attempts = 0
        while landmark_count < self.n_landmarks and attempts < 1000:
            attempts += 1
            row = self.np_random.randint(0, self.rows)
            col = self.np_random.randint(0, self.cols)

            if not self._is_empty_location(row, col):
                continue
            
            if self._reached_by_agent(row, col):
                continue
            
            self.field[row, col] = self.LANDMARK
            self.landmark_locs[landmark_count] = (row, col)

            landmark_count +=1

    def _is_empty_location(self, row, col):
        if self.field[row, col] != 0:
            return False

        for a in self.players:
            if a.position and row == a.position[0] and col == a.position[1]:
                return False

        return True
    
    def _reached_by_agent(self, row, col):
        for a in self.players:
            if a.position and a.position[0] - self.reach_range <= row <= a.position[0] + self.reach_range and \
                a.position[1] - self.reach_range <= col <= a.position[1] + self.reach_range:
                return True
        
        return False
    
    def spawn_players(self):
        for player in self.players:    
            
            attempts = 0
            
            while attempts < 1000:
                row = self.np_random.randint(0, self.rows)
                col = self.np_random.randint(0, self.cols)
                if self._is_empty_location(row, col):
                    player.setup(
                        (row, col),
                        self.field_size,
                    )
                    break
                attempts += 1
    
    def _is_valid_action(self, player, action):
        if action == Action.NONE:
            return True
        elif action == Action.NORTH:
            return (
                    player.position[0] > 0
            )
        elif action == Action.SOUTH:
            return (
                    player.position[0] < self.rows - 1
            )
        elif action == Action.WEST:
            return (
                    player.position[1] > 0
            )
        elif action == Action.EAST:
            return (
                    player.position[1] < self.cols - 1
            )

        # Should not get here!!!
        self.logger.error("Undefined action {} from {}".format(action, player.name))
        raise ValueError("Undefined action")

    def get_valid_actions(self):
        return list(product(*[self._valid_actions[player] for player in self.players]))

    def _within_sight(self, ego_position, other_position):
        return (ego_position[0] - self.sight <= other_position[0] <= ego_position[0] + self.sight) and \
            (ego_position[1] - self.sight <= other_position[1] <= ego_position[1] + self.sight)

    def _within_reach(self, ego_position, land_position):
        return (ego_position[0] - self.reach_range <= land_position[0] <= ego_position[0] + self.reach_range) and \
            (ego_position[1] - self.reach_range <= land_position[1] <= ego_position[1] + self.reach_range)

    def _loc_increasement(self, locs):
        return [[loc[0] + 1, loc[1] + 1] for loc in locs]

    def _get_state_info(self):
        state_info = [p.position for p in self.players]
        state_info.extend(self.landmark_locs)
        return state_info

    def _reach_landmark(self, ego_position):
        return np.sum(self.neighborhood(ego_position[0], ego_position[1], distance=self.reach_range)) > 0

    def get_obs_agent(self, agent_id):
        state_info = self._get_state_info()
        obs = [loc if self._within_sight(self.players[agent_id].position, loc) else [-1, -1] for loc in state_info[:agent_id] + state_info[agent_id+1:]]
        obs = [[self.players[agent_id].position[0], self.players[agent_id].position[1]]] + obs
        obs = self._loc_increasement(obs)
        return np.array(obs).flatten()

    def get_obs(self):
        return  [self.get_obs_agent(agent_id) for agent_id in range(self.n_agents)]

    def get_obs_size(self):
        return 2 * (self.n_agents + self.n_landmarks)

    def get_state(self):
        state_info = self._get_state_info()
        state_info = self._loc_increasement(state_info)
        return np.array(state_info).flatten()

    def get_state_size(self):
        return 2 * (self.n_agents + self.n_landmarks)
    
    def get_total_actions(self):
        # 5 possible actions
        return 5

    def get_avail_actions(self):
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_agent = self.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_agent)
        return avail_actions

    def get_avail_agent_actions(self, agent_id):
        valid_actions = self._valid_actions[self.players[agent_id]]
        avail_actions = [1 if Action(i) in valid_actions else 0 for i in range(self.n_actions)]
        return avail_actions

    def reset(self):
        self.field = np.zeros(self.field_size, np.int32)
        # spawn players on the board
        self.spawn_players()    
        # spawn the food on the board
        self.landmark_locs = [None for _ in range(self.n_landmarks)]
        self.spawn_landmarks()
        
        self.current_step = 0
        self._score = 0
        self._game_over = False
        
        self._gen_valid_moves()
        
        return self.get_obs(), self.get_state()

    def step(self, actions):
        if actions.__class__ == th.Tensor:
            actions = actions.cpu().numpy()

        self.current_step += 1

        actions = [
            Action(a) if Action(a) in self._valid_actions[p] else Action.NONE
            for p, a in zip(self.players, actions)
        ]

        # check if actions are valid
        for i, (player, action) in enumerate(zip(self.players, actions)):
            if action not in self._valid_actions[player]:
                self.logger.info(
                    "{}{} attempted invalid action {}.".format(
                        player.name, player.position, action
                    )
                )
                actions[i] = Action.NONE

        # move players
        # if two or more players try to move to the same location they all fail
        collisions = defaultdict(list)

        # so check for collisions
        for player, action in zip(self.players, actions):
            if action == Action.NONE:
                collisions[player.position].append(player)
            elif action == Action.NORTH:
                collisions[(player.position[0] - 1, player.position[1])].append(player)
            elif action == Action.SOUTH:
                collisions[(player.position[0] + 1, player.position[1])].append(player)
            elif action == Action.WEST:
                collisions[(player.position[0], player.position[1] - 1)].append(player)
            elif action == Action.EAST:
                collisions[(player.position[0], player.position[1] + 1)].append(player)
    
        # and do movements for non colliding players
        for k, v in collisions.items():
            if len(v) > 1:  # make sure no more than an player will arrive at location
                continue
            v[0].position = k

        _succeed = self._succeed()
        reward = 1 if _succeed else 0
        self._game_over = done = (
            _succeed or self.episode_limit <= self.current_step
        )
        
        # update valid moves
        self._gen_valid_moves() 

        return reward, done, {"battle_won": _succeed}

    def _succeed(self):
        _score = 0
        for p in self.players:
            if self._reach_landmark(p.position):
                _score += 1
        if _score != self.n_landmarks:
            return False
        matrix = [
            [10 for _ in range(self.n_agents)] for _ in range(self.n_landmarks)
        ]
        for i, landmark_loc in enumerate(self.landmark_locs):
            for j, p in enumerate(self.players):
                if self._within_reach(p.position, landmark_loc):
                    matrix[i][j] = 1
        _succeed = matched(matrix, self.n_agents)
        return _succeed
    
    def render(self):
        pass

    def close(self):
        pass

    def save_replay(self):
        pass

    def get_stats(self):
        return {}
