from gym_minigrid.minigrid import *
from numpy.core.fromnumeric import searchsorted
from functools import reduce
import operator

class DeepSea(MiniGridEnv):
    """
    Empty grid environment, no obstacles, sparse reward
    """

    def __init__(
        self,
        agent_start_pos=(1,1),
        agent_start_dir=0,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        self.sea_map = np.array(
            [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0.7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [-10, 8.2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [-10, -10, 11.5, 0, 0, 0, 0, 0, 0, 0, 0],
             [-10, -10, -10, 14.0, 15.1, 16.1, 0, 0, 0, 0, 0],
             [-10, -10, -10, -10, -10, -10, 0, 0, 0, 0, 0],
             [-10, -10, -10, -10, -10, -10, 0, 0, 0, 0, 0],
             [-10, -10, -10, -10, -10, -10, 19.6, 20.3, 0, 0, 0],
             [-10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0],
             [-10, -10, -10, -10, -10, -10, -10, -10, 22.4, 0, 0],
             [-10, -10, -10, -10, -10, -10, -10, -10, -10, 23.7, 0]]
        )
        height = self.sea_map.shape[0] + 2
        width = self.sea_map.shape[1] + 2
        super().__init__(
            height=height,
            width=width,
            max_steps=4*width*height,
            # Set this to True for maximum speed
            see_through_walls=True
        )

        imgShape= (self.width, self.height, 3)
        imgSize = reduce(operator.mul, imgShape, 1)

        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(imgSize,),
            dtype='float32'
        )

        # Allow only 3 actions permitted: left, right, forward
        self.action_space = spaces.Discrete(self.actions.forward + 1)

        
    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Place a goal square in the bottom-right corner
        for i in range(self.sea_map.shape[0]):
            for j in range(self.sea_map.shape[1]):
                content = self.sea_map[i,j]
                if content > 0:
                    self.put_obj(Goal(), j + 1, i + 1)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = ""


    def step(self, action):
        # Invalid action
        if action >= self.action_space.n or action < 0:
            raise ValueError("Invalid action!")
        
        obs, reward, done, info = MiniGridEnv.step(self, action)
        obs = self.observation(obs)
        x, y = self.agent_pos 
        if done:
            reward = [self.sea_map[y-1,x-1], -1]
            done = True 
            return obs, reward, done, info 
        
        return obs, [reward, -0.1], done, info

    def reset(self):
        # Reset super
        obs = MiniGridEnv.reset(self)

        return self.observation(obs)

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
            OBJECT_TO_IDX['agent'],
            COLOR_TO_IDX['red'],
            env.agent_dir
        ])

        full_grid = full_grid.flatten()
        return full_grid/ 1.


