import numpy as np
import math
import pdb

class Gridworld(object):
    def __init__(self, length):
        self.length = length
        self.x = np.random.randint(length)
        self.y = np.random.randint(length)
        self.n_state = length*length
        self.n_action = 4
        self.observation_space = np.zeros((self.n_state,))
        self.action_space = np.zeros((self.n_action,))

    def reset(self):
        self.x = 0#np.random.randint(self.length)
        self.y = 0#np.random.randint(self.length)
        return self.state_encoding(), {}

    def state_encoding(self):
        #return self.x * self.length + self.y
        return self.y * self.length + self.x

    def step(self, action):
        # manh distance
        #rew = -(np.abs(self.x - (self.length - 1)) + np.abs(self.y - (self.length - 1)))
        # right/west
        if action == 0:
            if self.x < self.length - 1:
                self.x += 1
        # up/north
        elif action == 1:
           if self.y < self.length - 1:
               self.y += 1
        # left/east
        elif action == 2:
            if self.x > 0:
                self.x -= 1
        # down/south
        elif action == 3:
            if self.y > 0:
                self.y -= 1

        # TODO make it a contunuing task since otherwise distance function is not learned wrt
        # terminal states, which can cause some confusion in the heatmaps
        rew = -(np.abs(self.x - (self.length - 1)) + np.abs(self.y - (self.length - 1)))

        done = self.x == self.length - 1 and self.y == self.length - 1
        # if self.x == self.length - 1 and self.y == self.length - 1:
        #     self.x = 0
        #     self.y = 0
        return self.state_encoding(), rew, done, False, {}

    def state_decoding(self, state):
        # x = int(state / self.length)
        # y = state % self.length
        x = state % self.length
        y = int(state / self.length)
        return x, y

    def phi(self, s):
        x, y = self.state_decoding(state)
        return x
