import numpy as np
import random 
import pickle as pk
from pathlib import Path
import os
import map_maker
import math
import matplotlib.pyplot as plt


class w_agent ():
    def __init__(self, type, range):
        self._type = type
        self._radius = 5
        self._loc = [random.randint(0,range[0]), random.randint(0,range[1])]
        self._vx = 0.5
        self._vy = 0.5
        if random.random() < 0.5: self._vx *= -1
        if random.random() < 0.5: self._vy *= -1

class water_world:

    def __init__(self, types):
        self._types = types
        self._visited = []
        self._max_vel = 0.5
        self._radius = 5
        self._vx = 0
        self._vy = 0
        self._action_space = ['up','down','left','right']
        self._dimension = (100,100)
        self._agents = self.create_agents(types)
        self._state_size = 1
        self._current_loc = [self._dimension[0]/2, self._dimension[1]/2]
        self._action_size = len(self._action_space)
        self._n_state_variables = (len(self._agents))*4 + 2
        self._state_ranges = self.get_ranges()
        self._locations = []
        

    #def render(self):
    #    fig, ax = plt.subplots() 
    #    self._fig, self._ax = fig, ax
    #    circles = []
    #    agent_circle = plt.Circle(self._current_loc, self._radius, color='b')
    #    circles.append(agent_circle)

    #    for bubble in self._agents:
    #        if bubble._type == 'good': bubble_color = 'g'
    #        elif bubble._type == 'bad': bubble_color = 'r'
    #        bubble_circle = plt.Circle(bubble._loc, bubble._radius, color=bubble_color)
    #        circles.append(bubble_circle)

    #    #self._ax.cla()
    #    self._ax.set_xlim((0, self._dimension[0]))
    #    self._ax.set_ylim((0, self._dimension[1]))

    #    for circle in circles:
    #        self._ax.add_patch(circle)
    #    self._ax.plot() 
    #    plt.show(block = False)
    #    plt.pause(2)
    #    plt.clf()
        

    def get_ranges(self):
        ranges = []
        loc_range = (0, self._dimension[1]+1)
        vel_range = (-math.ceil(self._max_vel)-1, math.ceil(self._max_vel)+1)
        ranges.append(vel_range)
        ranges.append(vel_range)
        ranges.append(loc_range)
        ranges.append(loc_range)
        
        max_dist = math.sqrt(math.pow(self._dimension[0],2) + math.pow(self._dimension[1],2))
        max_dist = math.ceil(max_dist)
        for i in range(self._n_state_variables-4):
            #ranges.append((0,max_dist))
            ranges.append((0,2))
            ranges.append(vel_range)
            ranges.append(vel_range)
            ranges.append(loc_range)
            ranges.append(loc_range)
        return ranges

    def create_agents (self, types):
        temp = types.keys()
        agents = []
        for t in temp:
            n = types[t][0]
            r = types[t][1]
            for i in range (n):
                agents.append(w_agent(t, self._dimension))
        return agents
    
    def separate_agents(self):
        good = []
        bad = []
        for bubble in self._agents:
            if bubble._type == "good":
                good.append(bubble)
            elif bubble._type == "bad":
                bad.append(bubble)
        return good, bad

    def in_bound_x (self, x):
        if x < 0 or x > self._dimension[0]: return False
        else: return True
    
    def in_bound_y (self, y):
        if y < 0 or y > self._dimension[1]: return False
        else: return True

    def move_other_agents(self, agent):
        x0, y0 = agent._loc[0], agent._loc[1]
        vx, vy = agent._vx, agent._vy
        x = vx + x0
        y = vy + y0
        if not self.in_bound_x(x):
            agent._vx *= -1
            if x > 0:
                b = self._dimension[0]
                off = x - b
                x = b - off
            else:
                x = abs(x)


        if not self.in_bound_y(y):
            agent._vy *= -1
            if y > 0:
                b = self._dimension[1]
                off = y - b
                y = b - off
            else:
                y = abs(y)
        agent._loc[0], agent._loc[1] = x, y

    def move_all_agents(self):
        for bubble in self._agents:
            self.move_other_agents(bubble)

    def move_agent(self, action):
        if action == 'up': 
            self._vy = self._max_vel
            self._vx = 0
        elif action == 'down': 
            self._vy = -self._max_vel
            self._vx = 0
        elif action == 'right': 
            self._vx = self._max_vel
            self._vy = 0
        elif action == 'left': 
            self._vx = -self._max_vel
            self._vy = 0

        x0 = self._current_loc[0]
        y0 = self._current_loc[1]
        vx, vy = self._vx, self._vy
        x = vx + x0
        y = vy + y0
        if not self.in_bound_x(x):
            if x < 0: x = 0
            else: x = self._dimension[0]

        if not self.in_bound_y(y):
            if y < 0: y = 0
            else: y = self._dimension[1]
        self._current_loc[0], self._current_loc[1] = x, y

    def does_collide(self, agent):
        dist = math.sqrt(math.pow(agent._loc[0] - self._current_loc[0],2) 
                         + math.pow(agent._loc[1] - self._current_loc[1],2))
        threshold = agent._radius + self._radius
        if dist < threshold: return True
        else: return False

    def get_states(self):
        states = []
        states.append(self._vx)
        states.append(self._vy)
        states.append(self._current_loc[0])
        states.append(self._current_loc[1])
        for bubble in self._agents:
            dist = math.sqrt (math.pow(self._current_loc[0] - bubble._loc[0],2) + 
                         math.pow(self._current_loc[1] - bubble._loc[1],2))
            if bubble._type == 'good': type = 1
            elif bubble._type == 'bad': type = 0
            #states.append(dist)
            states.append(type)
            states.append(bubble._vx)
            states.append(bubble._vy)
            states.append(bubble._loc[0])
            states.append(bubble._loc[1])
        return states

    def no_moving_green_bubble(self, green_bubbles):
        for bubble in green_bubbles:
            if bubble._vx != 0 and bubble._vy != 0:
                return False
        return True

    def step (self, action_index_input):
        [a,b] = self._current_loc
        reward  = None # the episode's reward (-100 for pitfall, 0 for reaching the goal, and -1 otherwise)
        flag = False # termination flag is true if the agent falls in a pitfall or reaches to the goal
        flag_succ = False
        flag_pitfall = False
        action = self._action_space [action_index_input]
        self.move_agent(action)
        self.move_all_agents()
        good_bubbles, bad_bubbles = self.separate_agents()

        if self.no_moving_green_bubble(good_bubbles):
            reward = 1000
            flag = True
            flag_succ = True
            return self.get_states(), reward, flag, flag_succ

        for bubble in bad_bubbles:
            if self.does_collide(bubble):
                reward = -1000
                flag = True
                flag_succ = False
                return self.get_states(), reward, flag, flag_succ

        for bubble in good_bubbles:
            if self.does_collide(bubble):
                if bubble._vx !=0 and bubble._vy != 0:
                    reward = 10
                    bubble._vx, bubble._vy = 0, 0 
                    flag = False
                    flag_succ = False
                    return self.get_states(), reward, flag, flag_succ

        reward = -1
        flag = False
        flag_succ = False
        return self.get_states(), reward, flag, flag_succ
        
    def reset (self):
        self._current_loc = [self._dimension[0]/2, self._dimension[1]/2]
        self._agents = self.create_agents(self._types)
        return self.get_states()



