from gym import core, spaces
from gym.utils import seeding
import numpy as np
import matplotlib.pyplot as plt


MAZE_SIZE = {1:8, 2:16}
ACTION = {0: [-1, 0], 1: [0, 1], 2: [1, 0], 3: [0, -1]} # up, right, down, left

class GuardedMaze(core.Env):
    def __init__(
        self,
        mode = 1,
        max_steps = 200,
        guard_prob = 1.0,
        goal_reward = 20.,
        stochastic_trans=False
        ):

        self.mode = mode
        self.rows = MAZE_SIZE[self.mode]
        self.cols = MAZE_SIZE[self.mode]
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Tuple([spaces.Discrete(self.rows), spaces.Discrete(self.cols)])

        self.guard_prob = guard_prob
        self.goal_reward = goal_reward
        self.stochastic_trans = stochastic_trans
        self.max_steps = max_steps

        self.L = 2 * self.rows

        self.map = self._build_wall()
        if mode == 1:
            self.init_s = (6,1)
            self.goal = (2, 6)
            self.curr_s = (6,1)
        if mode == 2:
            raise NotImplementedError

        # depend on the sample prob
        self.has_guard = None
        self.curr_cost = None
        
        # tracking
        self.tot_reward = 0
        self.nsteps = 0
        self.state_traj = []
        self.state_traj.append(self.init_s)
        

    def seed(self, seed=None):
        # get the random number generator
        self.rng, np_seed = seeding.np_random(seed)
        return
    
    def reset(self):
        self.has_guard = self.rng.random() < self.guard_prob
        
        # if self.random_cost:
        #     self.curr_cost = self.rng.exponential(self.guard_cost)
        # else:
        #     self.curr_cost = self.guard_cost

        if self.mode == 1:
            self.curr_s = (6,1)
        else:
            raise NotImplementedError
        self.tot_reward = 0
        self.nsteps = 0
        self.state_traj = []
        self.state_traj.append(self.init_s)

        # flip the x axis
        init_state = (1 ,1)

        return np.array(init_state)

    def step(self, action: int):
        if self.stochastic_trans:
            # with prob=0.1, random action
            rnd = self.rng.random()
            if rnd >= 0.9:
                action = self.rng.choice(4)

        delta = ACTION[action]
        new_pos = np.array(self.curr_s) + np.array(delta)
        
        r, next_pos, done = self._next_position(new_pos)
        self.tot_reward += r
        self.nsteps += 1
        info = {'r': self.tot_reward, 'l': self.nsteps, 'goal': done}

        if self.nsteps >= self.max_steps:
            done = True

        self.curr_s = tuple(next_pos)
        self.state_traj.append(self.curr_s)

        # flip the x axis
        next_pos_xflip = np.array([ 7- next_pos[0], next_pos[1]])

        return next_pos_xflip, r, done, info        

    def _next_position(self, new_pos):
        done = tuple(new_pos) == self.goal
        hit_wall = self.map[new_pos[0]][new_pos[1]] == 1
        hit_guard = self.map[new_pos[0]][new_pos[1]] == -1

        if done:
            #r = self.L
            r = self.goal_reward
            next_pos = new_pos
        elif hit_wall:
            r = -1.
            next_pos = np.array(self.curr_s)
        elif self.has_guard and hit_guard:
            # design a simple reward, 
            # -1 prob=0.2, 8 prob=0.4, -10 prob=0.4
            rnd = self.rng.random()
            if rnd < 0.2:
                r = -1
            elif rnd >= 0.2 and rnd < 0.6:
                r = 13
            else:
                r = -15

            next_pos = new_pos
        else:
            r = -1.
            next_pos = new_pos

        return r, next_pos, done

    def _build_wall(self):
        H, W = self.rows, self.cols
        map = np.zeros((self.rows, self.cols))
        # outer walls
        map[0, :] = 1
        map[:, 0] = 1
        map[-1:, :] = 1
        map[:, -1:] = 1

        # inner walls
        if self.mode == 1:
            map[2, 2:6] = 1
            map[2:6 ,5] = 1
        elif self.mode == 2:
            raise NotImplementedError
            
        # kill zone
        if self.guard_prob:
            if self.mode == 1:
                map[6][5] = -1
            elif self.mode == 2:
                raise NotImplementedError

        return map

    def _get_im(self):
        m = self.map
        m = np.array([[0.,0.,1.][int(x)] for x in m.reshape(-1)]).reshape(m.shape)
        im_list = [m]

        # goal
        goal = np.zeros((self.rows, self.cols))
        goal[self.goal] = 1
        im_list.append(goal)
        # init s
        init_s = np.zeros((self.rows, self.cols))
        init_s[self.init_s] = 1
        im_list.append(init_s)

        im = np.stack(im_list, axis=0)

        walls = np.array([[0,0.7,0][int(x)] for x in self.map.reshape(-1)]).reshape(self.map.shape)
        walls = np.stack(3*[walls], axis=0)
        
        im += walls

        return im

    def show(self, ax=None, color_scale=0.7, show_traj=True, traj_col='w'):
        if ax is None:
            ax = Axes(1, 1, grid=False)[0]
        im = self._get_im() # the shape is (3, 8, 8)
        im = np.moveaxis(im, 0, -1) # change to (8, 8, 3)
        
        print('im shape', im.shape)
        im[self.init_s[0],self.init_s[1],2] = 0  # remove agent square
        ax.imshow(color_scale*im)

        if show_traj:
            track = np.array(self.state_traj)
            ax.plot(track[:, 1], track[:, 0], f'{traj_col}.-')
            ax.plot(track[:1, 1], track[:1, 0], f'{traj_col}>', markersize=12)
            ax.plot(track[-1:, 1], track[-1:, 0], f'{traj_col}s', markersize=10)

        return ax

