import logging
from collections import defaultdict
from enum import Enum
from itertools import product
from gym.utils import seeding
import numpy as np
import torch as th
try:
    from .utils import matched
except:
    from utils import matched

# import sys
# sys.path.append("/home/ubuntu/zhanglichao/chenfeng/MATTAR_work/transfer_marl/src")
from envs.multiagentenv import MultiAgentEnv


class Action(Enum):
    NONE = 0
    NORTH = 1
    SOUTH = 2
    WEST = 3
    EAST = 4

class CellEntity(Enum):
    # entity encodings for grid observations
    EMPTY = 0
    LANDMARK = 1
    AGENT = 2

class Player:
    def __init__(self):
        self.controller = None
        self.position = None
        self.field_size = None
        self.current_step = None

    def setup(self, position, field_size):
        self.position = position
        self.field_size = field_size

    def set_controller(self, controller):
        self.controller = controller

    def step(self, obs):
        return self.controller._step(obs)

    @property
    def name(self):
        if self.controller:
            return self.controller.name
        else:
            return "Player"

N_AGENTS = 2
FIELD_SIZE = [9, 9]
SIGHT = 8
EPISODE_LIMIT = 50
REACH_RANGE = 1


# task config
task_id2config = {
    "0": {
        "target_position": [0, 4],
    },
    "1": {
        "target_position": [1, 7],
    },
    "2": {
        "target_position": [4, 8],
    },
    "3": {
        "target_position": [7, 7],
    },
    "4": {
        "target_position": [8, 4],
    },
    "5": {
        "target_position": [7, 1],
    },
    "6": {
        "target_position": [4, 0],
    },
    "7": {
        "target_position": [1, 1],
    }
}
# test task config
test_task_id2config = {
    "8": {
        "target_position": [0, 6],
    },
    "9": {
        "target_position": [2, 8],
    },
    "10": {
        "target_position": [6, 8],
    },
    "11": {
        "target_position": [8, 6],
    }
}
# merge the test task config
task_id2config.update(test_task_id2config)