class Simple_Grid:
    def __init__(self, env_name, start, goal):
        self._visited = []
        self._action_space = ['up','down','left','right']
        self._maze = map_maker.get_map(env_name)
        self._dimension = self._maze.shape
        self._state_size = self._dimension[0] * self._dimension[1]
        self._start = start
        self._goal = goal
        self._current_loc = self._start
        self._action_size = len(self._action_space)
        self._action_probs = {0:[2,3], 1:[2,3], 2:[0,1], 3:[0,1]}
        self._stoch_prob = 0.8
        self._visit_map = np.zeros_like(self._maze)
        self._n_state_variables = 2
        self._state_ranges = [ (0,self._dimension[0]), # robot y
                               (0,self._dimension[1]), # robot x
                                ]
        self._locations = []
        self._n_state_variables = 2
    def reset_visited (self):
        self._visit_map = np.zeros_like(self._maze)


    def action_stochastic (self, action_index):
        if random.uniform (0,1) > self._stoch_prob:
            if random.uniform (0,1) > 0.5 : 
                action_index_stoch = self._action_probs[action_index][0]
            else: action_index_stoch = self._action_probs[action_index][1]
        else: action_index_stoch = action_index
        return action_index_stoch


    def step (self, action_index_input):
        [a,b] = self._current_loc
        reward  = None # the episode's reward (-100 for pitfall, 0 for reaching the goal, and -1 otherwise)
        flag = False # termination flag is true if the agent falls in a pitfall or reaches to the goal
        flag_succ = False
        flag_pitfall = False
        action_index = self.action_stochastic (action_index_input)
        action = self.index_to_action (action_index)
        if action == 'up':
            next_loc = [a-1,b]
        elif action == 'down':
            next_loc = [a+1,b]
        elif action == 'left':
            next_loc = [a, b-1] 
        elif action == 'right':
            next_loc = [a, b+1] 

        if self._current_loc == self._goal:
            next_loc = self._current_loc
            x, y = next_loc[0], next_loc[1]
            self._visit_map[x][y] += 1
            reward = 500
            flag = True
            flag_succ = True
            return next_loc, reward, flag, flag_succ
        elif self._maze [self._current_loc[0]] [self._current_loc[1]] == -1:
            next_loc = self._current_loc
            x, y = next_loc[0], next_loc[1]
            self._visit_map[x][y] += 1
            flag_pitfall = True 
            reward = -1000
            flag = True
            return next_loc, reward, flag, flag_succ
        elif self.in_bound (next_loc) == False:
            reward = -2
            next_loc = self._current_loc
            x, y = next_loc[0], next_loc[1]
            self._visit_map[x][y] += 1
            flag = False
            return next_loc, reward, flag, flag_succ
        else:
            if self._maze [next_loc[0]] [next_loc[1]] == 1:
                next_loc = self._current_loc
                x, y = next_loc[0], next_loc[1]
                self._visit_map[x][y] += 1
                reward = -2
                flag = False
                return next_loc, reward, flag, flag_succ
            else:
                self._current_loc = next_loc
                x, y = next_loc[0], next_loc[1]
                self._visit_map[x][y] += 1
                reward = -1
                flag = False
                return next_loc, reward, flag, flag_succ

    def reset (self):
        self._current_loc = self._start
        return self._start

    def reset (self):
        self._current_loc = self._start
        return self._start

    # checks if a location is withing the env bound
    def in_bound (self, loc):
        flag = False
        if loc[0] < self._dimension[0] and loc[0] >= 0:
            if loc[1] < self._dimension[1] and loc[1] >= 0:
                flag = True
        return flag

    # action_index into action
    def index_to_action (self, action_index):
        return self._action_space [action_index]

    # state to state_index
    def state_to_index (self, state):
        return state[0]*self._dimension[0] + state[1]

    def update_visited (self, state):
        flag = True
        for i in self._visited:
            if state == i: flag = False
        if flag: self._visited.append(state)

    # state to state_index
    def state_to_index (self, state):
        return state[0]*self._dimension[0] + state[1]

    def transition (self, state, action_index_input):
        [a,b] = state
        reward  = None # the episode's reward (-100 for pitfall, 0 for reaching the goal, and -1 otherwise)
        flag = False # termination flag is true if the agent falls in a pitfall or reaches to the goal
        flag_succ = False
        flag_pitfall = False
        action_index = self.action_stochastic (action_index_input)
        action = self.index_to_action (action_index)
        if action == 'up':
            next_loc = [a-1,b]
        elif action == 'down':
            next_loc = [a+1,b]
        elif action == 'left':
            next_loc = [a, b-1] 
        elif action == 'right':
            next_loc = [a, b+1]


        if self.in_bound (next_loc) == False:
            reward = -2
            next_loc = state
            flag = False
            return next_loc, reward, flag, flag_succ, flag_pitfall
        else:
            if next_loc == self._goal:
                state = next_loc
                reward = 1000
                flag = True
                flag_succ = True
                return next_loc, reward, flag, flag_succ, flag_pitfall
            elif self._maze [next_loc[0]] [next_loc[1]] == -1:
                state = next_loc
                flag_pitfall = True 
                reward = -1000
                flag = True
                return next_loc, reward, flag, flag_succ, flag_pitfall
            elif self._maze [next_loc[0]] [next_loc[1]] == 1:
                next_loc = state
                reward = -2
                flag = False
                return next_loc, reward, flag, flag_succ, flag_pitfall
            else:
                state = next_loc
                reward = -1
                flag = False
                return next_loc, reward, flag, flag_succ, flag_pitfall


