import numpy as np


class CliffWalkingEnv():

    def __init__(self, prob=0.5):
        """
        Map
        0,  1,  2,  3
        4,  5,  6,  7
        8,  9,  10, 11
        12, 13, 14, 15
        """
        self.N_S = 16 + 1
        self.N_A = 4
        self.H = 20  # time horizon

        self.initial_state_dist = np.zeros((self.N_S,))
        self.initial_state_dist[8] = 1
        self.terminal_state = 16
        self.prob = prob
        self.P = np.zeros((self.N_S, self.N_A, self.N_S))
        for s in range(16):
            x, y = self.ind2loc(s)
            for a in range(4):
                if a == 0:  # left
                    x_, y_ = x, np.clip(y - 1, 0, 3)
                    s_ = self.loc2ind(x_, y_)
                    self.P[s, :, s_] += self.prob / 4
                elif a == 1:  # right
                    x_, y_ = x, np.clip(y + 1, 0, 3)
                    s_ = self.loc2ind(x_, y_)
                    self.P[s, :, s_] += self.prob / 4
                elif a == 2:  # up
                    x_, y_ = np.clip(x - 1, 0, 3), y
                    s_ = self.loc2ind(x_, y_)
                    self.P[s, :, s_] += self.prob / 4
                elif a == 3:  # down
                    x_, y_ = np.clip(x + 1, 0, 3), y
                    s_ = self.loc2ind(x_, y_)
                    self.P[s, :, s_] += self.prob / 4
                s_ = self.loc2ind(x_, y_)
                self.P[s, a, s_] += 1 - self.prob
        # terminal
        self.P[11, :, :] = 0
        self.P[11, :, -1] = 1
        self.P[-1, :, :] = 0
        self.P[-1, :, -1] = 1
        self.P = self.P / self.P.sum(-1, keepdims=True)
        # assert np.all(self.P.sum(-1) == 1), self.P.sum(-1) == 1
        assert np.all(np.abs(self.P.sum(-1) - 1) < 1e-8), self.P
        self.r = np.zeros((self.N_S, self.N_A))  # r(s, a)
        self.r[11, :] = 5
        self.r[12:16, :] = -1

    def loc2ind(self, x, y):
        return 4 * x + y

    def ind2loc(self, s):
        return s // 4, s % 4

    def reset(self):
        self.h = 0
        self.state = np.random.choice(self.N_S, p=self.initial_state_dist)
        return self.state

    def step(self, action):
        self.h += 1
        reward = self.r[self.state, action]
        self.state = np.random.choice(self.N_S, p=self.P[self.state, action])
        return self.state, reward, self.state == self.terminal_state, self.h >= self.H

    def step_multiple(self, action, sample_num):
        self.h += 1
        reward = np.repeat(self.r[self.state, action], sample_num)
        state = np.random.choice(self.N_S, size = (sample_num), p=self.P[self.state, action])
        self.state = state[0]
        return state, reward, state == self.terminal_state, self.h >= self.H
    
    def set_state(self, state):
        self.state = state