import numpy as np
from collections import defaultdict

# action_space = [
#   [0,1],  # 'up'
#   [1,0],  # 'right'
#   [0,-1], # 'down'
#   [-1,0]  # 'left'
# ]

class GridWorld:
    def __init__(self, grid, reward_map, epsilon = 0.0):
        self.grid = np.array(grid)
        self.grid_height, self.grid_width = self.grid.shape
        self.size_of_state_space = self.grid_height * self.grid_width
        self.size_of_action_space = 4
        # self.reward_map = reward_map
        self.reward_map = np.array(reward_map)
        self.allowed_actions = self.enumerate_allowed_actions()
        self.possible_next_states = self.enumerate_possible_next_state()
        self.epsilon = epsilon
        ''' randomness in transition '''

        '''
        state
                        x
                0   1   2   3   4
            0   0   1   2   3   4
        y   1   5   6   7   8   9
            2   10  11  12  13  14

        action
            0 : down
            1 : right
            2 : up
            3 : left
        '''

    def get_state_space_size(self):
        return self.size_of_state_space

    def stox(self,s):
        return int(s % self.grid_width)

    def stoy(self,s):
        return int(s // self.grid_width)

    def stoxy(self,s):
        return int(s % self.grid_width), int(s // self.grid_width)

    def xytos(self,x,y):
        return int(x + y * self.grid_width)


    def deterministic_transition(self, s, a):
        x,y = self.stoxy(s)
        if a == 0:   # down
            nx = x
            ny = y + 1
        elif a == 1: # right
            nx = x + 1
            ny = y
        elif a == 2: # up
            nx = x
            ny = y - 1
        elif a == 3: # left
            nx = x - 1
            ny = y
        return self.xytos(nx,ny)

    # def deterministic_transition(self, x, y, a):
    #     if a == 0:   # up
    #         nx = x
    #         ny = y + 1
    #     elif a == 1: # right
    #         nx = x + 1
    #         ny = y
    #     elif a == 2: # down
    #         nx = x
    #         ny = y - 1
    #     elif a == 3: # left
    #         nx = x - 1
    #         ny = y
    #     return nx,ny

    def enumerate_allowed_actions(self):
        allowed_actions = []
        for s in range(0,self.size_of_state_space):
            x,y = self.stoxy(s)
            allowed_actions_in_single_state = []

            if y < self.grid_height - 1:
                ns = self.deterministic_transition(s,0)
                nx,ny = self.stoxy(ns)
                # print(nx)
                # print(ny)
                if self.grid[ny][nx] == 0:
                    allowed_actions_in_single_state.append(0) # up

            if x < self.grid_width - 1:
                ns = self.deterministic_transition(s,1)
                nx,ny = self.stoxy(ns)
                if self.grid[ny][nx] == 0:
                    allowed_actions_in_single_state.append(1) # right

            if y > 0:
                ns = self.deterministic_transition(s,2)
                nx,ny = self.stoxy(ns)
                if self.grid[ny][nx] == 0:
                    allowed_actions_in_single_state.append(2) # down

            if x > 0:
                ns = self.deterministic_transition(s,3)
                nx,ny = self.stoxy(ns)
                if self.grid[ny][nx] == 0:
                    allowed_actions_in_single_state.append(3) # left

            allowed_actions.append(allowed_actions_in_single_state)

        return allowed_actions

    def enumerate_possible_next_state(self):
        possible_next_states = []
        for s in range(0,self.size_of_state_space):
            # x,y = self.stoxy(s)
            possible_next_states_in_single_state = []
            for a_candidate in self.allowed_actions[s]:
                ns = self.deterministic_transition(s,a_candidate)
                possible_next_states_in_single_state.append(ns)
            possible_next_states_in_single_state.sort()
            possible_next_states.append(possible_next_states_in_single_state)
        return possible_next_states

    def state_transition_law(self,s,a):
        # x,y = self.stoxy(s)
        allowed_actions_size = len(self.allowed_actions[s])
        state_transition_probability = defaultdict(int)
        for ns in self.possible_next_states[s]:
            state_transition_probability[ns] = self.epsilon / allowed_actions_size
            if ns == self.deterministic_transition(s,a):
                state_transition_probability[ns] += 1 - self.epsilon
        return state_transition_probability

    def reward_function(self,s,a,ns):
        x,y = self.stoxy(s)
        return self.reward_map[y][x]
        # nx,ny = self.stoxy(ns)
        # return reward_map[ny][nx]