class Taxi_Domain:
    def __init__(self, env_name, start, passenger_n):
        self._visited = []
        self._action_space = ['up','down','left','right', 'pickup', 'dropoff']
        self._action_size = len(self._action_space)
        self._maze = map_maker.get_map(env_name)
        self._dimension = self._maze.shape
        self._locations = {0:[]}
        self._passenger_n = passenger_n
        for y in range(self._dimension[0]):
            for x in range(self._dimension[1]):
                if self._maze[y,x] >= 2:
                    self._locations[self._maze[y,x]-1] = [y,x]
        pick_up_loc, drop_off_loc = self.choose_pickup_locations (self._passenger_n)
        self._state_size = int(self._dimension[0] * self._dimension[1] * math.pow(len (self._locations), passenger_n) * (len(self._locations)-1))
        self._start = start
        self._current_loc = self._start
        self._action_size = len(self._action_space)
        self._action_probs = {0:[2,3], 1:[2,3], 2:[0,1], 3:[0,1], 4:[4,4], 5:[5,5]}
        self._stoch_prob = 0.8
        self._visit_map = np.zeros_like(self._maze)
        self._passenger_loc = pick_up_loc
        self._drop_loc = drop_off_loc
        self._n_state_variables = 3 + len(self._passenger_loc) 
        self._state_ranges = self.get_state_ranges (passenger_n)
        self._extra_info = [self._locations, drop_off_loc]
        self._taxi_capacity = 1

    def get_state_ranges (self, passenger_n):
        ranges = [ (0,self._dimension[0]), (0,self._dimension[1])] 
        for i in range (passenger_n): ranges = ranges + [(0,5)] 
        ranges = ranges + [(1,5)] 
        return ranges


    def choose_pickup_locations (self, passenger_n):
        pickup_locs = []
        while (len (pickup_locs)<passenger_n):
            temp_loc = random.randint(1, len(self._locations)-1)
            if temp_loc not in pickup_locs: pickup_locs.append(temp_loc)
        drop_off_loc = pickup_locs[0]
        while drop_off_loc in pickup_locs:
            drop_off_loc = random.randint(1, len(self._locations)-1)
        return pickup_locs, drop_off_loc


    def reset_visited (self):
        self._visit_map = np.zeros_like(self._maze)

    # state to state_index
    def state_to_index (self, state):
        taxi_loc = (state[0],state[1])
        p_loc = state[2]
        drop_loc = state[3]
        grid_size = self._dimension[0]*self._dimension[1]
        output_index = (taxi_loc[0]*self._dimension[0] + taxi_loc[1] ) + ((p_loc-1)* grid_size) + ((len(self._locations)) * grid_size * (drop_loc-2))
        return output_index


    def action_stochastic (self, action_index):
        if random.uniform (0,1) > self._stoch_prob:
            if random.uniform (0,1) > 0.5 : 
                action_index_stoch = self._action_probs[action_index][0]
            else: action_index_stoch = self._action_probs[action_index][1]
        else: action_index_stoch = action_index
        return action_index_stoch

    def there_is_passenger (self, current_location):
        for l in self._passenger_loc:
            if current_location == self._locations [l] and l != self._drop_loc: return True
        return False

    def update_passenger_location (self, current_location, action):
        if action == 'pickup':
            for l in self._passenger_loc:
                if current_location == self._locations [l]:
                    self._passenger_loc.remove(l)
                    self._passenger_loc.append(0)

        if action == 'dropoff':
            temp = []
            for l in self._passenger_loc:
                if l == 0: temp.append (self._drop_loc)
                else: temp.append (l)
            self._passenger_loc = temp
                
    def step (self, action_index_input):
        r_move = -1
        r_wrong_pickup, r_wrong_dropoff = -100, -100
        r_correct_dropoff, r_correct_pickup = 500, 0
        [a,b] = self._current_loc
        reward  = None # the episode's reward (-100 for pitfall, 0 for reaching the goal, and -1 otherwise)
        flag = False # termination flag is true if the agent falls in a pitfall or reaches to the goal
        flag_succ = False
        action_index = self.action_stochastic (action_index_input)
        action = self.index_to_action (action_index)
        if action == 'up':
            next_loc = [a-1,b]
        elif action == 'down':
            next_loc = [a+1,b]
        elif action == 'left':
            next_loc = [a, b-1] 
        elif action == 'right':
            next_loc = [a, b+1]
        elif action == 'pickup':
            next_loc = self._current_loc
        elif action == 'dropoff':
            next_loc = self._current_loc

        if self.in_bound (next_loc) == False:
            reward = r_move
            next_loc = self._current_loc
            flag = False
            next_passenger_location = self._passenger_loc
            state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
            return state, reward, flag, flag_succ
        else:
            if self._maze [self._current_loc[0]] [self._current_loc[1]] == -1:
                flag_pitfall = True 
                reward = -1000
                flag = True
                next_passenger_location = self._passenger_loc
                state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                return state, reward, flag, flag_succ
            elif self._maze [next_loc[0]] [next_loc[1]] == 1:
                next_loc = self._current_loc
                reward = r_move
                flag = False
                next_passenger_location = self._passenger_loc
                state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                return state, reward, flag, flag_succ
            elif next_loc == self._current_loc:
                if action == 'pickup':
                    if self.there_is_passenger (self._current_loc) and self.taxi_is_not_full():
                        reward = r_correct_pickup
                        flag = False
                        self.update_passenger_location (self._current_loc, action)
                        state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                        return state, reward, flag, flag_succ
                    else:
                        reward = r_wrong_dropoff
                        flag = False
                        state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                        return state, reward, flag, flag_succ
                elif action == 'dropoff':
                    if self.any_passenger_in_taxi() and self._current_loc == self._locations[self._drop_loc]:
                        reward = r_correct_dropoff
                        self.update_passenger_location (self._current_loc, action)
                        state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                        if self.all_passengers_at_destination(): flag, flag_succ = True, True
                        return state, reward, flag, flag_succ
                    else:
                        reward = r_wrong_dropoff
                        flag = False
                        state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                        return state, reward, flag, flag_succ

            else:
                self._current_loc = next_loc
                reward = r_move
                state = state = [next_loc[0], next_loc[1]] + self._passenger_loc + [self._drop_loc]
                return state, reward, flag, flag_succ

    def taxi_is_not_full(self):
        count = 0
        for l in self._passenger_loc:
            if l == 0:
                count += 1
        if count >= self._taxi_capacity: return False
        else: return True

    def any_passenger_in_taxi(self):
        for p in self._passenger_loc:
            if p == 0: return True
        return False

    def all_passengers_at_destination (self):
        for p in self._passenger_loc: 
            if p != self._drop_loc: return False
        return True

    def reset (self):
        self._passenger_loc, self._drop_loc = self.choose_pickup_locations (self._passenger_n)
        self._current_loc = self._start
        state = [self._current_loc[0], self._current_loc[1]] + self._passenger_loc + [self._drop_loc]
        return state


    # checks if a location is withing the env bound
    def in_bound (self, loc):
        flag = False
        if loc[0] < self._dimension[0] and loc[0] >= 0:
            if loc[1] < self._dimension[1] and loc[1] >= 0:
                flag = True
        return flag

    # action_index into action
    def index_to_action (self, action_index):
        return self._action_space [action_index]



    def update_visited (self, state):
        flag = True
        for i in self._visited:
            if state == i: flag = False
        if flag: self._visited.append(state)