class Axes:
    def __init__(self, N, W=2, axsize=(5,3.5), grid=1, fontsize=13):
        self.fontsize = fontsize
        self.N = N
        self.W = W
        self.H = int(np.ceil(N/W))
        self.axs = plt.subplots(self.H, self.W, figsize=(self.W*axsize[0], self.H*axsize[1]))[1]
        for i in range(self.N):
            if grid == 1:
                self[i].grid(color='k', linestyle=':', linewidth=0.3)
            elif grid ==2:
                self[i].grid()
        for i in range(self.N, self.W*self.H):
            self[i].axis('off')

    def __len__(self):
        return self.N

    def __getitem__(self, item):
        if self.H == 1 and self.W == 1:
            return self.axs
        elif self.H == 1 or self.W == 1:
            return self.axs[item]
        return self.axs[item//self.W, item % self.W]

    def labs(self, item, *args, **kwargs):
        if 'fontsize' not in kwargs:
            kwargs['fontsize'] = self.fontsize
        labels(self[item], *args, **kwargs)

def labels(ax, xlab=None, ylab=None, title=None, fontsize=12):
    if isinstance(fontsize, int):
        fontsize = 3*[fontsize]
    if xlab is not None:
        ax.set_xlabel(xlab, fontsize=fontsize[0])
    if ylab is not None:
        ax.set_ylabel(ylab, fontsize=fontsize[1])
    if title is not None:
        ax.set_title(title, fontsize=fontsize[2])

def test():
    env = GuardedMaze()
    env.seed(11)
    env.reset()
    done = False
    r_list = []
    goal_list = []
    while not done:
        act = env.action_space.sample()
        obs, r, done, info = env.step(act)
        r_list.append(r)
        goal_list.append(info['goal'])

    print("reward", r_list)
    print('reach goal', goal_list)
    print("has guard",env.has_guard)
    env.show(show_traj=True)
    plt.show()

if __name__ == '__main__':
    test()