class TwoPlayerNavigationMPEEnv(MultiAgentEnv):
    """
    Class for NavigationMPE
    """

    action_set = [Action.NORTH, Action.SOUTH, Action.WEST, Action.EAST]
    def __init__(
            self,
            target_position=None,
            sight=SIGHT,
            episode_limit=EPISODE_LIMIT,
            seed=None,
            default_task=False,
            task_id=None, # [0, 1, 2, 3, 4]
            **kwargs,
    ):
        self.logger = logging.getLogger(__name__)
        self.seed(seed)
        
        self.n_agents = N_AGENTS
        self.n_landmarks = 1
        
        self.players = [Player() for _ in range(self.n_agents)]
        self.sight = sight

        self.field = np.zeros(FIELD_SIZE, np.int32)
        
        self.landmark_locs = None
        self.LANDMARK = 1

        self._game_over = None        
        self._valid_actions = None
        self.episode_limit = episode_limit
        
        self.reach_range = REACH_RANGE    # NOTE: we force agents to overlap the target mark for finishing the task
        self.target_position = target_position

        self._score = 0

        if default_task:
            task_config = task_id2config[str(task_id)]
            self.target_position = task_config["target_position"] 

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    @property
    def field_size(self):
        return self.field.shape

    @property
    def rows(self):
        return self.field_size[0]

    @property
    def cols(self):
        return self.field_size[1]

    @property
    def game_over(self):
        return self._game_over
    
    @property
    def n_actions(self):
        return 5

    def _gen_valid_moves(self):
        """
        Obtain valid actions for each agent; inplace function, which is used beforehand
        """
        self._valid_actions = {
            player: [
                action for action in Action if self._is_valid_action(player, action)
            ]
            for player in self.players
        }

    def neighborhood(self, row, col, distance=1, ignore_diag=False):
        """
        Obtain the neighborhood information of the specific location
        """
        if not ignore_diag:
            return self.field[
                   max(row - distance, 0): min(row + distance + 1, self.rows),
                   max(col - distance, 0): min(col + distance + 1, self.cols),
                   ]

        return (
                self.field[
                max(row - distance, 0): min(row + distance + 1, self.rows), col
                ].sum()
                + self.field[
                  row, max(col - distance, 0): min(col + distance + 1, self.cols)
                  ].sum()
        )

    def adjacent_players(self, row, col):
        """
        Obtain the agents that are adjacent to the position (row, col).
        """
        return [
            player
            for player in self.players
            if abs(player.position[0] - row) == 1
               and player.position[1] == col
               or abs(player.position[1] - col) == 1
               and player.position[0] == row
        ]

    def spawn_landmarks(self):
        """
        Spawn the landmark according to the target position.
        """
        # Get row and col for the target landmark
        row, col = self.target_position
        # Check whether the position is valid
        assert self.field[row, col] == 0, "Not empty target position for landmark!"
        # assert self._is_empty_location(row, col), "Not empty target position for landmark!"
        # assert not self._reached_by_agent(row, col), "Already reached by agents!"
        # Set the landmark
        self.field[row, col] = self.LANDMARK
        self.landmark_locs[0] = (row, col)

    def _is_empty_location(self, row, col):
        """
        Determine whether it's an empty location, consider both the landmarks and agents.
        """
        if self.field[row, col] != 0:
            return False

        for a in self.players:
            if a.position and row == a.position[0] and col == a.position[1]:
                return False

        return True
    
    def _reached_by_agent(self, row, col):
        """
        Determine whether the location is reached by one agent.
        """ 
        for a in self.players:
            if a.position and a.position[0] - self.reach_range <= row <= a.position[0] + self.reach_range and \
                a.position[1] - self.reach_range <= col <= a.position[1] + self.reach_range:
                return True
        
        return False
    
    def spawn_players(self):
        """
        Spawn the players!!!
        NOTE: We force the players to be located in the centering regions
        """
        row_start = self.rows // 2 - 1
        col_start = self.cols // 2 - 1

        for player in self.players:
            # Limit the maximum attempts to spawn the players
            attempts = 0
            
            while attempts < 1000:
                row = self.np_random.randint(row_start, row_start + 3)
                col = self.np_random.randint(col_start, col_start + 3)
                if self._is_empty_location(row, col):
                    player.setup(
                        (row, col),
                        self.field_size,
                    )
                    break
                attempts += 1

    def spawn_players_around_target(self):
        target_row, target_col = self.target_position
        row_start, row_end = max(0, target_row - self.reach_range), min(self.rows - 1, target_row + self.reach_range)
        col_start, col_end = max(0, target_col - self.reach_range), min(self.cols - 1, target_col + self.reach_range)
    
        for player in self.players:
            # Limit the maximum attempts to spawn the players
            row = self.np_random.randint(row_start, row_end + 1)
            col = self.np_random.randint(col_start, col_end + 1)
            player.setup(
                (row, col),
                self.field_size,
            )

    def _is_valid_action(self, player, action):
        """
        Check whether the action is valid.
        """
        if action == Action.NONE:
            return True
        elif action == Action.NORTH:
            return (
                    player.position[0] > 0
            )
        elif action == Action.SOUTH:
            return (
                    player.position[0] < self.rows - 1
            )
        elif action == Action.WEST:
            return (
                    player.position[1] > 0
            )
        elif action == Action.EAST:
            return (
                    player.position[1] < self.cols - 1
            )

        # Should not get here!!!
        self.logger.error("Undefined action {} from {}".format(action, player.name))
        raise ValueError("Undefined action")

    def get_valid_actions(self):
        """
        Obation valid action
        """
        return list(product(*[self._valid_actions[player] for player in self.players]))

    def _within_sight(self, ego_position, other_position):
        """
        Check whether within the sight of ego agent
        """
        return (ego_position[0] - self.sight <= other_position[0] <= ego_position[0] + self.sight) and \
            (ego_position[1] - self.sight <= other_position[1] <= ego_position[1] + self.sight)

    def _within_reach(self, ego_position, land_position):
        """
        Check whether within reach range
        """
        return (ego_position[0] - self.reach_range <= land_position[0] <= ego_position[0] + self.reach_range) and \
            (ego_position[1] - self.reach_range <= land_position[1] <= ego_position[1] + self.reach_range)

    def _loc_increasement(self, locs):
        return [[loc[0] + 1, loc[1] + 1] for loc in locs]

    def _get_state_info(self):
        state_info = [p.position for p in self.players]
        state_info.extend(self.landmark_locs)
        # Get raw state information (without increment and indexed from zero/0)
        return state_info

    def _reach_landmark(self, ego_position):
        """Make the problem simpler"""
        return np.sum(self.neighborhood(ego_position[0], ego_position[1], distance=self.reach_range)) > 0


    # def _reach_landmark(self, ego_position):
    #     """
    #     Whether the ego position reaches the landmark
    #     """
    #     assert self.n_landmarks == 1, "In this task, the number of landmarks must be 1."
    #     return ego_position[0] == self.target_position[0] and ego_position[1] == self.target_position[1]

    def get_obs_agent(self, agent_id):
        """
        Obtain obs information for agent_id 
        """
        state_info = self._get_state_info()
        obs = [loc if self._within_sight(self.players[agent_id].position, loc) else [-1, -1] for loc in state_info[:agent_id] + state_info[agent_id+1:]]
        obs = [[self.players[agent_id].position[0], self.players[agent_id].position[1]]] + obs
        obs = self._loc_increasement(obs)
        return np.array(obs).flatten()

    def get_obs(self):
        """
        Obtain total information of obs
        """
        return  [self.get_obs_agent(agent_id) for agent_id in range(self.n_agents)]

    def get_obs_size(self):
        """
        The size/shape of individual observation
        """
        return 2 * (self.n_agents + self.n_landmarks)

    def get_state(self):
        """
        Get total state information.
        """
        state_info = self._get_state_info()
        state_info = self._loc_increasement(state_info)
        return np.array(state_info).flatten()

    def get_state_size(self):
        """
        Get state shape information
        """
        return 2 * (self.n_agents + self.n_landmarks)
    
    def get_total_actions(self):
        """
        Fixed total action num of five.
        """
        # 5 possible actions
        return 5

    def get_avail_actions(self):
        """
        Fetch the available actions for the agents
        """
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_agent = self.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_agent)
        return avail_actions

    def get_avail_agent_actions(self, agent_id):
        """
        Fetch available actions for agent_id
        """
        valid_actions = self._valid_actions[self.players[agent_id]]
        avail_actions = [1 if Action(i) in valid_actions else 0 for i in range(self.n_actions)]
        return avail_actions

    def reset(self, boost_sampling=False):
        self.field = np.zeros(self.field_size, np.int32)
        # spawn players on the board
        if boost_sampling and self.np_random.rand() > 0.5:
            self.spawn_players_around_target()
        else:
            self.spawn_players()   
        # spawn the food on the board
        self.landmark_locs = [None for _ in range(self.n_landmarks)]
        self.spawn_landmarks()
        
        self.current_step = 0
        self._score = 0
        self._game_over = False
        
        self._gen_valid_moves()
        
        return self.get_obs(), self.get_state()

    def step(self, actions):
        if actions.__class__ == th.Tensor:
            actions = actions.cpu().numpy()

        self.current_step += 1

        actions = [
            Action(a) if Action(a) in self._valid_actions[p] else Action.NONE
            for p, a in zip(self.players, actions)
        ]

        # check if actions are valid
        for i, (player, action) in enumerate(zip(self.players, actions)):
            if action not in self._valid_actions[player]:
                self.logger.info(
                    "{}{} attempted invalid action {}.".format(
                        player.name, player.position, action
                    )
                )
                actions[i] = Action.NONE

        # move players
        # if two or more players try to move to the same location they all fail
        collisions = defaultdict(list)

        # so check for collisions
        for player, action in zip(self.players, actions):
            if action == Action.NONE:
                collisions[player.position].append(player)
            elif action == Action.NORTH:
                collisions[(player.position[0] - 1, player.position[1])].append(player)
            elif action == Action.SOUTH:
                collisions[(player.position[0] + 1, player.position[1])].append(player)
            elif action == Action.WEST:
                collisions[(player.position[0], player.position[1] - 1)].append(player)
            elif action == Action.EAST:
                collisions[(player.position[0], player.position[1] + 1)].append(player)
    
        # Change One: we allow collision, or to say overlap !!!
        # and do movements for non colliding players
        for k, v in collisions.items():
            # we allow agents to overlap
            for p in v:
                # v is a list of players
                p.position = k
    
        _succeed = self._succeed()
        reward = 1 if _succeed else 0
        self._game_over = done = (
            _succeed or self.episode_limit <= self.current_step
        )
        
        # update valid moves
        self._gen_valid_moves() 

        return reward, done, {"battle_won": _succeed}

    def _succeed(self):
        _score = 0
        for p in self.players:
            if self._reach_landmark(p.position):
                _score += 1
        if _score != self.n_agents:
            return False
        return True
        
    def _hard_succeed(self):
        _score = 0
        for p in self.players:
            if self._reach_landmark(p.position):
                _score += 1
        if _score != self.n_landmarks:
            return False
        matrix = [
            [10 for _ in range(self.n_agents)] for _ in range(self.n_landmarks)
        ]
        for i, landmark_loc in enumerate(self.landmark_locs):
            for j, p in enumerate(self.players):
                if self._within_reach(p.position, landmark_loc):
                    matrix[i][j] = 1
        _succeed = matched(matrix, self.n_agents)
        return _succeed
    
    def render(self):
        pass

    def close(self):
        pass

    def save_replay(self):
        pass

    def get_stats(self):
        return {}

if __name__ == "__main__":
    grid_mpe = TwoPlayerNavigationMPEEnv(default_task=True, task_id=2)
    
    grid_mpe.reset()
    done = False
    while not done:
        print(grid_mpe.get_obs())
        print(grid_mpe.landmark_locs)
        print(grid_mpe.get_state())
        print(grid_mpe._get_state_info())
        actions = [np.random.randint(5), np.random.randint(5)]
        print(actions)
        reward, done, _ = grid_mpe.step(actions)
        print(f"--->>> debug")
        print(reward)
        