class Office_Domain:
    def __init__(self, map_name):
        self._visited = []
        self._action_space = ['up','down','left','right']
        self._action_size = len(self._action_space)

        if map_name == "office_36x36_map1":
            self.load_36x36_map() 
        elif map_name == "office_45x45_map1":
            self.load_45x45_map()      
        self._coffee_locs = self._init_coffee_locs.copy()
        self._mail_locs = self._init_mail_locs.copy()
        self._has_coffee = 0
        self._has_mail = 0
        self._state_size = ((self._dimension[0] * self._dimension[1]) * 4) # 4 possible combinations of has_coffee, has_mail
        self._current_loc = self._start
        self._action_size = len(self._action_space)
        self._action_probs = {0:[2,3], 1:[2,3], 2:[0,1], 3:[0,1]}
        self._stoch_prob = 0.8
        self._visit_map = np.zeros_like(self._maze)
        self._n_state_variables = 4
        self._state_ranges = [
            (0,self._dimension[0]), # y variable
            (0,self._dimension[1]),
            (0,2),
            (0,2)
        ]
        
    def load_45x45_map(self):
        # Creating the map
        self._dimension = (45,45)
        self._maze = np.zeros((self._dimension[0], self._dimension[1]))
        self._dimension = self._maze.shape

        # Adding the agent
        self._start = (2,1)
        self._init_coffee_locs = [(8,14)]
        self._init_mail_locs = [(18,15)]
        self._office_loc = (26,29)
        self._rooms = {'a': (1,1), 'b': (43,1), 'c':(43,43), 'd':(1,43)}

        # Adding walls
        self.forbidden_transitions = set()
        for x in range(45):
           for y in [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42]:
               self.forbidden_transitions.add((x,y,self._action_space[1]))
               self.forbidden_transitions.add((x,y+2,self._action_space[0]))
        for y in range(45):
           for x in [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42]:
               self.forbidden_transitions.add((x,y,self._action_space[2]))
               self.forbidden_transitions.add((x+2,y,self._action_space[3]))
        # adding 'doors'
        for y in [1,4,7,10,13,16,19,22,25,28,31,34,37,40,43]:
           for x in [2,5,8,11,14,17,20,23,26,29,32,35,38,41]:
               self.forbidden_transitions.remove((x,y,self._action_space[3]))
               self.forbidden_transitions.remove((x+1,y,self._action_space[2]))
        for x in [1,4,7,10,13,16,19,22,25,28,31,34,37,40,43]:
            for y in [2,5,8,11,14,17,20,23,26,29,32,35,38,41]:
                self.forbidden_transitions.remove((x,y,self._action_space[0]))
                self.forbidden_transitions.remove((x,y+1,self._action_space[1]))  
         #for x in [1,4,7,10]:
         #   self.forbidden_transitions.remove((x,5,self._action_space[0]))
         #   self.forbidden_transitions.remove((x,6,self._action_space[1]))
         #for x in [1,10]:
         #   self.forbidden_transitions.remove((x,2,self._action_space[0]))
         #   self.forbidden_transitions.remove((x,3,self._action_space[1]))   
         
    def load_36x36_map(self):
        # Creating the map
        self._dimension = (36,36)
        self._maze = np.zeros((self._dimension[0], self._dimension[1]))
        self._dimension = self._maze.shape

        # Adding the agent
        self._start = (2,1)
        self._init_coffee_locs = [(8,14)]
        self._init_mail_locs = [(11,8)]
        self._office_loc = (17,20)
        self._rooms = {'a': (1,1), 'b': (34,1), 'c':(34,34), 'd':(1,34)}

        # Adding walls
        self.forbidden_transitions = set()
        for x in range(36):
           for y in [0,3,6,9,12,15,18,21,24,27,30,33]:
               self.forbidden_transitions.add((x,y,self._action_space[1]))
               self.forbidden_transitions.add((x,y+2,self._action_space[0]))
        for y in range(36):
           for x in [0,3,6,9,12,15,18,21,24,27,30,33]:
               self.forbidden_transitions.add((x,y,self._action_space[2]))
               self.forbidden_transitions.add((x+2,y,self._action_space[3]))
        # adding 'doors'
        for y in [1,4,7,10,13,16,19,22,25,28,31,34]:
           for x in [2,5,8,11,14,17,20,23,26,29,32]:
               self.forbidden_transitions.remove((x,y,self._action_space[3]))
               self.forbidden_transitions.remove((x+1,y,self._action_space[2]))
        for x in [1,4,7,10,13,16,19,22,25,28,31,34]:
            for y in [2,5,8,11,14,17,20,23,26,29,32]:
                self.forbidden_transitions.remove((x,y,self._action_space[0]))
                self.forbidden_transitions.remove((x,y+1,self._action_space[1]))  
         #for x in [1,4,7,10]:
         #   self.forbidden_transitions.remove((x,5,self._action_space[0]))
         #   self.forbidden_transitions.remove((x,6,self._action_space[1]))
         #for x in [1,10]:
         #   self.forbidden_transitions.remove((x,2,self._action_space[0]))
         #   self.forbidden_transitions.remove((x,3,self._action_space[1]))      

    def reset_visited (self):
        self._visit_map = np.zeros_like(self._maze)

    def update_visited (self, state):
        flag = True
        for i in self._visited:
            if state == i: flag = False
        if flag: self._visited.append(state)

    # state to state_index
    def state_to_index (self, state):
        current_loc = [state[0],state[1]]
        has_coffee = state[2]
        has_mail = state[3]
        state_coffee_mail = has_coffee + (has_mail*2)

        x_index = current_loc[0] * self._dimension[1] * 4
        y_index = current_loc[1] * 4

        if has_coffee and has_mail:
            o = 3
        elif has_coffee and not has_mail: 
            o = 2 
        elif not has_coffee and has_mail:
            o = 1
        else:
            o = 0

        index = x_index + y_index + o

        return index

    def action_stochastic (self, action_index):
        if random.uniform (0,1) > self._stoch_prob:
            if random.uniform (0,1) > 0.5 : 
                action_index_stoch = self._action_probs[action_index][0]
            else: action_index_stoch = self._action_probs[action_index][1]
        else: action_index_stoch = action_index
        return action_index_stoch

    def step (self, action_index_input):
        [a,b] = self._current_loc
        reward  = None # the episode's reward (-100 for pitfall, 0 for reaching the goal, and -1 otherwise)
        flag = False # termination flag is true if the agent falls in a pitfall or reaches to the goal
        flag_succ = False
        flag_pitfall = False
        action_index = self.action_stochastic (action_index_input)
        action = self.index_to_action (action_index)
        if (a,b,action) not in self.forbidden_transitions:
            if action == 'up':
                a -= 1
            elif action == 'down':
                a += 1
            elif action == 'left':
                b -= 1
            elif action == 'right':
                b += 1
        next_loc = tuple([a,b])

        if self.in_bound(next_loc):
            self._current_loc = next_loc
        else:
            next_loc = self._current_loc

        if self._has_coffee and self._has_mail and next_loc == self._office_loc:
            reward = 1000
            flag = True
            flag_succ = True
            state = [next_loc[0],next_loc[1], self._has_coffee, self._has_mail]
            # return state, reward, flag, flag_succ, flag_pitfall
            return state, reward, flag, flag_succ
        else:
            reward = 0
            # reward = -1
            flag = False
            flag_succ = False
            if not self._has_coffee and next_loc in self._coffee_locs:
                self._has_coffee = 1
                # reward = 10
            elif not self._has_mail and next_loc in self._mail_locs:
                self._has_mail = 1
                # reward = 10
            state = [next_loc[0],next_loc[1], self._has_coffee, self._has_mail]
            # return state, reward, flag, flag_succ, flag_pitfall
            return state, reward, flag, flag_succ

    def reset (self):
        self._current_loc = self._start
        self._has_coffee = 0
        self._has_mail = 0
        state = [self._current_loc[0],self._current_loc[1], self._has_coffee, self._has_mail]
        return state

    # checks if a location is withing the env bound
    def in_bound (self, loc):
        flag = False
        if loc[0] < self._dimension[0] and loc[0] >= 0:
            if loc[1] < self._dimension[1] and loc[1] >= 0:
                flag = True
        return flag

    # action_index into action
    def index_to_action (self, action_index):
        return self._action_space [action_index]