import gym
from gym import spaces
import numpy as np
from collections import OrderedDict
from threading import Lock
import sys
# from matplotlib.colors import hsv_to_rgb
import random
import math
import copy
# from od_mstar3 import cpp_mstar
# from od_mstar3.col_set_addition import NoSolutionError, OutOfTimeError

'''
    Observation: (position maps of current agent, current goal, other agents, other goals, obstacles)
        
    Action space: (Tuple)
        agent_id: positive integer
        action: {0:STILL, 1:MOVE_NORTH, 2:MOVE_EAST, 3:MOVE_SOUTH, 4:MOVE_WEST,
        5:NE, 6:SE, 7:SW, 8:NW}
    Reward: ACTION_COST for each action, GOAL_REWARD when robot arrives at target
'''
ACTION_COST, IDLE_COST, GOAL_REWARD, COLLISION_REWARD, FINISH_REWARD = -0.3, -.5, 20.0, -5.0, 20.0
opposite_actions = {0: -1, 1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6}
JOINT = False # True for joint estimation of rewards for closeby agents
dirDict = {0:(0,0),1:(0,1),2:(1,0),3:(0,-1),4:(-1,0),5:(1,1),6:(1,-1),7:(-1,-1),8:(-1,1)}
actionDict={v:k for k,v in dirDict.items()}

class State(object):
    '''
    State.
    Implemented as 2 2d numpy arrays.
    first one "state":
        static obstacle: -1
        empty: 0
        agent = positive integer (agent_id)
    second one "goals":
        agent goal = positive int(agent_id)
    '''
    def __init__(self, world0, goals, diagonal, num_agents=1, dynamic=False):
        assert(len(world0.shape) == 2 and world0.shape==goals.shape)
        self.state                    = world0.copy()
        self.goals                    = goals.copy()
        self.num_agents               = num_agents
        self.agents, self.agents_past, self.agent_goals = self.scanForAgents()
        self.diagonal=diagonal
        self.dynamic = dynamic
        assert(len(self.agents) == num_agents)

    def scanForAgents(self):
        agents = [(-1,-1) for i in range(self.num_agents)]
        agents_last = [(-1,-1) for i in range(self.num_agents)]
        agent_goals = [(-1,-1) for i in range(self.num_agents)]
        for i in range(self.state.shape[0]):
            for j in range(self.state.shape[1]):
                if(self.state[i,j]>0):
                    agents[self.state[i,j]-1] = (i,j)
                    agents_last[self.state[i,j]-1] = (i,j)
                if(self.goals[i,j]>0):
                    agent_goals[self.goals[i,j]-1] = (i,j)
        assert((-1,-1) not in agents and (-1,-1) not in agent_goals)
        assert(agents==agents_last)
        return agents, agents_last, agent_goals

    def getPos(self, agent_id):
        return self.agents[agent_id-1]

    def getPastPos(self, agent_id):
        return self.agents_past[agent_id-1]

    def getGoal(self, agent_id):
        return self.agent_goals[agent_id-1]

    def diagonalCollision(self, agent_id, newPos):
        '''diagonalCollision(id,(x,y)) returns true if agent with id "id" collided diagonally with 
        any other agent in the state after moving to coordinates (x,y)
        agent_id: id of the desired agent to check for
        newPos: coord the agent is trying to move to (and checking for collisions)
        '''
        def collide(a1,a2,b1,b2):
            '''
            a1,a2 are coords for agent 1, b1,b2 coords for agent 2, returns true if these collide diagonally
            '''
            return np.isclose( (a1[0]+a2[0]) /2. , (b1[0]+b2[0])/2. ) and np.isclose( (a1[1]+a2[1])/2. , (b1[1]+b2[1])/2. )
        assert(len(newPos) == 2);
        #up until now we haven't moved the agent, so getPos returns the "old" location
        lastPos = self.getPos(agent_id)
        for agent in range(1,self.num_agents+1):
            if agent == agent_id: continue
            aPast = self.getPastPos(agent)
            aPres = self.getPos(agent)
            if collide(aPast,aPres,lastPos,newPos): return True
        return False

    #try to move agent and return the status
    def moveAgent(self, direction, agent_id):
        ax=self.agents[agent_id-1][0]
        ay=self.agents[agent_id-1][1]

        # Not moving is always allowed
        if(direction==(0,0)):
            self.agents_past[agent_id-1]=self.agents[agent_id-1]
            return 1 if self.goals[ax,ay]==agent_id else 0

        # Otherwise, let's look at the validity of the move
        dx,dy =direction[0], direction[1]
        if(ax+dx>=self.state.shape[0] or ax+dx<0 or ay+dy>=self.state.shape[1] or ay+dy<0):#out of bounds
            return -1
        if(self.state[ax+dx,ay+dy]<0):#collide with static obstacle
            return -2
        if(self.state[ax+dx,ay+dy]>0):#collide with robot
            return -3
        # check for diagonal collisions
        if(self.diagonal):
            if self.diagonalCollision(agent_id,(ax+dx,ay+dy)):
                return -3
        # No collision: we can carry out the action
        self.state[ax,ay] = 0
        self.state[ax+dx,ay+dy] = agent_id
        self.agents_past[agent_id-1]=self.agents[agent_id-1]
        self.agents[agent_id-1] = (ax+dx,ay+dy)
        if self.goals[ax+dx,ay+dy]==agent_id:
            return 1
        elif self.goals[ax+dx,ay+dy]!=agent_id and self.goals[ax,ay]==agent_id:
            return 2
        else:
            return 0

    # try to execture action and return whether action was executed or not and why
    #returns:
    #     2: action executed and left goal
    #     1: action executed and reached goal (or stayed on)
    #     0: action executed
    #    -1: out of bounds
    #    -2: collision with wall
    #    -3: collision with robot
    def act(self, action, agent_id):
        # 0     1  2  3  4 
        # still N  E  S  W
        direction = self.getDir(action)
        moved = self.moveAgent(direction,agent_id)
        return moved

    def getDir(self,action):
        return dirDict[action]
    def getAction(self,direction):
        return actionDict[direction]

    # Compare with a plan to determine job completion
    def done(self):
        numComplete = 0
        for i in range(1,len(self.agents)+1):
            agent_pos = self.agents[i-1]
            if self.goals[agent_pos[0],agent_pos[1]] == i:
                numComplete += 1
        return numComplete==len(self.agents) #, numComplete/float(len(self.agents))


