import numpy as np

class GridWorld:
    def __init__(self, size = None, horizon = None):
        if size is None:
            size = 10
        if horizon is None:
            horizon = 10
        self.size = size
        self.state_space = list(range(self.size * self.size))
        self.map = np.array(self.state_space).reshape(self.size, self.size)
        self.horizon = horizon

        self.reward_map = np.zeros((self.size, self.size))
        for x in range(self.size):
            for y in range(self.size):
                if np.random.rand() < 0.5:
                    self.reward_map[x][y] = np.random.randn()
        self.reward = np.reshape(self.reward_map, (self.size * self.size,))

        self.perturb_prob = np.random.uniform(low=0.1, high=0.4, size=(len(self.state_space), 4, 4))
        for state in range(len(self.state_space)):
            for action in range(4):
                self.perturb_prob[state, action, :] = self.perturb_prob[state, action, :] / np.sum(self.perturb_prob[state, action, :])
        self.state = int(self.size * self.size / 2)
        self.x, self.y = self.state2xy(self.state)
        self.time = 0
        self.done = False

    def xy2state(self, x, y):
        return int(x * self.size + y)

    def state2xy(self, state):
        x = int(state // self.size)
        y = int(state % self.size)
        return x, y

    def reset(self, state_init = None):
        if state_init is None:
            state_init = int(self.size * self.size / 2)
        self.state = state_init
        self.x, self.y = self.state2xy(self.state)
        self.time = 0
        self.done = False
        return self.state

    def step(self, action):
        if self.done == False:
            reward = self.reward_map[self.x, self.y]

            if np.random.rand() < 0.5:
                prob = self.perturb_prob[self.state, action, :]
                action = np.random.choice([0,1,2,3], p=prob)

            if action == 0:
                self.y = np.maximum(self.y - 1, 0)
                self.state = self.xy2state(self.x, self.y)
            elif action == 1:
                self.x = np.minimum(self.x + 1, self.size - 1)
                self.state = self.xy2state(self.x, self.y)
            elif action == 2:
                self.y = np.minimum(self.y + 1, self.size - 1)
                self.state = self.xy2state(self.x, self.y)
            else:
                self.x = np.maximum(self.x - 1, 0)
                self.state = self.xy2state(self.x, self.y)

            self.time += 1
            if self.time >= self.horizon:
                self.done = True
        else:
            reward = 0
        return self.state, reward

    def get_next_state(self, state, action, disturb):
        x, y = self.state2xy(state)

        if disturb == 1:
            action = (action + 2) % 4

        if action == 0:
            y = np.maximum(y - 1, 0)
            state = self.xy2state(x, y)
        elif action == 1:
            x = np.minimum(x + 1, self.size - 1)
            state = self.xy2state(x, y)
        elif action == 2:
            y = np.minimum(y + 1, self.size - 1)
            state = self.xy2state(x, y)
        else:
            x = np.maximum(x - 1, 0)
            state = self.xy2state(x, y)
        return state

    def value_iteration(self):
        Q = np.zeros((len(self.state_space), 4))
        Q_prev = Q.copy()
        for t in range(self.horizon):
            for s in self.state_space:
                for a in range(4):
                    next_s = self.get_next_state(s, a, disturb = 0)
                    next_s_0 = self.get_next_state(s, 0, disturb=0)
                    next_s_1 = self.get_next_state(s, 1, disturb=0)
                    next_s_2 = self.get_next_state(s, 2, disturb=0)
                    next_s_3 = self.get_next_state(s, 3, disturb=0)
                    dist = self.perturb_prob[s,a,:]
                    Q[s, a] = (self.reward[s] + 0.5 * np.max(Q_prev[next_s, :])
                               + 0.5 * (dist[0] * np.max(Q_prev[next_s_0, :])
                               + dist[1] * np.max(Q_prev[next_s_1, :])
                               + dist[2] * np.max(Q_prev[next_s_2, :])
                               + dist[3] * np.max(Q_prev[next_s_3, :])))
            Q_prev = Q.copy()

        opt_value = np.max(Q, axis = 1)
        opt_policy = np.argmax(Q, axis = 1)

        return opt_value, opt_policy, Q


if __name__ == '__main__':
    np.random.seed(1)
    env = GridWorld(size=10, horizon=10)
    value, policy, _ = env.value_iteration()
    print('Reward Map:')
    print(np.round(env.reward_map, 2))
    print('Value Function:')
    print(np.round(value.reshape(10, 10), 2))
    print('Policy:')
    print(policy.reshape(10,10))