class MAPFEnv(gym.Env):
    """
    Encapsulates all the information and functions of the Multi-agent Path Finding Gym environment.
    """
    metadata = {"render.modes": ["human", "ansi"]}

    # Initialize env
    def __init__(self, num_agents=1, observation_size=10, world0=None, goals0=None,
                 DIAGONAL_MOVEMENT=False, SIZE=(10,40), PROB=(0,.5),
                 FULL_HELP=False, blank_world=False, group_reward_at_end=False, dynamic=False, render_gif=False):
        """
        Args:
            DIAGONAL_MOVEMENT: if the agents are allowed to move diagonally
            SIZE: size of a side of the square grid
            PROB: range of probabilities that a given block is an obstacle
            FULL_HELP
        """
        # Initialize member variables
        self.num_agents        = num_agents
        #a way of doing joint rewards
        self.individual_rewards           = [0 for i in range(num_agents)]
        self.observation_size  = observation_size
        self.SIZE              = SIZE
        self.PROB              = PROB
        self.fresh             = True
        self.FULL_HELP         = FULL_HELP
        self.finished          = False
        self.mutex             = Lock()
        self.DIAGONAL_MOVEMENT = DIAGONAL_MOVEMENT
        self.group_reward_at_end = group_reward_at_end
        self.dynamic = dynamic

        # Initialize data structures
        self._setWorld(world0,goals0,blank_world=blank_world)
        if DIAGONAL_MOVEMENT:
            self.action_space = spaces.Tuple([spaces.Discrete(self.num_agents), spaces.Discrete(9)])
        else:
            self.action_space = spaces.Tuple([spaces.Discrete(self.num_agents), spaces.Discrete(5)])
        self.viewer           = None

    def isConnected(self,world0):
        sys.setrecursionlimit(10000)
        world0 = world0.copy()

        def firstFree(world0):
            for x in range(world0.shape[0]):
                for y in range(world0.shape[1]):
                    if world0[x,y]==0:
                        return x,y
        def floodfill(world,i,j):
            sx,sy=world.shape[0],world.shape[1]
            if(i<0 or i>=sx or j<0 or j>=sy):#out of bounds, return
                return
            if(world[i,j]==-1):return
            world[i,j] = -1
            floodfill(world,i+1,j)
            floodfill(world,i,j+1)
            floodfill(world,i-1,j)
            floodfill(world,i,j-1)

        i,j = firstFree(world0)
        floodfill(world0,i,j)
        if np.any(world0==0):
            return False
        else:
            return True

    def getObstacleMap(self):
        return (self.world.state==-1).astype(int)

    def getGoals(self):
        result=[]
        for i in range(1,self.num_agents+1):
            result.append(self.world.getGoal(i))
        return result

    def getPositions(self):
        result=[]
        for i in range(1,self.num_agents+1):
            result.append(self.world.getPos(i))
        return result

    def _setWorld(self, world0=None, goals0=None,blank_world=False):
        self.total_move = 0
        self.collision_total = 0
        self.collision_agent = 0
        self.collision_static = 0
        #blank_world is a flag indicating that the world given has no agent or goal positions 
        def getConnectedRegion(world,regions_dict,x,y):
            sys.setrecursionlimit(1000000)
            '''returns a list of tuples of connected squares to the given tile
            this is memoized with a dict'''
            if (x,y) in regions_dict:
                return regions_dict[(x,y)]
            visited=set()
            sx,sy=world.shape[0],world.shape[1]
            work_list=[(x,y)]
            while len(work_list)>0:
                (i,j)=work_list.pop()
                if(i<0 or i>=sx or j<0 or j>=sy):#out of bounds, return
                    continue
                if(world[i,j]==-1):
                    continue#crashes
                if world[i,j]>0:
                    regions_dict[(i,j)]=visited
                if (i,j) in visited:continue
                visited.add((i,j))
                work_list.append((i+1,j))
                work_list.append((i,j+1))
                work_list.append((i-1,j))
                work_list.append((i,j-1))
            regions_dict[(x,y)]=visited
            return visited
        #defines the State object, which includes initializing goals and agents
        #sets the world to world0 and goals, or if they are None randomizes world
        if not (world0 is None):
            if goals0 is None and not blank_world:
                raise Exception("you gave a world with no goals!")
            if blank_world:
                #RANDOMIZE THE POSITIONS OF AGENTS
                agent_counter = 1
                agent_locations=[]
                while agent_counter<=self.num_agents:
                    x,y       = np.random.randint(0,world0.shape[0]),np.random.randint(0,world0.shape[1])
                    if(world0[x,y] == 0):
                        world0[x,y]=agent_counter
                        agent_locations.append((x,y))
                        agent_counter += 1
                #RANDOMIZE THE GOALS OF AGENTS
                goals0 = np.zeros(world0.shape).astype(int)
                goal_counter = 1
                agent_regions=dict()
                while goal_counter<=self.num_agents:
                    agent_pos=agent_locations[goal_counter-1]
                    valid_tiles=getConnectedRegion(world0,agent_regions,agent_pos[0],agent_pos[1])#crashes
                    x,y  = random.choice(list(valid_tiles))
                    if(goals0[x,y]==0 and world0[x,y]!=-1):
                        goals0[x,y]    = goal_counter
                        goal_counter += 1
                self.initial_world = world0.copy()
                self.initial_goals = goals0.copy()
                self.world = State(self.initial_world,self.initial_goals,self.DIAGONAL_MOVEMENT,self.num_agents)
                return
            self.initial_world = world0
            self.initial_goals = goals0
            self.world = State(world0,goals0,self.DIAGONAL_MOVEMENT,self.num_agents)
            return

        #otherwise we have to randomize the world
        #RANDOMIZE THE STATIC OBSTACLES
        prob=np.random.triangular(self.PROB[0],.33*self.PROB[0]+.66*self.PROB[1],self.PROB[1])
        size=np.random.choice([self.SIZE[0],self.SIZE[0]*.5+self.SIZE[1]*.5,self.SIZE[1]],p=[.5,.25,.25])
        world     = -(np.random.rand(int(size),int(size))<prob).astype(int)

        #RANDOMIZE THE POSITIONS OF AGENTS
        agent_counter = 1
        agent_locations=[]
        while agent_counter<=self.num_agents:
            x,y       = np.random.randint(0,world.shape[0]),np.random.randint(0,world.shape[1])
            if(world[x,y] == 0):
                world[x,y]=agent_counter
                agent_locations.append((x,y))
                agent_counter += 1

        #RANDOMIZE THE GOALS OF AGENTS
        goals = np.zeros(world.shape).astype(int)
        goal_counter = 1
        agent_regions=dict()
        while goal_counter<=self.num_agents:
            agent_pos=agent_locations[goal_counter-1]
            valid_tiles=getConnectedRegion(world,agent_regions,agent_pos[0],agent_pos[1])
            x,y  = random.choice(list(valid_tiles))
            if(goals[x,y]==0 and world[x,y]!=-1):
                goals[x,y]    = goal_counter
                goal_counter += 1
        self.initial_world = world
        self.initial_goals = goals
        self.world = State(world,goals,self.DIAGONAL_MOVEMENT,num_agents=self.num_agents)

    # Returns an observation of an agent
    def _observe(self,agent_id):
        assert(agent_id>0)
        top_left=(self.world.getPos(agent_id)[0]-self.observation_size//2,self.world.getPos(agent_id)[1]-self.observation_size//2)
        bottom_right=(top_left[0]+self.observation_size,top_left[1]+self.observation_size)
        obs_shape=(self.observation_size,self.observation_size)
        goal_map             = np.zeros(obs_shape)
        poss_map             = np.zeros(obs_shape)
        goals_map            = np.zeros(obs_shape)
        obs_map              = np.zeros(obs_shape)
        visible_agents=[]
        for i in range(top_left[0],top_left[0]+self.observation_size):
            for j in range(top_left[1],top_left[1]+self.observation_size):
                if i>=self.world.state.shape[0] or i<0 or j>=self.world.state.shape[1] or j<0:
                    #out of bounds, just treat as an obstacle
                    obs_map[i-top_left[0],j-top_left[1]]=1
                    continue
                if self.world.state[i,j]==-1:
                    #obstacles
                    obs_map[i-top_left[0],j-top_left[1]]=1
                if self.world.state[i,j]==agent_id:
                    #agent's position
                    poss_map[i-top_left[0],j-top_left[1]]=1
                if self.world.goals[i,j]==agent_id:
                    #agent's goal
                    goal_map[i-top_left[0],j-top_left[1]]=1
                if self.world.state[i,j]>0 and self.world.state[i,j]!=agent_id:
                    #other agents' positions
                    visible_agents.append(self.world.state[i,j])
                    poss_map[i-top_left[0],j-top_left[1]]=1

        for agent in visible_agents:
            x, y = self.world.getGoal(agent)
            # print("x: ", x)
            # print("y: ", y)
            min_node = (max(top_left[0], min(top_left[0] + self.observation_size - 1, x)),
                        max(top_left[1], min(top_left[1] + self.observation_size - 1, y)))
            goals_map[min_node[0] - top_left[0], min_node[1] - top_left[1]] = 1

        dx = self.world.getGoal(agent_id)[0]-self.world.getPos(agent_id)[0]
        dy = self.world.getGoal(agent_id)[1]-self.world.getPos(agent_id)[1]
        mag=(dx**2+dy**2)**.5
        if mag!=0:
            dx=dx/mag
            dy=dy/mag
        return ([poss_map,goal_map,goals_map,obs_map],[dx,dy,mag])

    def get_goal_distance(self, agent_id):
        assert agent_id > 0
        dx = self.world.getGoal(agent_id)[0] - self.world.getPos(agent_id)[0]
        dy = self.world.getGoal(agent_id)[1] - self.world.getPos(agent_id)[1]
        goal_distance = abs(dx) + abs(dy)

        return goal_distance

    def get_all_edge_cells(self, agent_id):

        top_left = (self.world.getPos(agent_id)[0] - self.observation_size // 2,
                    self.world.getPos(agent_id)[1] - self.observation_size // 2)

        y_min = top_left[1]
        y_max = top_left[1] + self.observation_size - 1 # center alignment

        x_min = top_left[0] + 1 # center alignment
        x_max = top_left[0] + self.observation_size # center alignment

        edge_cells = []
        for y in range(y_min, y_max):
            edge_cells.append((x_min, y))
            edge_cells.append((x_max, y))

        for x in range(x_min, x_max):
            edge_cells.append((x, y_min))
            edge_cells.append((x, y_max))

        return edge_cells

    @staticmethod
    def get_closest_edge_cell(edge_cells, goal_position):
        min_dist = float(np.inf)
        min_cell = None
        for cell in edge_cells:
            dx = goal_position[0] - cell[0]
            dy = goal_position[1] - cell[1]
            dist = abs(dx) + abs(dy)
            if dist < min_dist:
                min_dist = dist
                min_cell = cell

        return min_cell

    def get_goal_in_fov_format(self, agent_id, show_distance_inside_fov=False, cap_limit=None, debug=False):
        """
        Returns observations for agent_id with Goal information in Field-of-view (FOV).

        Goal info is represented in FOV by calculating the distance from agent position to its goal.
        Distance means the minimum number of moves to reach goal (considering no obstacles)

        This distance is written in the agent FOV (goal_map) on the edge cell between the agent and goal.
        """
        top_left = (self.world.getPos(agent_id)[0] - self.observation_size // 2 + 1, # center alignment
                    self.world.getPos(agent_id)[1] - self.observation_size // 2)
        edge_cells = self.get_all_edge_cells(agent_id)

        shortest_distance_edge_cell = MAPFEnv.get_closest_edge_cell(edge_cells, self.world.getGoal(agent_id))

        if debug and agent_id in [1,2,3,4]:
            print(f"AGENT_ID: {agent_id}")
            print(f"edge_cells: {edge_cells}")
            print(f"shortest_dista: {shortest_distance_edge_cell}")
            print(f"Agent_pos: {self.world.getPos(agent_id)}")
            print(f"Goal_pos: {self.world.getGoal(agent_id)}")

        observation = self._observe(agent_id)

        goal_map = observation[0][1]

        if debug:
            print(f"top left: {(top_left[0], top_left[1])}")

        goal_pos_fov = (self.world.getGoal(agent_id)[0] - self.world.getPos(agent_id)[0],
                        self.world.getGoal(agent_id)[1] - self.world.getPos(agent_id)[1])

        if cap_limit is not None:
            distance = min(self.get_goal_distance(agent_id), cap_limit)
        else:
            distance = self.get_goal_distance(agent_id)

        if goal_pos_fov[0] < -5 or goal_pos_fov[0] > 4 or goal_pos_fov[1] < -5 or goal_pos_fov[1] > 4:
            goal_map[shortest_distance_edge_cell[0] - top_left[0], shortest_distance_edge_cell[1] - top_left[1]] = distance
        

        if show_distance_inside_fov:
            for i in range(len(goal_map)):
                for j in range(len(goal_map[i])):
                    if int(goal_map[i][j]) == 1:
                        goal_map[i][j] = distance
        if debug and agent_id == 1:
            print(f"AFTER Goal map: {goal_map}")
            print("----------------------------------------------")

        if debug and agent_id in [1,2,3,4]:
            print(f"pos_in_fov: {(shortest_distance_edge_cell[0] - top_left[0], shortest_distance_edge_cell[1] - top_left[1])}")
            print(f"Distance to goal: {self.get_goal_distance(agent_id)}")
            print(f"obs_goal: {observation[0][1]}")

        return observation

    # Resets environment
    def _reset(self, agent_id,world0=None,goals0=None):
        self.finished = False
        self.mutex.acquire()

        # Initialize data structures
        self._setWorld(world0,goals0)
        self.fresh = True

        self.mutex.release()
        if self.viewer is not None:
            self.viewer = None
        on_goal = self.world.getPos(agent_id) == self.world.getGoal(agent_id)
        #we assume you don't start blocking anyone (the probability of this happening is insanely low)
        return self._listNextValidActions(agent_id), on_goal,False

    def _complete(self):
        return self.world.done()

    # def astar(self,world,start,goal,robots=[]):
    #     '''robots is a list of robots to add to the world'''
    #     for (i,j) in robots:
    #         world[i,j]=1
    #     try:
    #         path=cpp_mstar.find_path(world,[start],[goal],1,5)
    #     except NoSolutionError:
    #         path=None
    #     for (i,j) in robots:
    #         world[i,j]=0
    #     return path


    # Executes an action by an agent
    def _step(self, action_input,episode=0):
        #episode is an optional variable which will be used on the reward discounting
        self.fresh = False
        n_actions = 9 if self.DIAGONAL_MOVEMENT else 5

        # Check action input
        assert len(action_input) == 2, 'Action input should be a tuple with the form (agent_id, action)'
        assert action_input[1] in range(n_actions), 'Invalid action'
        assert action_input[0] in range(1, self.num_agents+1)

        # Parse action input
        agent_id = action_input[0]
        action   = action_input[1]

        # Lock mutex (race conditions start here)
        self.mutex.acquire()

        #get start location of agent
        agentStartLocation = self.world.getPos(agent_id)

        # Execute action & determine reward
        action_status = self.world.act(action,agent_id)
        valid_action=action_status >=0
        #     2: action executed and left goal
        #     1: action executed and reached/stayed on goal
        #     0: action executed
        #    -1: out of bounds
        #    -2: collision with wall
        #    -3: collision with robot
          
        if action==0:#staying still
            if action_status == 1:#stayed on goal
                reward=0
            elif action_status == 0:#stayed off goal
                reward=IDLE_COST
        else:#moving
            if (action_status == 1): # reached goal
                reward = GOAL_REWARD
            elif (action_status == -3 or action_status==-2 or action_status==-1): # collision
                reward = COLLISION_REWARD
                self.collision_total += 1
                if action_status == -2:
                    self.collision_static += 1
                elif action_status == -3:
                    self.collision_agent += 1
            elif (action_status == 2): #left goal
                reward=ACTION_COST
            else:
                reward=ACTION_COST

        #self.individual_rewards[agent_id-1]=reward

        # Perform observation
        state = self._observe(agent_id)

        # Done?
        done = self.world.done()
        self.finished |= done

        if self.group_reward_at_end:
            if done:
                for i_agent in range(self.num_agents):
                    self.individual_rewards[i_agent]+=FINISH_REWARD
            else:
                self.individual_rewards[agent_id-1]=reward
        else:
            self.individual_rewards[agent_id-1]=reward

        # on_goal estimation
        on_goal = self.world.getPos(agent_id) == self.world.getGoal(agent_id)

        # Unlock mutex
        self.mutex.release()
        return state, reward, done, on_goal, valid_action

if __name__=='__main__':
    n_agents=8
    env=MAPFEnv(n_agents,PROB=(.3,.5),SIZE=(10,11),DIAGONAL_MOVEMENT=False)
